diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml
index 4a5b87b3e69ed..e4d1b91bab736 100644
--- a/.github/workflows/codeql.yml
+++ b/.github/workflows/codeql.yml
@@ -47,6 +47,14 @@ jobs:
         # Details on CodeQL's query packs refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
         queries: security-extended,security-and-quality
 
+    # Setup Java to use a version that is not too old for the project
+    - if: ${{ matrix.language == 'java' }}
+      name: Setup Java 11
+      uses: actions/setup-java@v4
+      with:
+        java-version: '11'
+        distribution: 'microsoft'
+
     # Autobuild attempts to build any compiled languages  (C/C++, C#, or Java).
     # If this step fails, then you should remove it and run the build manually (see below)
     - if: ${{ matrix.language != 'cpp' }}
diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml
index 03ea773a25130..bc2d8117930bc 100644
--- a/.github/workflows/gradle-wrapper-validation.yml
+++ b/.github/workflows/gradle-wrapper-validation.yml
@@ -11,4 +11,4 @@ jobs:
     runs-on: ubuntu-latest
     steps:
       - uses: actions/checkout@v4
-      - uses: gradle/wrapper-validation-action@v1
+      - uses: gradle/wrapper-validation-action@v2
diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
index 936ab0de899a2..a196226a4b836 100644
--- a/.github/workflows/labeler.yml
+++ b/.github/workflows/labeler.yml
@@ -3,6 +3,9 @@ on:
   issues:
     types: [opened, edited]
 
+permissions:
+  issues: write
+
 jobs:
   triage:
     runs-on: ubuntu-latest
diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml
index c03399f4693be..5bc21595bf882 100644
--- a/.github/workflows/publish-csharp-apidocs.yml
+++ b/.github/workflows/publish-csharp-apidocs.yml
@@ -37,7 +37,7 @@ jobs:
         wget https://github.com/dotnet/docfx/releases/download/v${DOCFXVERSION}/docfx-linux-x64-v${DOCFXVERSION}.zip -O build/docfx/docfx.zip
         unzip build/docfx/docfx.zip -d build/docfx
     - name: Install NuGet
-      uses: nuget/setup-nuget@v1
+      uses: nuget/setup-nuget@v2
     - name: Build Documentation
       run: |
         build/docfx/docfx metadata csharp/ApiDocs/docfx.json
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index c94e3fa5bcb8c..181f3fb17d332 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -13,7 +13,7 @@ jobs:
       issues: write
       pull-requests: write
     steps:
-      - uses: actions/stale@v9.0.0
+      - uses: actions/stale@v8
         with:
           # Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale
           exempt-issue-labels: contributions welcome, feature request, regression
diff --git a/.lintrunner.toml b/.lintrunner.toml
index 4e5d077b08ff4..be95e03479cf9 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -132,6 +132,7 @@ exclude_patterns = [
     'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code
     'onnxruntime/core/graph/contrib_ops/quantization_defs.cc',
     'onnxruntime/core/mlas/**', # Contains assembly code
+    'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS lib recommends NO automatic code formatting
     'winml/lib/Api.Image/shaders/**',  # Contains data chunks
 ]
 command = [
diff --git a/.pipelines/windowsai-steps.yml b/.pipelines/windowsai-steps.yml
index ff5179e6135c2..855573de753b0 100644
--- a/.pipelines/windowsai-steps.yml
+++ b/.pipelines/windowsai-steps.yml
@@ -80,11 +80,11 @@ jobs:
 
     # must call vsdevcmd first to add cmake to PATH
     - script: |
-        curl -O -L https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-windows-x86_64.zip
-        7z x cmake-3.26.3-windows-x86_64.zip
+        curl -O -L https://github.com/Kitware/CMake/releases/download/v3.28.3/cmake-3.28.3-windows-x86_64.zip
+        7z x cmake-3.28.3-windows-x86_64.zip
         set PYTHONHOME=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools
         set PYTHONPATH=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools
-        $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe
+        $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos --windows_sdk_version "10.0.22621.0" $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE"  --cmake_path $(Build.BinariesDirectory)\cmake-3.28.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.28.3-windows-x86_64\bin\ctest.exe
       workingDirectory: '$(Build.BinariesDirectory)'
       displayName: 'Generate cmake config'
 
diff --git a/.vscode/settings.json b/.vscode/settings.json
index 3e2b1f31dd6cf..98d23090fd474 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -21,5 +21,8 @@
     "cpplint.filters": [
         "-build/include_subdir",
         "-runtime/references"
-    ]
+    ],
+    "files.associations": {
+        "span": "cpp"
+    }
 }
diff --git a/CITATION.cff b/CITATION.cff
index 82bcac5a7b750..10b7290022aef 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -3,8 +3,7 @@ title: ONNX Runtime
 message: "Please use this information to cite ONNX Runtime in
   research or other publications."
 authors:
-  - affiliation: Microsoft Corporation
-    given-names: ONNX Runtime developers
+  - name: ONNX Runtime developers
 date-released: 2018-11-29
 url: "https://onnxruntime.ai"
 repository-code: "https://github.com/microsoft/onnxruntime"
diff --git a/cgmanifests/generate_cgmanifest.py b/cgmanifests/generate_cgmanifest.py
index 81181d3ccfb20..3cecbb0cc977f 100644
--- a/cgmanifests/generate_cgmanifest.py
+++ b/cgmanifests/generate_cgmanifest.py
@@ -115,8 +115,8 @@ def normalize_path_separators(path):
 submodule_lines = proc.stdout.splitlines()
 for submodule_line in submodule_lines:
     (absolute_path, url, commit) = submodule_line.split(" ")
-    git_deps[GitDep(commit, url)] = "git submodule at {}".format(
-        normalize_path_separators(os.path.relpath(absolute_path, REPO_DIR))
+    git_deps[GitDep(commit, url)] = (
+        f"git submodule at {normalize_path_separators(os.path.relpath(absolute_path, REPO_DIR))}"
     )
 
 with open(os.path.join(SCRIPT_DIR, "..", "cmake", "deps.txt")) as f:
diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json
index efd901787fdb7..3e13a567b1eaa 100644
--- a/cgmanifests/generated/cgmanifest.json
+++ b/cgmanifests/generated/cgmanifest.json
@@ -86,7 +86,7 @@
       "component": {
         "type": "git",
         "git": {
-          "commitHash": "6df40a2471737b27271bdd9b900ab5f3aec746c7",
+          "commitHash": "0100f6a5779831fa7a651e4b67ef389a8752bd9b",
           "repositoryUrl": "https://github.com/google/flatbuffers.git"
         },
         "comments": "flatbuffers"
@@ -116,7 +116,7 @@
       "component": {
         "type": "git",
         "git": {
-          "commitHash": "361e8d1cfe0c6c36d30b39f1b61302ece5507320",
+          "commitHash": "344117638c8ff7e239044fd0fa7085839fc03021",
           "repositoryUrl": "https://github.com/google/benchmark.git"
         },
         "comments": "google_benchmark"
@@ -206,7 +206,7 @@
       "component": {
         "type": "git",
         "git": {
-          "commitHash": "c11386eb632eec7c1c2aa323142f73519f946e2a",
+          "commitHash": "150e7527d5286ddd3a995c228dedf8d76a7a86bc",
           "repositoryUrl": "https://github.com/intel/neural-speed.git"
         },
         "comments": "neural_speed"
diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index 90fe8276ea9c7..ee1959bb357fe 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -88,7 +88,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF)
 option(onnxruntime_USE_SNPE "Build with SNPE support" OFF)
 option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF)
 option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
-option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" OFF)
+option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" ON)
 option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
 option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON)
 option(onnxruntime_BUILD_CSHARP "Build C# library" OFF)
@@ -117,9 +117,7 @@ option(onnxruntime_CROSS_COMPILING "Cross compiling onnx runtime" OFF)
 option(onnxruntime_GCOV_COVERAGE "Compile with options necessary to run code coverage" OFF)
 option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF)
 
-#It's preferred to turn it OFF when onnxruntime is dynamically linked to PROTOBUF. But Tensort always required the full version of protobuf.
-cmake_dependent_option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF "NOT onnxruntime_USE_TENSORRT" ON)
-option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir")
+option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF)
 option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF)
 option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF)
 cmake_dependent_option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS_ENABLE_DUMP_TO_SQLDB "Build dump debug information about node inputs and outputs with support for sql database." OFF "onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS" OFF)
@@ -325,17 +323,29 @@ if (onnxruntime_USE_ROCM)
   endif()
 
   # replicate strategy used by pytorch to get ROCM_VERSION
-  # https://github.com/pytorch/pytorch/blob/8eb21488fdcdb8b0e6fa2e46179b5fa6c42e75af/cmake/public/LoadHIP.cmake#L153-L173
-  file(READ "${onnxruntime_ROCM_HOME}/.info/version-dev" ROCM_VERSION_DEV_RAW)
-  string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW})
-  if (ROCM_VERSION_DEV_MATCH)
+  # https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake
+  # with modification
+  if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version")
+    file(READ "${onnxruntime_ROCM_HOME}/.info/version" ROCM_VERSION_DEV_RAW)
+    string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW})
+  elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h")
+    file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW)
+    string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW})
+  elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h")
+    file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW)
+    string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW})
+  endif()
+
+  if (ROCM_VERSION_MATCH)
     set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1})
     set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2})
     set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3})
     set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}")
     math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}")
+  else()
+    message(FATAL_ERROR "Cannot determine ROCm version string")
   endif()
-  message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version-dev ****\n")
+  message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version ****\n")
   message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}")
   message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}")
   message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}")
@@ -716,6 +726,9 @@ if (onnxruntime_USE_CUDA)
     set(onnxruntime_USE_FLASH_ATTENTION OFF)
     set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
   endif()
+  if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4)
+    message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4")
+  endif()
 else()
   set(onnxruntime_USE_FLASH_ATTENTION OFF)
   set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
@@ -736,8 +749,8 @@ if (onnxruntime_USE_CUDA)
       list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1)
       list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1)
     endif()
-
 endif()
+
 if (onnxruntime_USE_VITISAI)
     list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1)
     list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_VITISAI=1)
@@ -1193,7 +1206,7 @@ if (onnxruntime_USE_DNNL)
   add_compile_definitions(DNNL_OPENMP)
 endif()
 
-if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD)
+if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_TVM)
   include(neural_speed)
   if (USE_NEURAL_SPEED)
     list(APPEND onnxruntime_EXTERNAL_LIBRARIES neural_speed::bestla)
@@ -1240,17 +1253,15 @@ if (onnxruntime_USE_TVM)
     $<TARGET_PROPERTY:tvm,INTERFACE_INCLUDE_DIRECTORIES>)
 
   set(onnxruntime_tvm_libs onnxruntime_providers_tvm)
-
-  # needs to link with stdc++fs in Linux
-  if (UNIX)
-    if (NOT APPLE)
-      set(FS_STDLIB stdc++fs)
-    endif()
-  endif()
-  list(APPEND onnxruntime_EXTERNAL_LIBRARIES tvm ${FS_STDLIB})
+  list(APPEND onnxruntime_EXTERNAL_LIBRARIES tvm)
   list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES tvm)
 endif()
 
+# needs to link with stdc++fs in Linux
+if (UNIX AND "${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS 9)
+  set(FS_STDLIB stdc++fs)
+endif()
+list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${FS_STDLIB})
 
 # onnxruntime-extensions
 if (onnxruntime_USE_EXTENSIONS)
@@ -1260,11 +1271,7 @@ endif()
 #Dependencies end. In the next we'll enable "treat warning as error"
 
 #Adjust warning flags
-if (onnxruntime_USE_CUDA)
-  set_msvc_c_cpp_compiler_warning_level(3)
-else()
-  set_msvc_c_cpp_compiler_warning_level(4)
-endif()
+set_msvc_c_cpp_compiler_warning_level(4)
 
 set(onnxruntime_DELAYLOAD_FLAGS "")
 
@@ -1283,34 +1290,6 @@ if (onnxruntime_USE_OPENVINO)
 
   add_definitions(-DUSE_OPENVINO=1)
 
-  if (EXISTS "$ENV{INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/version.txt")
-    file(READ $ENV{INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/version.txt VER)
-  endif()
-
-  if (NOT DEFINED ENV{INTEL_OPENVINO_DIR})
-    message(FATAL_ERROR "[Couldn't locate OpenVINO] OpenVINO may not have been initialized")
-  endif()
-
-  # Check OpenVINO version for support
-  if ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.0")
-    set(OPENVINO_VERSION "2023.0")
-    add_definitions(-DOPENVINO_2023_0=1)
-  elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.1")
-    set(OPENVINO_VERSION "2023.1")
-    add_definitions(-DOPENVINO_2023_1=1)
-  elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.2")
-    set(OPENVINO_VERSION "2023.2")
-    add_definitions(-DOPENVINO_2023_2=1)
-  elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.3")
-    set(OPENVINO_VERSION "2023.3")
-    add_definitions(-DOPENVINO_2023_3=1)
-  elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "openvino")
-    set(OPENVINO_VERSION "2023.3")
-    add_definitions(-DOPENVINO_2023_3=1)
-  else()
-    message(FATAL_ERROR "Unsupported OpenVINO version: ${INTEL_OPENVINO_DIR}")
-  endif()
-
   if (onnxruntime_USE_OPENVINO_GPU_FP32)
     add_definitions(-DOPENVINO_CONFIG_GPU_FP32=1)
   endif()
@@ -1327,6 +1306,10 @@ if (onnxruntime_USE_OPENVINO)
     add_definitions(-DOPENVINO_CONFIG_CPU_FP16=1)
   endif()
 
+  if (onnxruntime_USE_OPENVINO_NPU)
+    add_definitions(-DOPENVINO_CONFIG_NPU=1)
+  endif()
+
   if (onnxruntime_USE_OPENVINO_GPU_FP32_NP)
     add_definitions(-DOPENVINO_CONFIG_GPU_FP32=1)
     add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1)
@@ -1347,6 +1330,11 @@ if (onnxruntime_USE_OPENVINO)
     add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1)
   endif()
 
+  if (onnxruntime_USE_OPENVINO_NPU_NP)
+    add_definitions(-DOPENVINO_CONFIG_NPU=1)
+    add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1)
+  endif()
+
   if (onnxruntime_USE_OPENVINO_HETERO)
     add_definitions(-DOPENVINO_CONFIG_HETERO=1)
     add_definitions(-DDEVICE_NAME="${onnxruntime_USE_OPENVINO_DEVICE}")
@@ -1401,6 +1389,10 @@ endif()
 if (onnxruntime_USE_CUDA)
   set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
   set(CMAKE_CUDA_STANDARD 17)
+  if(onnxruntime_CUDA_HOME)
+    file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME})
+  endif()
+  find_package(CUDAToolkit REQUIRED)
   if(onnxruntime_CUDNN_HOME)
     file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME)
   endif()
@@ -1601,7 +1593,7 @@ if (UNIX AND onnxruntime_USE_NCCL)
 else()
   set(onnxruntime_USE_NCCL OFF)
   set(onnxruntime_USE_MPI OFF)
-message( WARNING "MPI and NCCL disabled on Win build." )
+  message( WARNING "MPI and NCCL are disabled because build is on Windows or USE_NCCL is set to OFF." )
 endif()
 
 if (onnxruntime_USE_MPI)
@@ -1730,14 +1722,12 @@ if(onnxruntime_BUILD_KERNEL_EXPLORER)
 endif()
 
 # When GDK_PLATFORM is set then WINAPI_FAMILY is defined in gdk_toolchain.cmake (along with other relevant flags/definitions).
-if (WIN32 AND NOT GDK_PLATFORM)
+if (WIN32 AND NOT GDK_PLATFORM AND NOT CMAKE_CROSSCOMPILING)
   if (NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib)
     # On onecore, link to the onecore build of the MSVC runtime
     get_filename_component(msvc_path "${CMAKE_C_COMPILER}/../../../.." ABSOLUTE)
     link_directories(BEFORE "${msvc_path}/lib/onecore/${onnxruntime_target_platform}")
-    # The .lib files in the MSVC runtime have a DEFAULITLIB entry for onecore.lib, which in turn links to reverse forwarders.
-    # We ignore that entry and use onecore_apiset.lib instead, since system components must not rely on reverse forwarders.
-    add_link_options("/NODEFAULTLIB:onecore.lib")
+    # The .lib files in the MSVC runtime have a DEFAULITLIB entry for onecore.lib, but it shold not cause any conflict with onecoreuap.lib
   endif()
 endif()
 
diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake
index a56864ebf4644..9a3bc3302cc2b 100644
--- a/cmake/adjust_global_compile_flags.cmake
+++ b/cmake/adjust_global_compile_flags.cmake
@@ -8,6 +8,15 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Android")
   string(APPEND CMAKE_ASM_FLAGS_RELEASE " -O3")
 endif()
 
+# Suggested by https://gitlab.kitware.com/cmake/cmake/-/issues/20132
+# MacCatalyst is not well supported in CMake
+# The error that can emerge without this flag can look like:
+# "clang : error : overriding '-mmacosx-version-min=11.0' option with '-target x86_64-apple-ios14.0-macabi' [-Werror,-Woverriding-t-option]"
+if (PLATFORM_NAME STREQUAL "macabi")
+  add_compile_options(-Wno-overriding-t-option)
+  add_link_options(-Wno-overriding-t-option)
+endif()
+
 # Enable space optimization for gcc/clang
 # Cannot use "-ffunction-sections -fdata-sections" if we enable bitcode (iOS)
 if (NOT MSVC AND NOT onnxruntime_ENABLE_BITCODE)
@@ -92,13 +101,8 @@ if (onnxruntime_MINIMAL_BUILD)
   endif()
 endif()
 
-# Enable stream for all the non-minimal build, except for DML. There's currently a bug
-# in the allocation planner when reusing buffers and more than one streams are used that
-# make it possible (although rarely) to reach a reference count of 0 for a buffer that is
-# still being used. Since DML doesn't benefit from multiple streams, disabling it is the
-# safest option for now.
-# https://github.com/microsoft/onnxruntime/issues/19480
-if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML)
+# Enable stream for all the non-minimal build
+if (NOT onnxruntime_MINIMAL_BUILD)
   add_compile_definitions(ORT_ENABLE_STREAM)
 endif()
 
@@ -210,7 +214,7 @@ endif()
 
 
 macro(check_nvcc_compiler_flag _FLAG _RESULT)
-    execute_process(COMMAND ${onnxruntime_CUDA_HOME}/bin/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR)
+    execute_process(COMMAND ${CUDAToolkit_BIN_DIR}/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR)
     message("NVCC_ERROR = ${NVCC_ERROR}")
     message("NVCC_OUT = ${NVCC_OUT}")
     if ("${NVCC_OUT}" MATCHES "0")
diff --git a/cmake/deps.txt b/cmake/deps.txt
index cb431f8c77397..22ad9338ea59a 100644
--- a/cmake/deps.txt
+++ b/cmake/deps.txt
@@ -23,10 +23,10 @@ dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b3132
 # Until the 3.4.1 release this is the best option we have.
 # Issue link: https://gitlab.com/libeigen/eigen/-/issues/2744
 eigen;https://gitlab.com/libeigen/eigen/-/archive/e7248b26a1ed53fa030c5c459f7ea095dfd276ac/eigen-e7248b26a1ed53fa030c5c459f7ea095dfd276ac.zip;be8be39fdbc6e60e94fa7870b280707069b5b81a
-flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip;ba0a75fd12dbef8f6557a74e611b7a3d0c5fe7bf
+flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v23.5.26.zip;59422c3b5e573dd192fead2834d25951f1c1670c
 fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494
 fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1
-google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908
+google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.3.zip;bf9870756ee3f8d2d3b346b24ee3600a41c74d3d
 google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752
 googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
 googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73
@@ -35,10 +35,10 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36
 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41
 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063
-neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939
+neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/v0.3.zip;5ec64e3071edc7347ebd8a81679cf06e2bb9b851
 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11
-#use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459)
-onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035
+#use the commit of Final DDS removal. DDS output is now supported by ORT TRT.
+onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/bacfaaa951653cd4e72efe727a543567cb38f7de.zip;26434329612e804164ab7baa6ae629ada56c1b26
 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa
 protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a
 protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874
diff --git a/cmake/deps_update_and_upload.py b/cmake/deps_update_and_upload.py
index d357284d91225..63df3f6f03869 100644
--- a/cmake/deps_update_and_upload.py
+++ b/cmake/deps_update_and_upload.py
@@ -1,56 +1,109 @@
-# in case deps.txt is updated, run this file to update and upload the dependencies so that CI can use them.
-# Before running the script, increase the version number found at:
+# If deps.txt is updated, run this file to update and upload the dependencies so that CI can use them.
+#
+# Before running the script, find the latest version number at:
 # https://aiinfra.visualstudio.com/Lotus/_artifacts/feed/Lotus/UPack/onnxruntime_build_dependencies/versions
+# Increment it to obtain a new version number to use.
+#
 # Run without --do-upload once to verify downloading. Use --do-upload when you are ready to publish.
-# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 --do-upload
-# update version number in tools\ci_build\github\azure-pipelines\templates\download-deps.yml
+# E.g.:
+#   python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82
+#   # check contents of C:/temp/onnxruntime_deps
+#   python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 --no-download --do-upload
+#
+# Next, update the version number in tools/ci_build/github/azure-pipelines/templates/download-deps.yml.
+
+import argparse
+import contextlib
+import pathlib
 import re
 import subprocess
-import os
-import argparse
 import tempfile
 
+script_dir = pathlib.Path(__file__).parent
+
 parser = argparse.ArgumentParser(description="Update dependencies and publish to Azure Artifacts")
 parser.add_argument(
-    "--root-path", type=str, default=tempfile.gettempdir(), help="Target root path for downloaded files"
+    "--root-path",
+    type=pathlib.Path,
+    help="Target root path for downloaded files. If not provided, a temporary directory is used.",
+)
+parser.add_argument(
+    "--version",
+    type=str,
+    help="Package version to publish",
+)
+parser.add_argument(
+    "--do-upload",
+    action="store_true",
+    dest="upload",
+    help="Upload the package to Azure Artifacts",
+)
+parser.add_argument(
+    "--no-download",
+    action="store_false",
+    dest="download",
+    help="Skip downloading the dependency files. "
+    "Use with '--do-upload' and '--root-path' to upload the package from existing dependency files.",
 )
-parser.add_argument("--version", type=str, default="1.0.82", help="Package version to publish")
-parser.add_argument("--do-upload", action="store_true", help="Upload the package to Azure Artifacts")
 args = parser.parse_args()
 
-with open("cmake/deps.txt") as file:
+if args.upload:
+    assert args.version is not None, "'--version' must be specified if uploading."
+
+if args.upload != args.download:
+    assert args.root_path is not None, "'--root-path' must be specified if only downloading or uploading."
+
+deps_path = script_dir / "deps.txt"
+with open(deps_path) as file:
     text = file.read()
 
 lines = [line for line in text.split("\n") if not line.startswith("#") and ";" in line]
 
-root_path = args.root_path
-
-for line in lines:
-    url = re.sub("^[^;]+?;https://([^;]+?);.*", r"https://\1", line)
-    filename = re.sub("^[^;]+?;https://([^;]+?);.*", r"\1", line)
-    full_path = os.path.join(root_path, filename)
-    subprocess.run(["curl", "-sSL", "--create-dirs", "-o", full_path, url])  # noqa: PLW1510
-
-package_name = "onnxruntime_build_dependencies"
-version = args.version
-
-# Check if the user is logged in to Azure
-result = subprocess.run("az account show", shell=True, capture_output=True, text=True)  # noqa: PLW1510
-if "No subscriptions found" in result.stderr:
-    # Prompt the user to log in to Azure
-    print("You are not logged in to Azure. Please log in to continue.")
-    subprocess.run("az login", shell=True)  # noqa: PLW1510
-
-# Publish the package to Azure Artifacts if --no-upload is not specified
-
-cmd = f'az artifacts universal publish --organization https://dev.azure.com/onnxruntime --feed onnxruntime --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}'
-if args.do_upload:
-    subprocess.run(cmd, shell=True)  # noqa: PLW1510
-else:
-    print("would have run: " + cmd)
-
-cmd = f'az artifacts universal publish --organization https://dev.azure.com/aiinfra --feed Lotus --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}'
-if args.do_upload:
-    subprocess.run(cmd, shell=True)  # noqa: PLW1510
-else:
-    print("would have run: " + cmd)
+with contextlib.ExitStack() as context_stack:
+    if args.root_path is not None:
+        root_path = args.root_path.resolve()
+        root_path.mkdir(parents=True, exist_ok=True)
+    else:
+        temp_dir_name = context_stack.enter_context(tempfile.TemporaryDirectory())
+        root_path = pathlib.Path(temp_dir_name)
+
+    if args.download:
+        print(f"Downloading dependencies to directory: {root_path}")
+
+        dep_pattern = re.compile(r"^[^;]+;https://([^;]+);.*$")
+
+        for line in lines:
+            match = dep_pattern.fullmatch(line)
+            if match is None:
+                continue
+
+            dep_path = match[1]
+            url = f"https://{dep_path}"
+            full_path = root_path / dep_path
+
+            subprocess.run(["curl", "-sSL", "--create-dirs", "-o", str(full_path), url], check=True)
+
+    package_name = "onnxruntime_build_dependencies"
+    version = args.version if args.version is not None else "VERSION_PLACEHOLDER"
+
+    if args.upload:
+        # Check if the user is logged in to Azure
+        result = subprocess.run("az account show", shell=True, capture_output=True, text=True, check=False)
+        if "No subscriptions found" in result.stderr:
+            # Prompt the user to log in to Azure
+            print("You are not logged in to Azure. Please log in to continue.")
+            subprocess.run("az login", shell=True, check=True)
+
+    # Publish the package to Azure Artifacts if --do-upload is specified
+
+    cmd = f'az artifacts universal publish --organization https://dev.azure.com/onnxruntime --feed onnxruntime --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}'
+    if args.upload:
+        subprocess.run(cmd, shell=True, check=True)
+    else:
+        print("would have run: " + cmd)
+
+    cmd = f'az artifacts universal publish --organization https://dev.azure.com/aiinfra --feed Lotus --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}'
+    if args.upload:
+        subprocess.run(cmd, shell=True, check=True)
+    else:
+        print("would have run: " + cmd)
diff --git a/cmake/external/neural_speed.cmake b/cmake/external/neural_speed.cmake
index ed711351403a7..3fe9c660f89d6 100644
--- a/cmake/external/neural_speed.cmake
+++ b/cmake/external/neural_speed.cmake
@@ -9,6 +9,7 @@ if(USE_NEURAL_SPEED)
       neural_speed
       URL ${DEP_URL_neural_speed}
       URL_HASH SHA1=${DEP_SHA1_neural_speed}
+      PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/neural_speed/150e7527d5286ddd3a995c228dedf8d76a7a86bc.patch
   )
   set(BTLA_USE_OPENMP OFF)
   onnxruntime_fetchcontent_makeavailable(neural_speed)
diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake
index 22d12b128dc1f..8839dbc8fda4f 100644
--- a/cmake/external/onnxruntime_external_deps.cmake
+++ b/cmake/external/onnxruntime_external_deps.cmake
@@ -14,6 +14,16 @@ foreach(ONNXRUNTIME_DEP IN LISTS ONNXRUNTIME_DEPS_LIST)
     set(DEP_URL_${ONNXRUNTIME_DEP_NAME} ${ONNXRUNTIME_DEP_URL})
     # The third column is SHA1 hash value
     set(DEP_SHA1_${ONNXRUNTIME_DEP_NAME} ${ONNXRUNTIME_DEP})
+
+    if(ONNXRUNTIME_DEP_URL MATCHES "^https://")
+      # Search a local mirror folder
+      string(REGEX REPLACE "^https://" "${REPO_ROOT}/mirror/" LOCAL_URL "${ONNXRUNTIME_DEP_URL}")
+
+      if(EXISTS "${LOCAL_URL}")
+        cmake_path(ABSOLUTE_PATH LOCAL_URL)
+        set(DEP_URL_${ONNXRUNTIME_DEP_NAME} "${LOCAL_URL}")
+      endif()
+    endif()
   endif()
 endforeach()
 
@@ -37,8 +47,13 @@ if (onnxruntime_BUILD_UNIT_TESTS)
     set(gtest_disable_pthreads ON)
   endif()
   set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
-  if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
-    # Needs to update onnxruntime/test/xctest/xcgtest.mm
+  if (IOS OR ANDROID)
+    # on mobile platforms the absl flags class dumps the flag names (assumably for binary size), which breaks passing
+    # any args to gtest executables, such as using --gtest_filter to debug a specific test.
+    # Processing of compile definitions:
+    # https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/config.h#L21
+    # If set, this code throws away the flag and does nothing on registration, which results in no flags being known:
+    # https://github.com/abseil/abseil-cpp/blob/8dc90ff07402cd027daec520bb77f46e51855889/absl/flags/flag.h#L205-L217
     set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE)
   else()
     set(GTEST_HAS_ABSL ON CACHE BOOL "" FORCE)
@@ -104,7 +119,7 @@ FetchContent_Declare(
     URL ${DEP_URL_flatbuffers}
     URL_HASH SHA1=${DEP_SHA1_flatbuffers}
     PATCH_COMMAND ${ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND}
-    FIND_PACKAGE_ARGS 1.12.0...<2.0.0 NAMES Flatbuffers
+    FIND_PACKAGE_ARGS 23.5.9 NAMES Flatbuffers
 )
 
 # Download a protoc binary from Internet if needed
@@ -256,14 +271,7 @@ if (onnxruntime_ENABLE_CPUINFO)
       set(CPUINFO_SUPPORTED TRUE)
     endif()
     if (WIN32)
-      # Exclude Windows ARM build and Windows Store
-      if (${onnxruntime_target_platform} MATCHES "^(ARM.*|arm.*)$" )
-        message(WARNING "Cpuinfo not included for compilation problems with Windows ARM.")
-        set(CPUINFO_SUPPORTED FALSE)
-      elseif (WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib)
-        message(WARNING "Cpuinfo not included non-Desktop builds")
-        set(CPUINFO_SUPPORTED FALSE)
-      endif()
+      set(CPUINFO_SUPPORTED TRUE)
     elseif (NOT ${onnxruntime_target_platform} MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64|arm64)$")
       message(WARNING
         "Target processor architecture \"${onnxruntime_target_platform}\" is not supported in cpuinfo. "
@@ -307,13 +315,23 @@ if (CPUINFO_SUPPORTED)
   set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE INTERNAL "")
   set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE INTERNAL "")
   set(CPUINFO_BUILD_BENCHMARKS OFF CACHE INTERNAL "")
-
-  FetchContent_Declare(
-    pytorch_cpuinfo
-    URL ${DEP_URL_pytorch_cpuinfo}
-    URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo}
-    FIND_PACKAGE_ARGS NAMES cpuinfo
-  )
+  if(onnxruntime_target_platform STREQUAL "ARM64EC")
+      message("Applying a patch for Windows ARM64EC in cpuinfo")
+      FetchContent_Declare(
+        pytorch_cpuinfo
+        URL ${DEP_URL_pytorch_cpuinfo}
+        URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo}
+        PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch
+        FIND_PACKAGE_ARGS NAMES cpuinfo
+      )
+  else()
+      FetchContent_Declare(
+        pytorch_cpuinfo
+        URL ${DEP_URL_pytorch_cpuinfo}
+        URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo}
+        FIND_PACKAGE_ARGS NAMES cpuinfo
+      )
+  endif()
   set(ONNXRUNTIME_CPUINFO_PROJ pytorch_cpuinfo)
 endif()
 
@@ -556,16 +574,15 @@ message("Finished fetching external dependencies")
 set(onnxruntime_LINK_DIRS )
 if (onnxruntime_USE_CUDA)
       #TODO: combine onnxruntime_CUDNN_HOME and onnxruntime_CUDA_HOME, assume they are the same
+      find_package(CUDAToolkit REQUIRED)
       if (WIN32)
         if(onnxruntime_CUDNN_HOME)
           list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib/x64)
         endif()
-        list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/x64/lib64)
       else()
         if(onnxruntime_CUDNN_HOME)
           list(APPEND onnxruntime_LINK_DIRS  ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib64)
         endif()
-        list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/lib64)
       endif()
 endif()
 
diff --git a/cmake/maccatalyst_prepare_objects_for_prelink.py b/cmake/maccatalyst_prepare_objects_for_prelink.py
new file mode 100644
index 0000000000000..34664b4e05237
--- /dev/null
+++ b/cmake/maccatalyst_prepare_objects_for_prelink.py
@@ -0,0 +1,72 @@
+#!/usr/bin/env python3
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import os
+import shutil
+import sys
+
+
+# Note: This script is mainly used for sanity checking/validating the files in the .a library equal to the .o files
+# in the source dir to handle the case of source files having duplicate names under different subdirectories for
+# each onnxruntime library. (Only applicable when doing a Mac Catalyst build.)
+def main():
+    source_dir = sys.argv[1]
+    dest_dir = sys.argv[2]
+    files_from_static_lib = sys.argv[3]
+    files_from_source_dir = []
+    for subdir, _, files in os.walk(source_dir):
+        for file_name in files:
+            if file_name.endswith(".o"):
+                files_from_source_dir.append(file_name.strip())
+                dest_name_without_extension, _ = os.path.splitext(file_name)
+                counter = 0
+
+                dest_file = f"{dest_name_without_extension}.o"
+                while os.path.exists(os.path.join(dest_dir, dest_file)):
+                    print("Duplicate file name from source: " + os.path.join(source_dir, subdir, file_name))
+                    counter += 1
+                    dest_file = f"{dest_name_without_extension}_{counter}.o"
+                    print("Renamed file name in destination: " + os.path.join(dest_dir, dest_file))
+
+                destination_path = os.path.join(dest_dir, dest_file)
+                source_file = os.path.join(source_dir, subdir, file_name)
+                shutil.copy(source_file, destination_path)
+
+    # Sanity check to ensure the number of .o object from the original cmake source directory matches with the number
+    # of .o files extracted from each .a onnxruntime library
+    file_lists_from_static_lib = []
+    with open(files_from_static_lib) as file:
+        filenames = file.readlines()
+    for filename in filenames:
+        file_lists_from_static_lib.append(filename.strip())
+
+    sorted_list1 = sorted(file_lists_from_static_lib)
+    sorted_list2 = sorted(files_from_source_dir)
+
+    if len(sorted_list1) != len(sorted_list2):
+        print(
+            "Caught a mismatch in the number of .o object files from the original cmake source directory: ",
+            len(sorted_list1),
+            "the number of .o files extracted from the static onnxruntime lib: ",
+            len(sorted_list2),
+            "for: ",
+            os.path.basename(source_dir),
+        )
+
+    if sorted_list1 == sorted_list2:
+        print(
+            "Sanity check passed: object files from original source directory matches with files extracted "
+            "from static library for: ",
+            os.path.basename(source_dir),
+        )
+    else:
+        print(
+            "Error: Mismatch between object files from original source directory "
+            "and the .o files extracted from static library for: ",
+            os.path.basename(source_dir),
+        )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake
index 2ead13e554197..e15c8a046dc20 100644
--- a/cmake/onnxruntime.cmake
+++ b/cmake/onnxruntime.cmake
@@ -281,7 +281,13 @@ endif()
 
 # Assemble the Apple static framework (iOS and macOS)
 if(onnxruntime_BUILD_APPLE_FRAMEWORK)
-  set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-${CMAKE_OSX_SYSROOT})
+  # when building for mac catalyst, the CMAKE_OSX_SYSROOT is set to MacOSX as well, to avoid duplication,
+  # we specify as `-macabi` in the name of the output static apple framework directory.
+  if (PLATFORM_NAME STREQUAL "macabi")
+    set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-macabi)
+  else()
+    set(STATIC_FRAMEWORK_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}-${CMAKE_OSX_SYSROOT})
+  endif()
 
   # Setup the various directories required. Remove any existing ones so we start with a clean directory.
   set(STATIC_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/static_libraries)
@@ -299,18 +305,34 @@ if(onnxruntime_BUILD_APPLE_FRAMEWORK)
   # to enforce symbol visibility. doing it this way limits the symbols included from the .a files to symbols used
   # by the ORT .o files.
 
-  # If it's an onnxruntime library, extract .o files to a separate directory for each library to avoid any clashes
-  # with filenames (e.g. utils.o)
+  # If it's an onnxruntime library, extract .o files from the original cmake build path to a separate directory for
+  # each library to avoid any clashes with filenames (e.g. utils.o)
   foreach(_LIB ${onnxruntime_INTERNAL_LIBRARIES} )
     GET_TARGET_PROPERTY(_LIB_TYPE ${_LIB} TYPE)
     if(_LIB_TYPE STREQUAL "STATIC_LIBRARY")
       set(CUR_STATIC_LIB_OBJ_DIR ${STATIC_LIB_TEMP_DIR}/$<TARGET_LINKER_FILE_BASE_NAME:${_LIB}>)
       add_custom_command(TARGET onnxruntime POST_BUILD
                          COMMAND ${CMAKE_COMMAND} -E make_directory ${CUR_STATIC_LIB_OBJ_DIR})
-
-      add_custom_command(TARGET onnxruntime POST_BUILD
-                         COMMAND ar ARGS -x $<TARGET_FILE:${_LIB}>
-                         WORKING_DIRECTORY ${CUR_STATIC_LIB_OBJ_DIR})
+      if (PLATFORM_NAME STREQUAL "macabi")
+        # There exists several duplicate names for source files under different subdirectories within
+        # each onnxruntime library. (e.g. onnxruntime/contrib_ops/cpu/element_wise_ops.o
+        # vs. onnxruntime/providers/core/cpu/math/element_wise_ops.o)
+        # In that case, using 'ar ARGS -x' to extract the .o files from .a lib would possibly cause duplicate naming files being overwritten
+        # and lead to missing undefined symbol error in the generated binary.
+        # So we use the below python script as a sanity check to do a recursive find of all .o files in ${CUR_TARGET_CMAKE_SOURCE_LIB_DIR}
+        # and verifies that matches the content of the .a, and then copy from the source dir.
+        # TODO: The copying action here isn't really necessary. For future fix, consider using the script extracts from the ar with the rename to potentially
+        # make both maccatalyst and other builds do the same thing.
+        set(CUR_TARGET_CMAKE_SOURCE_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/${_LIB}.dir)
+        add_custom_command(TARGET onnxruntime POST_BUILD
+                          COMMAND ar -t $<TARGET_FILE:${_LIB}> | grep "\.o$"  > ${_LIB}.object_file_list.txt
+                          COMMAND ${CMAKE_COMMAND} -E env python3 ${CMAKE_CURRENT_SOURCE_DIR}/maccatalyst_prepare_objects_for_prelink.py ${CUR_TARGET_CMAKE_SOURCE_LIB_DIR} ${CUR_STATIC_LIB_OBJ_DIR} ${CUR_STATIC_LIB_OBJ_DIR}/${_LIB}.object_file_list.txt
+                          WORKING_DIRECTORY ${CUR_STATIC_LIB_OBJ_DIR})
+      else()
+        add_custom_command(TARGET onnxruntime POST_BUILD
+        COMMAND ar ARGS -x $<TARGET_FILE:${_LIB}>
+        WORKING_DIRECTORY ${CUR_STATIC_LIB_OBJ_DIR})
+      endif()
     endif()
   endforeach()
 
diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake
index 6b8c2560b1714..fb56e3f3445d4 100644
--- a/cmake/onnxruntime_common.cmake
+++ b/cmake/onnxruntime_common.cmake
@@ -201,10 +201,6 @@ endif()
 
 
 if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64)
-  if((WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) OR ((ARM64 OR ARM) AND MSVC))
-    # msvc compiler report syntax error with cpuinfo arm source files
-    # and cpuinfo does not have code for getting arm uarch info under windows
-  else()
     # Link cpuinfo if supported
     # Using it mainly in ARM with Android.
     # Its functionality in detecting x86 cpu features are lacking, so is support for Windows.
@@ -212,7 +208,6 @@ if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64)
       onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo)
       list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo ${ONNXRUNTIME_CLOG_TARGET_NAME})
     endif()
-  endif()
 endif()
 
 if (NOT onnxruntime_BUILD_SHARED_LIB)
diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake
index 3f532ec2c3261..4d51325b8414e 100644
--- a/cmake/onnxruntime_graph.cmake
+++ b/cmake/onnxruntime_graph.cmake
@@ -7,8 +7,26 @@ file(GLOB_RECURSE onnxruntime_graph_src CONFIGURE_DEPENDS
   "${ONNXRUNTIME_ROOT}/core/graph/*.cc"
   )
 
-# create empty list for any excludes
+# start with empty training srcs list
+set(orttraining_graph_src)
+
+if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING)
+  set(orttraining_graph_src
+      "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc"
+      "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h"
+      )
+endif()
+
+if (onnxruntime_ENABLE_TRAINING)
+  file(GLOB_RECURSE orttraining_graph_src CONFIGURE_DEPENDS
+      "${ORTTRAINING_SOURCE_DIR}/core/graph/*.h"
+      "${ORTTRAINING_SOURCE_DIR}/core/graph/*.cc"
+      )
+endif()
+
+# create empty lists for any excludes
 set(onnxruntime_graph_src_exclude_patterns)
+set(orttraining_graph_src_exclude_patterns)
 
 if (onnxruntime_MINIMAL_BUILD)
   # remove schema registration support
@@ -22,11 +40,18 @@ if (onnxruntime_MINIMAL_BUILD)
     "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.cc"
     "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/shape_inference_functions.h"
     "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/shape_inference_functions.cc"
+    "${ONNXRUNTIME_ROOT}/core/graph/dml_ops/dml_defs.h"
+    "${ONNXRUNTIME_ROOT}/core/graph/dml_ops/dml_defs.cc"
     "${ONNXRUNTIME_ROOT}/core/graph/function_template.h"
     "${ONNXRUNTIME_ROOT}/core/graph/function_utils.h"
     "${ONNXRUNTIME_ROOT}/core/graph/function_utils.cc"
   )
 
+  list(APPEND orttraining_graph_src_exclude_patterns
+    "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h"
+    "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc"
+  )
+
   # no Function support initially
   list(APPEND onnxruntime_graph_src_exclude_patterns
     "${ONNXRUNTIME_ROOT}/core/graph/function*"
@@ -64,30 +89,12 @@ endif()
 file(GLOB onnxruntime_graph_src_exclude ${onnxruntime_graph_src_exclude_patterns})
 list(REMOVE_ITEM onnxruntime_graph_src ${onnxruntime_graph_src_exclude})
 
-file(GLOB_RECURSE onnxruntime_ir_defs_src CONFIGURE_DEPENDS
-  "${ONNXRUNTIME_ROOT}/core/defs/*.cc"
-)
-
-if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING)
-  set(orttraining_graph_src
-      "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc"
-      "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h"
-      )
-endif()
-
-if (onnxruntime_ENABLE_TRAINING)
-  file(GLOB_RECURSE orttraining_graph_src CONFIGURE_DEPENDS
-      "${ORTTRAINING_SOURCE_DIR}/core/graph/*.h"
-      "${ORTTRAINING_SOURCE_DIR}/core/graph/*.cc"
-      )
-endif()
-
-set(onnxruntime_graph_lib_src ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src})
 if (onnxruntime_ENABLE_TRAINING_OPS)
-    list(APPEND onnxruntime_graph_lib_src ${orttraining_graph_src})
+  file(GLOB orttraining_graph_src_exclude ${orttraining_graph_src_exclude_patterns})
+  list(REMOVE_ITEM orttraining_graph_src ${orttraining_graph_src_exclude})
 endif()
 
-onnxruntime_add_static_library(onnxruntime_graph ${onnxruntime_graph_lib_src})
+onnxruntime_add_static_library(onnxruntime_graph ${onnxruntime_graph_src} ${orttraining_graph_src})
 add_dependencies(onnxruntime_graph onnx_proto flatbuffers::flatbuffers)
 onnxruntime_add_include_to_target(onnxruntime_graph onnxruntime_common ${WIL_TARGET} onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers safeint_interface Boost::mp11)
 
@@ -120,7 +127,7 @@ endif()
 
 set_target_properties(onnxruntime_graph PROPERTIES FOLDER "ONNXRuntime")
 set_target_properties(onnxruntime_graph PROPERTIES LINKER_LANGUAGE CXX)
-source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src})
+source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_graph_src})
 if (onnxruntime_ENABLE_TRAINING_OPS)
     source_group(TREE ${ORTTRAINING_ROOT} FILES ${orttraining_graph_src})
 endif()
diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index 17de2aa4aaea6..6b7d4402be8eb 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -631,6 +631,12 @@ if (WIN32)
   endif()
 endif()
 
+if (PLATFORM_NAME STREQUAL "macabi")
+  # Needed for maccatalyst C compilation
+  # i.e. the flags below add "--target=x86_64-apple-ios14.0-macabi -ffunction-sections -fdata-sections"
+  target_compile_options(onnxruntime_mlas PRIVATE ${CMAKE_C_FLAGS})
+endif()
+
 if (NOT onnxruntime_BUILD_SHARED_LIB)
     install(TARGETS onnxruntime_mlas
             ARCHIVE   DESTINATION ${CMAKE_INSTALL_LIBDIR}
diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake
index c6c9d8f4894c5..7e7819ac31a19 100644
--- a/cmake/onnxruntime_providers.cmake
+++ b/cmake/onnxruntime_providers.cmake
@@ -66,11 +66,7 @@ if(onnxruntime_USE_CUDA)
   set(PROVIDERS_CUDA onnxruntime_providers_cuda)
 endif()
 if(onnxruntime_USE_COREML)
-  if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS")
-    set(PROVIDERS_COREML onnxruntime_providers_coreml coreml_proto)
-  else()
-    set(PROVIDERS_COREML onnxruntime_providers_coreml)
-  endif()
+  set(PROVIDERS_COREML onnxruntime_providers_coreml coreml_proto)
 endif()
 if(onnxruntime_USE_NNAPI_BUILTIN)
   set(PROVIDERS_NNAPI onnxruntime_providers_nnapi)
diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake
index 2ca4a22aca7d2..b8ebc4ca53239 100644
--- a/cmake/onnxruntime_providers_coreml.cmake
+++ b/cmake/onnxruntime_providers_coreml.cmake
@@ -7,6 +7,27 @@ endif()
 
 add_compile_definitions(USE_COREML=1)
 
+# Check if we can build the coremltools code for creating an mlpackage with an mlprogram.
+# The coremltools source requires std::filesystem::path which is only available from iOS 13 on.
+set(_enable_ML_PROGRAM ON)
+if (IOS AND CMAKE_OSX_DEPLOYMENT_TARGET VERSION_LESS 13.0)
+  message(WARNING "CoreML ML Program is not supported on iOS < 13.0. Excluding ML Program support from build.")
+  set(_enable_ML_PROGRAM OFF)
+elseif(LINUX)
+  # uuid-dev is required. we don't bother installing on CIs as it's really for manual developer testing.
+  find_library(LibUUID_LIBRARY NAMES uuid)
+  find_path(LibUUID_INCLUDE_DIR NAMES uuid/uuid.h)
+  if (NOT LibUUID_INCLUDE_DIR)
+    message(STATUS "uuid/uuid.h was not found as is required for ML Program support. "
+                    "Run `sudo apt install uuid-dev` if you need to test ML Program related CoreML EP code. ")
+    set(_enable_ML_PROGRAM OFF)
+  endif()
+endif()
+
+if (_enable_ML_PROGRAM)
+  add_compile_definitions(COREML_ENABLE_MLPROGRAM=1)
+endif()
+
 # Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml_proto
 set(COREML_PROTO_ROOT ${coremltools_SOURCE_DIR}/mlmodel/format)
 file(GLOB coreml_proto_srcs "${COREML_PROTO_ROOT}/*.proto")
@@ -19,8 +40,8 @@ target_compile_definitions(coreml_proto
                            PUBLIC $<TARGET_PROPERTY:${PROTOBUF_LIB},INTERFACE_COMPILE_DEFINITIONS>)
 set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility=hidden")
 set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility-inlines-hidden")
-set(_src_sub_dir "coreml_proto/")
 
+set(_src_sub_dir "coreml_proto/")
 onnxruntime_protobuf_generate(
   APPEND_PATH
   GEN_SRC_SUB_DIR ${_src_sub_dir}
@@ -49,12 +70,16 @@ list(FILTER coreml_proto_generated_srcs INCLUDE REGEX "\.pb\.(h|cc)$")
 source_group(TREE ${CMAKE_CURRENT_BINARY_DIR} PREFIX coreml_proto_generated FILES ${coreml_proto_generated_srcs})
 
 # These are shared utils,
-# TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML
-file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
+# TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML
+file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
   "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
   "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
 )
 
+file(GLOB onnxruntime_providers_coreml_public_headers CONFIGURE_DEPENDS
+  "${ONNXRUNTIME_INCLUDE_DIR}/core/providers/coreml/*.h"
+)
+
 file(GLOB
   onnxruntime_providers_coreml_cc_srcs_top CONFIGURE_DEPENDS
   "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.h"
@@ -67,15 +92,38 @@ file(GLOB_RECURSE
   "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.h"
   "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc"
 )
-if (NOT CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND NOT CMAKE_SYSTEM_NAME STREQUAL "iOS")
-  list(REMOVE_ITEM onnxruntime_providers_coreml_cc_srcs_nested
-  "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.h"
-  "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.cc"
+
+if(_enable_ML_PROGRAM)
+  # Add helpers to create mlpackage weights. limit to just the files we need to minimize the changes to make them
+  # build on Windows and Linux.
+  file(GLOB
+    onnxruntime_providers_coreml_milblob_cc_srcs CONFIGURE_DEPENDS
+    "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.hpp"
+    "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.cpp"
+    "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Util/*.hpp"
+    "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/BlobDataType.hpp"
+    "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageFormat.hpp"
+    "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/FileWriter.?pp"
+    "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageWriter.?pp"
+  )
+
+  # Add helpers to create mlpackage
+  file(GLOB
+    onnxruntime_providers_coreml_modelpackage_cc_srcs CONFIGURE_DEPENDS
+    "${coremltools_SOURCE_DIR}/modelpackage/src/ModelPackage.?pp"
+    "${coremltools_SOURCE_DIR}/modelpackage/src/utils/JsonMap.?pp"
   )
+
+  set(coremltools_srcs
+    ${onnxruntime_providers_coreml_milblob_cc_srcs}
+    ${onnxruntime_providers_coreml_modelpackage_cc_srcs}
+  )
+
+  source_group(TREE ${coremltools_SOURCE_DIR} PREFIX coremltools FILES ${coremltools_srcs})
 endif()
 
 # Add CoreML objective c++ source code
-if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS")
+if (APPLE)
   file(GLOB
     onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS
     "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h"
@@ -83,26 +131,79 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS")
     "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h"
     "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm"
   )
+else()
+  # add the Model implementation that uses the protobuf types but excludes any actual CoreML dependencies
+  # by using stub implementations on non-Apple platforms.
+  file(GLOB
+    onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS
+    "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h"
+    "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils_stub.cc"
+    "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h"
+    "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model_stub.cc"
+  )
 endif()
 
 set(onnxruntime_providers_coreml_cc_srcs
   ${onnxruntime_providers_coreml_cc_srcs_top}
   ${onnxruntime_providers_coreml_cc_srcs_nested}
   ${onnxruntime_providers_shared_utils_cc_srcs}
+  ${onnxruntime_providers_coreml_objcc_srcs}
 )
 
-source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_coreml_cc_srcs})
+source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_providers_coreml_cc_srcs})
+source_group(TREE ${ONNXRUNTIME_INCLUDE_DIR} FILES ${onnxruntime_providers_coreml_public_headers})
+
 onnxruntime_add_static_library(onnxruntime_providers_coreml
-  ${onnxruntime_providers_coreml_cc_srcs} ${onnxruntime_providers_coreml_objcc_srcs}
+  ${onnxruntime_providers_coreml_public_headers}
+  ${onnxruntime_providers_coreml_cc_srcs}
+  ${coremltools_srcs}
 )
+
 onnxruntime_add_include_to_target(onnxruntime_providers_coreml
-  onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB}  flatbuffers::flatbuffers Boost::mp11 safeint_interface
+  onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11
+  safeint_interface
 )
-if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS")
-  onnxruntime_add_include_to_target(onnxruntime_providers_coreml coreml_proto)
-  target_link_libraries(onnxruntime_providers_coreml PRIVATE coreml_proto "-framework Foundation" "-framework CoreML")
-  add_dependencies(onnxruntime_providers_coreml coreml_proto)
+
+onnxruntime_add_include_to_target(onnxruntime_providers_coreml coreml_proto)
+target_link_libraries(onnxruntime_providers_coreml PRIVATE coreml_proto)
+add_dependencies(onnxruntime_providers_coreml coreml_proto)
+
+if (APPLE)
+  target_compile_definitions(onnxruntime_providers_coreml PRIVATE __APPLE__)
 endif()
+
+if (_enable_ML_PROGRAM)
+  # Setup coremltools fp16 and json dependencies for creating an mlpackage.
+  #
+  # These are also used by external/xnnpack.cmake. fp16 depends on psimd
+  FetchContent_Declare(psimd URL ${DEP_URL_psimd} URL_HASH SHA1=${DEP_SHA1_psimd})
+  onnxruntime_fetchcontent_makeavailable(psimd)
+  set(PSIMD_SOURCE_DIR ${psimd_SOURCE_DIR})
+  FetchContent_Declare(fp16 URL ${DEP_URL_fp16} URL_HASH SHA1=${DEP_SHA1_fp16})
+  set(FP16_BUILD_TESTS OFF CACHE INTERNAL "")
+  set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "")
+  onnxruntime_fetchcontent_makeavailable(fp16)
+
+  # need to tweak the include paths to match what the coreml source code expects
+  target_include_directories(onnxruntime_providers_coreml PRIVATE
+                            ${fp16_SOURCE_DIR}/include
+                            ${nlohmann_json_SOURCE_DIR}/single_include/nlohmann
+                            ${coremltools_SOURCE_DIR}
+                            ${coremltools_SOURCE_DIR}/mlmodel/src/
+                            ${coremltools_SOURCE_DIR}/modelpackage/src/
+  )
+
+  add_dependencies(onnxruntime_providers_coreml nlohmann_json::nlohmann_json fp16)
+
+  if (LINUX)
+    target_link_libraries(onnxruntime_providers_coreml PRIVATE uuid)
+  endif()
+endif()
+
+if (APPLE)
+  target_link_libraries(onnxruntime_providers_coreml PRIVATE "-framework Foundation" "-framework CoreML")
+endif()
+
 add_dependencies(onnxruntime_providers_coreml ${onnxruntime_EXTERNAL_DEPENDENCIES})
 
 set_target_properties(onnxruntime_providers_coreml PROPERTIES CXX_STANDARD_REQUIRED ON)
diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake
index 9887d615c92d7..aeeac10ead27d 100644
--- a/cmake/onnxruntime_providers_cuda.cmake
+++ b/cmake/onnxruntime_providers_cuda.cmake
@@ -141,18 +141,22 @@
     if (HAS_GUARD_CF)
       target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /guard:cf>")
     endif()
+
     if (HAS_QSPECTRE)
       target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /Qspectre>")
     endif()
+
     foreach(ORT_FLAG ${ORT_WARNING_FLAGS})
         target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler \"${ORT_FLAG}\">")
     endforeach()
+
     # CUDA 11.3+ supports parallel compilation
     # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver-threads
     if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.3)
       option(onnxruntime_NVCC_THREADS "Number of threads that NVCC can use for compilation." 1)
       target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">")
     endif()
+
     if (UNIX)
       target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler -Wno-reorder>"
                   "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:-Wno-reorder>")
@@ -162,6 +166,13 @@
       #mutex.cuh(91): warning C4834: discarding return value of function with 'nodiscard' attribute
       target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /wd4834>")
       target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /wd4127>")
+      if (MSVC)
+        # the VS warnings for 'Conditional Expression is Constant' are spurious as they don't handle multiple conditions
+        # e.g. `if (std::is_same_v<T, float> && not_a_const)` will generate the warning even though constexpr cannot
+        # be used due to `&& not_a_const`. This affects too many places for it to be reasonable to disable at a finer
+        # granularity.
+        target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:/wd4127>")
+      endif()
     endif()
 
     onnxruntime_add_include_to_target(${target} onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers)
@@ -178,9 +189,10 @@
     add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
     if(onnxruntime_CUDA_MINIMAL)
       target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL)
-      target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
+      target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface CUDA::cudart)
     else()
-      target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
+      target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas cudnn CUDA::curand CUDA::cufft CUDA::cudart
+              ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
       if(onnxruntime_CUDNN_HOME)
           target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include)
           target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib)
@@ -196,25 +208,24 @@
       target_include_directories(${target} PRIVATE ${triton_kernel_header_dir})
       target_link_libraries(${target} PUBLIC -Wl,--whole-archive ${triton_kernel_obj_file} -Wl,--no-whole-archive)
       # lib cuda needed by cuLaunchKernel
-      target_link_libraries(${target} PRIVATE cuda)
+      target_link_libraries(${target} PRIVATE CUDA::cuda_driver)
     endif()
 
     include(cutlass)
-    target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
+    target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include)
 
-    target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}  ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
+    target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}  ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES}
+     PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
     # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
     set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA)
     set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime")
 
     if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling
-      target_include_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/include)
-      target_link_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/lib64)
-      target_link_libraries(${target} PRIVATE cupti)
+      target_link_libraries(${target} PRIVATE CUDA::cupti)
     endif()
 
-    if (onnxruntime_ENABLE_NVTX_PROFILE AND NOT WIN32)
-      target_link_libraries(${target} PRIVATE nvToolsExt)
+    if (onnxruntime_ENABLE_NVTX_PROFILE)
+      target_link_libraries(${target} PRIVATE CUDA::nvtx3)
     endif()
 
     if (onnxruntime_ENABLE_TRAINING_OPS)
diff --git a/cmake/onnxruntime_providers_nnapi.cmake b/cmake/onnxruntime_providers_nnapi.cmake
index 5ac25a3b76efb..b718a976eb26f 100644
--- a/cmake/onnxruntime_providers_nnapi.cmake
+++ b/cmake/onnxruntime_providers_nnapi.cmake
@@ -49,12 +49,10 @@
   endif()
 
   # These are shared utils,
-  # TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML
+  # TODO, move this to a separate lib when used by EPs other than NNAPI and CoreML
   list(APPEND onnxruntime_provider_nnapi_cc_src_patterns
     "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
     "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
-    "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h"
-    "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc"
   )
 
   file(GLOB onnxruntime_providers_nnapi_cc_srcs CONFIGURE_DEPENDS ${onnxruntime_provider_nnapi_cc_src_patterns})
@@ -81,4 +79,4 @@
             LIBRARY   DESTINATION ${CMAKE_INSTALL_LIBDIR}
             RUNTIME   DESTINATION ${CMAKE_INSTALL_BINDIR}
             FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR})
-  endif()
\ No newline at end of file
+  endif()
diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake
index e26f0bfc0b751..5876b2b5c448b 100644
--- a/cmake/onnxruntime_providers_openvino.cmake
+++ b/cmake/onnxruntime_providers_openvino.cmake
@@ -16,23 +16,19 @@
   endif()
 
   # Header paths
-  find_package(InferenceEngine REQUIRED)
-  find_package(ngraph REQUIRED)
-
-  if (OPENVINO_2022_1 OR OPENVINO_2022_2)
   find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX)
-  list (OV_20_LIBS openvino::frontend::onnx openvino::runtime)
+  if(OpenVINO_VERSION VERSION_LESS 2023.0)
+    message(FATAL_ERROR "OpenVINO 2023.0 and newer are supported. Please, latest OpenVINO release")
   endif()
 
   if (WIN32)
     unset(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO)
   endif()
 
+  list(APPEND OPENVINO_LIB_LIST openvino::frontend::onnx openvino::runtime ${PYTHON_LIBRARIES})
   if ((DEFINED ENV{OPENCL_LIBS}) AND (DEFINED ENV{OPENCL_INCS}))
     add_definitions(-DIO_BUFFER_ENABLED=1)
-    list(APPEND OPENVINO_LIB_LIST $ENV{OPENCL_LIBS} ${OV_20_LIBS} ${InferenceEngine_LIBRARIES} ${NGRAPH_LIBRARIES} ngraph::onnx_importer ${PYTHON_LIBRARIES})
-  else()
-    list(APPEND OPENVINO_LIB_LIST ${OV_20_LIBS} ${InferenceEngine_LIBRARIES} ${NGRAPH_LIBRARIES} ngraph::onnx_importer ${PYTHON_LIBRARIES})
+    list(APPEND OPENVINO_LIB_LIST $ENV{OPENCL_LIBS})
   endif()
 
   source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_openvino_cc_srcs})
@@ -75,7 +71,14 @@
     message(FATAL_ERROR "onnxruntime_providers_openvino unknown platform, need to specify shared library exports for it")
   endif()
 
-  install(TARGETS onnxruntime_providers_openvino
-          ARCHIVE  DESTINATION ${CMAKE_INSTALL_LIBDIR}
-          LIBRARY  DESTINATION ${CMAKE_INSTALL_LIBDIR}
-          RUNTIME  DESTINATION ${CMAKE_INSTALL_BINDIR})
\ No newline at end of file
+  if (CMAKE_OPENVINO_LIBRARY_INSTALL_DIR)
+    install(TARGETS onnxruntime_providers_openvino
+            ARCHIVE  DESTINATION ${CMAKE_INSTALL_LIBDIR}
+            LIBRARY  DESTINATION ${CMAKE_OPENVINO_LIBRARY_INSTALL_DIR}
+            RUNTIME  DESTINATION ${CMAKE_INSTALL_BINDIR})
+  else()
+    install(TARGETS onnxruntime_providers_openvino
+            ARCHIVE  DESTINATION ${CMAKE_INSTALL_LIBDIR}
+            LIBRARY  DESTINATION ${CMAKE_INSTALL_LIBDIR}
+            RUNTIME  DESTINATION ${CMAKE_INSTALL_BINDIR})
+  endif()
diff --git a/cmake/onnxruntime_providers_qnn.cmake b/cmake/onnxruntime_providers_qnn.cmake
index a93a06e960c81..b68d84c23bb32 100644
--- a/cmake/onnxruntime_providers_qnn.cmake
+++ b/cmake/onnxruntime_providers_qnn.cmake
@@ -4,12 +4,10 @@
   add_compile_definitions(USE_QNN=1)
 
   # These are shared utils,
-  # TODO, move this to a separated lib when used by EPs other than QNN, NNAPI and CoreML
-  file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
+  # TODO, move to a separate lib when used by EPs other than QNN, NNAPI and CoreML
+  file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS
     "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
     "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
-    "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h"
-    "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc"
   )
 
   file(GLOB_RECURSE
@@ -42,4 +40,4 @@
   # ignore the warning unknown-pragmas on "pragma region"
   if(NOT MSVC)
     target_compile_options(onnxruntime_providers_qnn PRIVATE "-Wno-unknown-pragmas")
-  endif()
\ No newline at end of file
+  endif()
diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake
index 686a993de3a4a..15ffc29e79ff4 100644
--- a/cmake/onnxruntime_providers_tensorrt.cmake
+++ b/cmake/onnxruntime_providers_tensorrt.cmake
@@ -8,7 +8,7 @@
   set(BUILD_LIBRARY_ONLY 1)
   add_definitions("-DONNX_ML=1")
   add_definitions("-DONNX_NAMESPACE=onnx")
-  set(CUDA_INCLUDE_DIRS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
+  set(CUDA_INCLUDE_DIRS ${CUDAToolkit_INCLUDE_DIRS})
   set(TENSORRT_ROOT ${onnxruntime_TENSORRT_HOME})
   set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
   set(PROTOBUF_LIBRARY ${PROTOBUF_LIB})
@@ -58,7 +58,7 @@
       URL_HASH SHA1=${DEP_SHA1_onnx_tensorrt}
     )
     if (NOT CUDA_INCLUDE_DIR)
-      set(CUDA_INCLUDE_DIR ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # onnx-tensorrt repo needs this variable to build
+      set(CUDA_INCLUDE_DIR ${CUDAToolkit_INCLUDE_DIRS}) # onnx-tensorrt repo needs this variable to build
     endif()
     # The onnx_tensorrt repo contains a test program, getSupportedAPITest, which doesn't support Windows. It uses
     # unistd.h. So we must exclude it from our build. onnxruntime_fetchcontent_makeavailable is for the purpose.
@@ -102,11 +102,12 @@
   onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface)
   add_dependencies(onnxruntime_providers_tensorrt onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
   if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER)
-    target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS})
+    target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
   else()
-    target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS})
+    target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
   endif()
-  target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
+  target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS}
+    PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
   if(onnxruntime_CUDNN_HOME)
     target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include)
   endif()
diff --git a/cmake/onnxruntime_providers_xnnpack.cmake b/cmake/onnxruntime_providers_xnnpack.cmake
index 6342c24b2917e..796536ac9d12b 100644
--- a/cmake/onnxruntime_providers_xnnpack.cmake
+++ b/cmake/onnxruntime_providers_xnnpack.cmake
@@ -7,9 +7,6 @@
     "${ONNXRUNTIME_INCLUDE_DIR}/core/providers/xnnpack/*.h"
     "${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.h"
     "${ONNXRUNTIME_ROOT}/core/providers/xnnpack/*.cc"
-    # utils for handling QDQ models
-    "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.h"
-    "${ONNXRUNTIME_ROOT}/core/providers/shared/node_unit/node_unit.cc"
   )
 
   source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs})
diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake
index 3f20787e87425..23c6e5e430875 100644
--- a/cmake/onnxruntime_python.cmake
+++ b/cmake/onnxruntime_python.cmake
@@ -282,10 +282,7 @@ if (WIN32)
     get_filename_component(CUDNN_DLL_NAME ${CUDNN_DLL_PATH} NAME_WE)
     string(REPLACE "cudnn64_" "" CUDNN_VERSION "${CUDNN_DLL_NAME}")
     if(NOT onnxruntime_CUDA_VERSION)
-      message("Reading json file ${onnxruntime_CUDA_HOME}/version.json")
-      set(CUDA_SDK_JSON_FILE_PATH "${onnxruntime_CUDA_HOME}/version.json")
-      file(READ ${CUDA_SDK_JSON_FILE_PATH} CUDA_SDK_JSON_CONTENT)
-      string(JSON onnxruntime_CUDA_VERSION GET ${CUDA_SDK_JSON_CONTENT} "cuda" "version")
+      set(onnxruntime_CUDA_VERSION ${CUDAToolkit_VERSION})
       message("onnxruntime_CUDA_VERSION=${onnxruntime_CUDA_VERSION}")
     endif()
     file(APPEND "${VERSION_INFO_FILE}"
diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index d485abe6bb1a6..cadb06bb38707 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -20,10 +20,6 @@ set(contrib_ops_excluded_files
   "bert/fastertransformer_decoder_attention/*"
   "bert/multihead_attention.cc"
   "bert/multihead_attention.h"
-  "bert/fast_gelu_impl.cu"
-  "bert/fast_gelu_impl.h"
-  "bert/fast_gelu.cc"
-  "bert/fast_gelu.h"
   "bert/relative_attn_bias.cc"
   "bert/relative_attn_bias.h"
   "bert/relative_attn_bias_impl.cu"
@@ -44,12 +40,7 @@ set(contrib_ops_excluded_files
   "bert/packed_multihead_attention.cc"
   "bert/packed_multihead_attention_impl.h"
   "bert/packed_multihead_attention_impl.cu"
-  "diffusion/group_norm.cc"
   "diffusion/group_norm_impl.cu"
-  "diffusion/group_norm_impl.h"
-  "diffusion/group_norm_impl_kernel.cuh"
-  "diffusion/group_norm_common_base.h"
-  "diffusion/group_norm_common_base.cc"
   "diffusion/nhwc_conv.cc"
   "math/gemm_float8.cc"
   "math/gemm_float8.cu"
@@ -103,26 +94,18 @@ set(contrib_ops_excluded_files
   "bert/group_query_attention.cc"
   "bert/group_query_attention_impl.h"
   "bert/group_query_attention_impl.cu"
+  "collective/distributed_*"
+  "collective/shard*"
 )
 
-if (NOT onnxruntime_ENABLE_ATEN)
-  list(APPEND contrib_ops_excluded_files "aten_ops/aten_op.cc")
-endif()
 if (NOT onnxruntime_USE_NCCL)
   # Those are string patterns to exclude. Do NOT use stars such as
   # collective/*.cc or *.h.
   list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc")
-  list(APPEND contrib_ops_excluded_files "collective/sharded_moe.h")
-  list(APPEND contrib_ops_excluded_files "collective/sharded_moe.cc")
-  list(APPEND contrib_ops_excluded_files "collective/sharding.cc")
-  list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc")
-  list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc")
-  list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc")
-  list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc")
-  list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc")
-  list(APPEND contrib_ops_excluded_files "collective/distributed_reduce.cc")
-  list(APPEND contrib_ops_excluded_files "collective/distributed_unsqueeze.cc")
-  list(APPEND contrib_ops_excluded_files "collective/distributed_squeeze.cc")
+endif()
+
+if (NOT onnxruntime_ENABLE_ATEN)
+  list(APPEND contrib_ops_excluded_files "aten_ops/aten_op.cc")
 endif()
 
 set(provider_excluded_files
diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index cfdad4c527761..649c29212f305 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -1,6 +1,6 @@
 # Copyright (c) Microsoft Corporation. All rights reserved.
 # Licensed under the MIT License.
-if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS")
+if (IOS)
   find_package(XCTest REQUIRED)
 endif()
 
@@ -18,7 +18,7 @@ function(AddTest)
   cmake_parse_arguments(_UT "DYN" "TARGET" "LIBS;SOURCES;DEPENDS;TEST_ARGS" ${ARGN})
   list(REMOVE_DUPLICATES _UT_SOURCES)
 
-  if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS")
+  if (IOS)
     onnxruntime_add_executable(${_UT_TARGET} ${TEST_SRC_DIR}/xctest/orttestmain.m)
   else()
     onnxruntime_add_executable(${_UT_TARGET} ${_UT_SOURCES})
@@ -67,7 +67,7 @@ function(AddTest)
     if(onnxruntime_USE_CUDA)
       #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs,
       # otherwise it will impact when CUDA DLLs can be unloaded.
-      target_link_libraries(${_UT_TARGET} PRIVATE cudart)
+      target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart)
     endif()
     target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES})
   endif()
@@ -129,7 +129,7 @@ function(AddTest)
     endif()
   endif(onnxruntime_GENERATE_TEST_REPORTS)
 
-  if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS")
+  if (IOS)
     # target_sources(${_UT_TARGET} PRIVATE ${TEST_SRC_DIR}/xctest/orttestmain.m)
     set_target_properties(${_UT_TARGET} PROPERTIES FOLDER "ONNXRuntimeTest"
       MACOSX_BUNDLE_BUNDLE_NAME ${_UT_TARGET}
@@ -567,11 +567,7 @@ if(onnxruntime_USE_ROCM)
 endif()
 
 if(onnxruntime_USE_COREML)
-  if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS")
-    list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto)
-  else()
-    list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml)
-  endif()
+  list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto)
 endif()
 
 if(onnxruntime_USE_ACL)
@@ -676,15 +672,9 @@ endif()
 
 if(onnxruntime_USE_COREML)
   list(APPEND onnxruntime_test_framework_src_patterns  ${TEST_SRC_DIR}/providers/coreml/*)
-  if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS")
-    list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml coreml_proto)
-    list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto)
-    list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml coreml_proto)
-  else()
-    list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml)
-    list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml)
-    list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml)
-  endif()
+  list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml coreml_proto)
+  list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto)
+  list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml coreml_proto)
 endif()
 
 if(onnxruntime_USE_XNNPACK)
@@ -744,34 +734,37 @@ target_include_directories(onnxruntime_test_utils PUBLIC "${TEST_SRC_DIR}/util/i
 set_target_properties(onnxruntime_test_utils PROPERTIES FOLDER "ONNXRuntimeTest")
 source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_test_utils_src})
 
-set(onnx_test_runner_src_dir ${TEST_SRC_DIR}/onnx)
-file(GLOB onnx_test_runner_common_srcs CONFIGURE_DEPENDS
-    ${onnx_test_runner_src_dir}/*.h
-    ${onnx_test_runner_src_dir}/*.cc)
+if(NOT IOS)
+    set(onnx_test_runner_src_dir ${TEST_SRC_DIR}/onnx)
+    file(GLOB onnx_test_runner_common_srcs CONFIGURE_DEPENDS
+        ${onnx_test_runner_src_dir}/*.h
+        ${onnx_test_runner_src_dir}/*.cc)
 
-list(REMOVE_ITEM onnx_test_runner_common_srcs ${onnx_test_runner_src_dir}/main.cc)
+    list(REMOVE_ITEM onnx_test_runner_common_srcs ${onnx_test_runner_src_dir}/main.cc)
 
-onnxruntime_add_static_library(onnx_test_runner_common ${onnx_test_runner_common_srcs})
-if(MSVC)
-  target_compile_options(onnx_test_runner_common PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /utf-8>"
-          "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/utf-8>")
-else()
-  target_compile_definitions(onnx_test_runner_common PUBLIC -DNSYNC_ATOMIC_CPP11)
-  target_include_directories(onnx_test_runner_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT})
-  onnxruntime_add_include_to_target(onnx_test_runner_common nsync::nsync_cpp)
-endif()
-if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
-  #TODO: fix the warnings, they are dangerous
-  target_compile_options(onnx_test_runner_common PRIVATE "/wd4244")
-endif()
-onnxruntime_add_include_to_target(onnx_test_runner_common onnxruntime_common onnxruntime_framework
-        onnxruntime_test_utils onnx onnx_proto re2::re2 flatbuffers::flatbuffers Boost::mp11 safeint_interface)
+    onnxruntime_add_static_library(onnx_test_runner_common ${onnx_test_runner_common_srcs})
+    if(MSVC)
+      target_compile_options(onnx_test_runner_common PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /utf-8>"
+              "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/utf-8>")
+    else()
+      target_compile_definitions(onnx_test_runner_common PUBLIC -DNSYNC_ATOMIC_CPP11)
+      target_include_directories(onnx_test_runner_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT})
+      onnxruntime_add_include_to_target(onnx_test_runner_common nsync::nsync_cpp)
+    endif()
+    if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
+      #TODO: fix the warnings, they are dangerous
+      target_compile_options(onnx_test_runner_common PRIVATE "/wd4244")
+    endif()
+    onnxruntime_add_include_to_target(onnx_test_runner_common onnxruntime_common onnxruntime_framework
+            onnxruntime_test_utils onnx onnx_proto re2::re2 flatbuffers::flatbuffers Boost::mp11 safeint_interface)
 
-add_dependencies(onnx_test_runner_common onnx_test_data_proto ${onnxruntime_EXTERNAL_DEPENDENCIES})
-target_include_directories(onnx_test_runner_common PRIVATE ${eigen_INCLUDE_DIRS}
-        ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT})
+    add_dependencies(onnx_test_runner_common onnx_test_data_proto ${onnxruntime_EXTERNAL_DEPENDENCIES})
+    target_include_directories(onnx_test_runner_common PRIVATE ${eigen_INCLUDE_DIRS}
+            ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT})
 
-set_target_properties(onnx_test_runner_common PROPERTIES FOLDER "ONNXRuntimeTest")
+    set_target_properties(onnx_test_runner_common PROPERTIES FOLDER "ONNXRuntimeTest")
+    set(onnx_test_runner_common_lib onnx_test_runner_common)
+endif()
 
 set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src}
         ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantiztion_src})
@@ -784,6 +777,7 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
   onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $<TARGET_OBJECTS:onnxruntime_providers_cuda_obj>)
   config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut)
   onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock)
+  target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
   target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
   list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut)
 endif()
@@ -834,6 +828,15 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
     "${TEST_SRC_DIR}/optimizer/qdq_transformer_test.cc")
 endif()
 
+if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR IOS)
+   # Because we do not run these model tests in our web or iOS CI build pipelines, and some test code uses C++17
+   # filesystem functions that are not available in the iOS version we target.
+   message("Disable model tests in onnxruntime_test_all")
+   list(REMOVE_ITEM all_tests
+      "${TEST_SRC_DIR}/providers/cpu/model_tests.cc"
+    )
+endif()
+
 set(test_all_args)
 if (onnxruntime_USE_TENSORRT)
   # TRT EP CI takes much longer time when updating to TRT 8.2
@@ -851,7 +854,7 @@ AddTest(
   TARGET onnxruntime_test_all
   SOURCES ${all_tests} ${onnxruntime_unittest_main_src}
   LIBS
-    onnx_test_runner_common ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs}
+    ${onnx_test_runner_common_lib} ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs}
     onnx_test_data_proto
   DEPENDS ${all_dependencies}
   TEST_ARGS ${test_all_args}
@@ -889,7 +892,7 @@ endif()
 # the default logger tests conflict with the need to have an overall default logger
 # so skip in this type of
 target_compile_definitions(onnxruntime_test_all PUBLIC -DSKIP_DEFAULT_LOGGER_TESTS)
-if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
+if (IOS)
   target_compile_definitions(onnxruntime_test_all_xc PUBLIC -DSKIP_DEFAULT_LOGGER_TESTS)
 endif()
 if(onnxruntime_RUN_MODELTEST_IN_DEBUG_MODE)
@@ -1002,7 +1005,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
     if (MSVC OR ${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
         file(GLOB QNN_LIB_FILES LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/*.so" "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/*.dll")
         if (${QNN_ARCH_ABI} STREQUAL "aarch64-windows-msvc")
-          file(GLOB EXTRA_HTP_LIB LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/hexagon-v68/unsigned/libQnnHtpV68Skel.so" "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so")
+          file(GLOB EXTRA_HTP_LIB LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/hexagon-v68/unsigned/libQnnHtpV68Skel.so"
+		  "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so"
+		  "${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libqnnhtpv73.cat")
           list(APPEND QNN_LIB_FILES ${EXTRA_HTP_LIB})
         endif()
         message(STATUS "QNN lib files: " ${QNN_LIB_FILES})
@@ -1060,45 +1065,42 @@ if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS)
   list(APPEND onnx_test_libs onnxruntime_language_interop onnxruntime_pyop)
 endif()
 
-onnxruntime_add_executable(onnx_test_runner ${onnx_test_runner_src_dir}/main.cc)
-if(MSVC)
-  target_compile_options(onnx_test_runner PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /utf-8>"
-          "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/utf-8>")
-endif()
-if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS")
-  set_target_properties(onnx_test_runner PROPERTIES
-    XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO"
-  )
-endif()
-if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
-  if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
-    set_target_properties(onnx_test_runner PROPERTIES LINK_FLAGS "-s NODERAWFS=1 -s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1")
-  else()
-    set_target_properties(onnx_test_runner PROPERTIES LINK_FLAGS "-s NODERAWFS=1 -s ALLOW_MEMORY_GROWTH=1")
-  endif()
-endif()
+if (NOT IOS)
+    onnxruntime_add_executable(onnx_test_runner ${onnx_test_runner_src_dir}/main.cc)
+    if(MSVC)
+      target_compile_options(onnx_test_runner PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /utf-8>"
+              "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/utf-8>")
+    endif()
+    if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
+      if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
+        set_target_properties(onnx_test_runner PROPERTIES LINK_FLAGS "-s NODERAWFS=1 -s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1")
+      else()
+        set_target_properties(onnx_test_runner PROPERTIES LINK_FLAGS "-s NODERAWFS=1 -s ALLOW_MEMORY_GROWTH=1")
+      endif()
+    endif()
 
-target_link_libraries(onnx_test_runner PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs} nlohmann_json::nlohmann_json)
-target_include_directories(onnx_test_runner PRIVATE ${ONNXRUNTIME_ROOT})
-if (onnxruntime_USE_ROCM)
-  target_include_directories(onnx_test_runner PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining)
-endif()
-if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
-  target_link_libraries(onnx_test_runner PRIVATE Python::Python)
-endif()
-set_target_properties(onnx_test_runner PROPERTIES FOLDER "ONNXRuntimeTest")
+    target_link_libraries(onnx_test_runner PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs} nlohmann_json::nlohmann_json)
+    target_include_directories(onnx_test_runner PRIVATE ${ONNXRUNTIME_ROOT})
+    if (onnxruntime_USE_ROCM)
+      target_include_directories(onnx_test_runner PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining)
+    endif()
+    if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
+      target_link_libraries(onnx_test_runner PRIVATE Python::Python)
+    endif()
+    set_target_properties(onnx_test_runner PROPERTIES FOLDER "ONNXRuntimeTest")
 
-if (onnxruntime_USE_TVM)
-  if (WIN32)
-    target_link_options(onnx_test_runner PRIVATE "/STACK:4000000")
-  endif()
-endif()
+    if (onnxruntime_USE_TVM)
+      if (WIN32)
+        target_link_options(onnx_test_runner PRIVATE "/STACK:4000000")
+      endif()
+    endif()
 
-install(TARGETS onnx_test_runner
-        ARCHIVE  DESTINATION ${CMAKE_INSTALL_LIBDIR}
-        LIBRARY  DESTINATION ${CMAKE_INSTALL_LIBDIR}
-        BUNDLE   DESTINATION ${CMAKE_INSTALL_LIBDIR}
-        RUNTIME  DESTINATION ${CMAKE_INSTALL_BINDIR})
+    install(TARGETS onnx_test_runner
+            ARCHIVE  DESTINATION ${CMAKE_INSTALL_LIBDIR}
+            LIBRARY  DESTINATION ${CMAKE_INSTALL_LIBDIR}
+            BUNDLE   DESTINATION ${CMAKE_INSTALL_LIBDIR}
+            RUNTIME  DESTINATION ${CMAKE_INSTALL_BINDIR})
+endif()
 
 if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
   if(onnxruntime_BUILD_BENCHMARKS)
@@ -1179,90 +1181,80 @@ endif()
 
 
 if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
-  #perf test runner
-  set(onnxruntime_perf_test_src_dir ${TEST_SRC_DIR}/perftest)
-  set(onnxruntime_perf_test_src_patterns
-  "${onnxruntime_perf_test_src_dir}/*.cc"
-  "${onnxruntime_perf_test_src_dir}/*.h")
+  if(NOT IOS)
+    #perf test runner
+    set(onnxruntime_perf_test_src_dir ${TEST_SRC_DIR}/perftest)
+    set(onnxruntime_perf_test_src_patterns
+    "${onnxruntime_perf_test_src_dir}/*.cc"
+    "${onnxruntime_perf_test_src_dir}/*.h")
 
-  if(WIN32)
-    list(APPEND onnxruntime_perf_test_src_patterns
-      "${onnxruntime_perf_test_src_dir}/windows/*.cc"
-      "${onnxruntime_perf_test_src_dir}/windows/*.h" )
-  else ()
-    list(APPEND onnxruntime_perf_test_src_patterns
-      "${onnxruntime_perf_test_src_dir}/posix/*.cc"
-      "${onnxruntime_perf_test_src_dir}/posix/*.h" )
-  endif()
+    if(WIN32)
+      list(APPEND onnxruntime_perf_test_src_patterns
+        "${onnxruntime_perf_test_src_dir}/windows/*.cc"
+        "${onnxruntime_perf_test_src_dir}/windows/*.h" )
+    else ()
+      list(APPEND onnxruntime_perf_test_src_patterns
+        "${onnxruntime_perf_test_src_dir}/posix/*.cc"
+        "${onnxruntime_perf_test_src_dir}/posix/*.h" )
+    endif()
 
-  file(GLOB onnxruntime_perf_test_src CONFIGURE_DEPENDS
-    ${onnxruntime_perf_test_src_patterns}
-    )
-  onnxruntime_add_executable(onnxruntime_perf_test ${onnxruntime_perf_test_src} ${ONNXRUNTIME_ROOT}/core/platform/path_lib.cc)
-  if(MSVC)
-    target_compile_options(onnxruntime_perf_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /utf-8>"
+    file(GLOB onnxruntime_perf_test_src CONFIGURE_DEPENDS
+      ${onnxruntime_perf_test_src_patterns}
+      )
+    onnxruntime_add_executable(onnxruntime_perf_test ${onnxruntime_perf_test_src} ${ONNXRUNTIME_ROOT}/core/platform/path_lib.cc)
+    if(MSVC)
+      target_compile_options(onnxruntime_perf_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /utf-8>"
             "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/utf-8>")
-  endif()
-  target_include_directories(onnxruntime_perf_test PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT}
+    endif()
+    target_include_directories(onnxruntime_perf_test PRIVATE   ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT}
           ${eigen_INCLUDE_DIRS} ${onnxruntime_graph_header} ${onnxruntime_exec_src_dir}
           ${CMAKE_CURRENT_BINARY_DIR})
-  if (onnxruntime_USE_ROCM)
-    target_include_directories(onnxruntime_perf_test PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining)
-  endif()
-  if (WIN32)
-    target_compile_options(onnxruntime_perf_test PRIVATE ${disabled_warnings})
-    if (NOT DEFINED SYS_PATH_LIB)
-      set(SYS_PATH_LIB shlwapi)
+    if (onnxruntime_USE_ROCM)
+      target_include_directories(onnxruntime_perf_test PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining)
+    endif()
+    if (WIN32)
+      target_compile_options(onnxruntime_perf_test PRIVATE ${disabled_warnings})
+      if (NOT DEFINED SYS_PATH_LIB)
+        set(SYS_PATH_LIB shlwapi)
+      endif()
     endif()
-  endif()
-  if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS")
-    set_target_properties(onnxruntime_perf_test PROPERTIES
-      XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO"
-    )
-  endif()
 
-  if (onnxruntime_BUILD_SHARED_LIB)
-    #It will dynamically link to onnxruntime. So please don't add onxruntime_graph/onxruntime_framework/... here.
-    #onnxruntime_common is kind of ok because it is thin, tiny and totally stateless.
-    set(onnxruntime_perf_test_libs
+    if (onnxruntime_BUILD_SHARED_LIB)
+      #It will dynamically link to onnxruntime. So please don't add onxruntime_graph/onxruntime_framework/... here.
+      #onnxruntime_common is kind of ok because it is thin, tiny and totally stateless.
+      set(onnxruntime_perf_test_libs
             onnx_test_runner_common onnxruntime_test_utils onnxruntime_common
             onnxruntime onnxruntime_flatbuffers onnx_test_data_proto
             ${onnxruntime_EXTERNAL_LIBRARIES}
             ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS})
-    if(NOT WIN32)
-      list(APPEND onnxruntime_perf_test_libs nsync::nsync_cpp)
-      if(onnxruntime_USE_SNPE)
-        list(APPEND onnxruntime_perf_test_libs onnxruntime_providers_snpe)
+      if(NOT WIN32)
+        list(APPEND onnxruntime_perf_test_libs nsync::nsync_cpp)
+        if(onnxruntime_USE_SNPE)
+          list(APPEND onnxruntime_perf_test_libs onnxruntime_providers_snpe)
+        endif()
       endif()
+      if (CMAKE_SYSTEM_NAME STREQUAL "Android")
+        list(APPEND onnxruntime_perf_test_libs ${android_shared_libs})
+      endif()
+      target_link_libraries(onnxruntime_perf_test PRIVATE ${onnxruntime_perf_test_libs} Threads::Threads)
+      if(WIN32)
+        target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32)
+      endif()
+    else()
+      target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs})
     endif()
-    if (CMAKE_SYSTEM_NAME STREQUAL "Android")
-      list(APPEND onnxruntime_perf_test_libs ${android_shared_libs})
-    endif()
-    target_link_libraries(onnxruntime_perf_test PRIVATE ${onnxruntime_perf_test_libs} Threads::Threads)
-    if(WIN32)
-      target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32)
-    endif()
-    if(tensorflow_C_PACKAGE_PATH)
-      target_include_directories(onnxruntime_perf_test PRIVATE ${tensorflow_C_PACKAGE_PATH}/include)
-      target_link_directories(onnxruntime_perf_test PRIVATE ${tensorflow_C_PACKAGE_PATH}/lib)
-      target_link_libraries(onnxruntime_perf_test PRIVATE tensorflow)
-      target_compile_definitions(onnxruntime_perf_test PRIVATE HAVE_TENSORFLOW)
-    endif()
-  else()
-    target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs})
-  endif()
-  set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest")
+    set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest")
 
-  if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS AND NOT onnxruntime_BUILD_SHARED_LIB)
-    target_link_libraries(onnxruntime_perf_test PRIVATE onnxruntime_language_interop onnxruntime_pyop)
-  endif()
+    if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS AND NOT onnxruntime_BUILD_SHARED_LIB)
+      target_link_libraries(onnxruntime_perf_test PRIVATE onnxruntime_language_interop onnxruntime_pyop)
+    endif()
 
-  if (onnxruntime_USE_TVM)
-    if (WIN32)
-      target_link_options(onnxruntime_perf_test PRIVATE "/STACK:4000000")
+    if (onnxruntime_USE_TVM)
+      if (WIN32)
+        target_link_options(onnxruntime_perf_test PRIVATE "/STACK:4000000")
+      endif()
     endif()
   endif()
-
   # shared lib
   if (onnxruntime_BUILD_SHARED_LIB)
     onnxruntime_add_static_library(onnxruntime_mocked_allocator ${TEST_SRC_DIR}/util/test_allocator.cc)
@@ -1283,7 +1275,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
       list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo)
     endif()
     if (onnxruntime_USE_CUDA)
-      list(APPEND onnxruntime_shared_lib_test_LIBS cudart)
+      list(APPEND onnxruntime_shared_lib_test_LIBS CUDA::cudart)
     endif()
     if (onnxruntime_USE_ROCM)
       list(APPEND onnxruntime_shared_lib_test_LIBS hip::host)
@@ -1317,7 +1309,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
       target_compile_definitions(onnxruntime_shared_lib_test PRIVATE USE_DUMMY_EXA_DEMANGLE=1)
     endif()
 
-    if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
+    if (IOS)
       add_custom_command(
         TARGET onnxruntime_shared_lib_test POST_BUILD
         COMMAND ${CMAKE_COMMAND} -E copy_directory
@@ -1404,7 +1396,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
       target_compile_options(onnxruntime_mlas_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd26426>"
                   "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd26426>")
     endif()
-    if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS")
+    if(IOS)
       set_target_properties(onnxruntime_mlas_test PROPERTIES
         XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO"
       )
@@ -1605,7 +1597,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
             DEPENDS ${all_dependencies}
     )
 
-    if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
+    if (IOS)
       add_custom_command(
         TARGET onnxruntime_customopregistration_test POST_BUILD
         COMMAND ${CMAKE_COMMAND} -E copy_directory
diff --git a/cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch b/cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch
new file mode 100644
index 0000000000000..afb19a45ce0f4
--- /dev/null
+++ b/cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch
@@ -0,0 +1,22 @@
+diff --git a/include/cpuinfo.h b/include/cpuinfo.h
+index c46b65e..8b83a64 100644
+--- a/include/cpuinfo.h
++++ b/include/cpuinfo.h
+@@ -18,7 +18,7 @@
+ 	#define CPUINFO_ARCH_X86 1
+ #endif
+ 
+-#if defined(__x86_64__) || defined(__x86_64) || defined(_M_X64) || defined(_M_AMD64)
++#if defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)) || (defined(_M_AMD64) && !defined(_M_ARM64EC))
+ 	#define CPUINFO_ARCH_X86_64 1
+ #endif
+ 
+@@ -26,7 +26,7 @@
+ 	#define CPUINFO_ARCH_ARM 1
+ #endif
+ 
+-#if defined(__aarch64__) || defined(_M_ARM64)
++#if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC)
+ 	#define CPUINFO_ARCH_ARM64 1
+ #endif
+ 
diff --git a/cmake/patches/flatbuffers/flatbuffers.patch b/cmake/patches/flatbuffers/flatbuffers.patch
index f141d358c54b6..fbe8db37ecb0e 100644
--- a/cmake/patches/flatbuffers/flatbuffers.patch
+++ b/cmake/patches/flatbuffers/flatbuffers.patch
@@ -2,35 +2,11 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt
 index 3987eac9..5e5462f1 100644
 --- a/CMakeLists.txt
 +++ b/CMakeLists.txt
-@@ -223,7 +223,7 @@ elseif(CMAKE_COMPILER_IS_GNUCXX)
-       "${CMAKE_CXX_FLAGS} -std=c++0x")
-   endif(CYGWIN)
-   set(CMAKE_CXX_FLAGS
--    "${CMAKE_CXX_FLAGS} -Wall -pedantic -Werror -Wextra -Werror=shadow")
-+    "${CMAKE_CXX_FLAGS} -Wall -pedantic -Wextra -Werror=shadow -Wno-error=stringop-overflow")
-   set(FLATBUFFERS_PRIVATE_CXX_FLAGS "-Wold-style-cast")
-   if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.4)
-     if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
-diff --git a/src/idl_gen_rust.cpp b/src/idl_gen_rust.cpp
-index 55b8439b..dc03e8a8 100644
---- a/src/idl_gen_rust.cpp
-+++ b/src/idl_gen_rust.cpp
-@@ -406,7 +406,8 @@ class RustGenerator : public BaseGenerator {
-     // example: f(A, D::E)          -> super::D::E
-     // does not include leaf object (typically a struct type).
- 
--    size_t i = 0;
-+    // fix unused but set variable warning
-+    //size_t i = 0;
-     std::stringstream stream;
- 
-     auto s = src->components.begin();
-@@ -417,7 +418,7 @@ class RustGenerator : public BaseGenerator {
-       if (*s != *d) { break; }
-       ++s;
-       ++d;
--      ++i;
-+      //++i;
-     }
- 
-     for (; s != src->components.end(); ++s) { stream << "super::"; }
+@@ -279,5 +279,5 @@
+ # Append FLATBUFFERS_CXX_FLAGS to CMAKE_CXX_FLAGS.
+ if(DEFINED FLATBUFFERS_CXX_FLAGS)
+   message(STATUS "extend CXX_FLAGS with ${FLATBUFFERS_CXX_FLAGS}")
+-  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLATBUFFERS_CXX_FLAGS}")
++  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLATBUFFERS_CXX_FLAGS} -Wno-error=stringop-overflow")
+ endif()
+ message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
diff --git a/cmake/patches/neural_speed/150e7527d5286ddd3a995c228dedf8d76a7a86bc.patch b/cmake/patches/neural_speed/150e7527d5286ddd3a995c228dedf8d76a7a86bc.patch
new file mode 100644
index 0000000000000..e503a512a74ff
--- /dev/null
+++ b/cmake/patches/neural_speed/150e7527d5286ddd3a995c228dedf8d76a7a86bc.patch
@@ -0,0 +1,30 @@
+diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h
+index 99f3ccc..a11de9d 100644
+--- a/bestla/bestla/bestla_prologue_b.h
++++ b/bestla/bestla/bestla_prologue_b.h
+@@ -456,9 +456,8 @@ class WeightKBlockNInteger {
+     auto tmpscales = tmp;
+     auto tmpzeropoints = reinterpret_cast<int8_t*>(tmpscales + N * blks);
+     if (scales) {
+-      for (size_t i = 0; i < N * blks; i += 2) {
++      for (size_t i = 0; i < N * blks; i ++) {
+         tmpscales[i] = scales[i] / 16;
+-        tmpscales[i + 1] = scales[i + 1] / 16;
+       }
+     }
+     if (zero_points) {
+diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h
+index 6783ee8..59822e5 100644
+--- a/bestla/bestla/kernel_avx512f.h
++++ b/bestla/bestla/kernel_avx512f.h
+@@ -673,8 +673,8 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8
+     zmm1 = _mm512_sllv_epi32(zmm1, zmm_shift);  // int3_clip => int8
+     zmm2 = _mm512_sllv_epi32(zmm2, zmm_shift);  // int3_clip => int8
+
+-    _mm512_storeu_epi8((__m512i*)dst, zmm1);
+-    _mm512_storeu_epi8((__m512i*)(dst + 64), zmm2);
++    _mm512_storeu_si512((__m512i*)dst, zmm1);
++    _mm512_storeu_si512((__m512i*)(dst + 64), zmm2);
+   };
+
+   assert(head_ignore_num % 8 == 0);
diff --git a/cmake/wcos_rules_override.cmake b/cmake/wcos_rules_override.cmake
index f3d8093629a42..ec2303b073d5e 100644
--- a/cmake/wcos_rules_override.cmake
+++ b/cmake/wcos_rules_override.cmake
@@ -1,2 +1,2 @@
-set(CMAKE_C_STANDARD_LIBRARIES_INIT onecoreuap_apiset.lib)
-set(CMAKE_CXX_STANDARD_LIBRARIES_INIT onecoreuap_apiset.lib)
+set(CMAKE_C_STANDARD_LIBRARIES_INIT onecoreuap.lib)
+set(CMAKE_CXX_STANDARD_LIBRARIES_INIT onecoreuap.lib)
diff --git a/cmake/winml.cmake b/cmake/winml.cmake
index 268ee3960e75a..d74250b962628 100644
--- a/cmake/winml.cmake
+++ b/cmake/winml.cmake
@@ -836,6 +836,13 @@ if (winml_is_inbox)
     target_include_directories(${new_target} PRIVATE ${include_directories})
     target_link_libraries(${new_target} PRIVATE ${link_libraries})
     target_link_options(${new_target} PRIVATE ${link_options})
+
+    # Attempt to copy linker flags 
+    get_target_property(link_flags ${target} LINK_FLAGS)
+    
+    if (NOT link_flags MATCHES ".*NOTFOUND")
+      set_property(TARGET ${new_target} PROPERTY LINK_FLAGS "${link_flags}")
+    endif()
   endfunction()
 
   if (WAI_ARCH STREQUAL x64 OR WAI_ARCH STREQUAL arm64)
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index 4128524b30483..8a8426a0b3054 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -362,6 +362,7 @@ static NativeMethods()
             OrtDisableMemPattern = (DOrtDisableMemPattern)Marshal.GetDelegateForFunctionPointer(api_.DisableMemPattern, typeof(DOrtDisableMemPattern));
             OrtEnableCpuMemArena = (DOrtEnableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.EnableCpuMemArena, typeof(DOrtEnableCpuMemArena));
             OrtDisableCpuMemArena = (DOrtDisableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.DisableCpuMemArena, typeof(DOrtDisableCpuMemArena));
+            OrtDisablePerSessionThreads = (DOrtDisablePerSessionThreads)Marshal.GetDelegateForFunctionPointer(api_.DisablePerSessionThreads, typeof(DOrtDisablePerSessionThreads));
             OrtSetSessionLogId = (DOrtSetSessionLogId)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogId, typeof(DOrtSetSessionLogId));
             OrtSetSessionLogVerbosityLevel = (DOrtSetSessionLogVerbosityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogVerbosityLevel, typeof(DOrtSetSessionLogVerbosityLevel));
             OrtSetSessionLogSeverityLevel = (DOrtSetSessionLogSeverityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogSeverityLevel, typeof(DOrtSetSessionLogSeverityLevel));
@@ -992,6 +993,10 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
         public delegate IntPtr /*(OrtStatus*)*/ DOrtDisableCpuMemArena(IntPtr /* OrtSessionOptions* */ options);
         public static DOrtDisableCpuMemArena OrtDisableCpuMemArena;
 
+        [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+        public delegate IntPtr /*(OrtStatus*)*/ DOrtDisablePerSessionThreads(IntPtr /* OrtSessionOptions* */ options);
+        public static DOrtDisablePerSessionThreads OrtDisablePerSessionThreads;
+
         [UnmanagedFunctionPointer(CallingConvention.Winapi)]
         public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, byte[] /* const char* */ logId);
         public static DOrtSetSessionLogId OrtSetSessionLogId;
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
index 7a68246c9b67a..30d005b3c4236 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
@@ -696,6 +696,15 @@ public bool EnableCpuMemArena
         }
         private bool _enableCpuMemArena = true;
 
+        /// <summary>
+        /// Disables the per session threads. Default is true.
+        /// This makes all sessions in the process use a global TP.
+        /// </summary>
+        public void DisablePerSessionThreads()
+        {
+            NativeApiStatus.VerifySuccess(NativeMethods.OrtDisablePerSessionThreads(handle));
+        }
+
         /// <summary>
         /// Log Id to be used for the session. Default is empty string.
         /// </summary>
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
index fd8feda359f90..d6a6b9627f418 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
@@ -55,6 +55,9 @@ public void TestSessionOptions()
                 Assert.Equal(0, opt.InterOpNumThreads);
                 Assert.Equal(GraphOptimizationLevel.ORT_ENABLE_ALL, opt.GraphOptimizationLevel);
 
+                // No get, so no verify
+                opt.DisablePerSessionThreads();
+
                 // try setting options
                 opt.ExecutionMode = ExecutionMode.ORT_PARALLEL;
                 Assert.Equal(ExecutionMode.ORT_PARALLEL, opt.ExecutionMode);
@@ -98,7 +101,7 @@ public void TestSessionOptions()
                 Assert.Contains("[ErrorCode:InvalidArgument] Config key is empty", ex.Message);
 
                 // SessionOptions.RegisterOrtExtensions can be manually tested by referencing the
-                // Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw.                
+                // Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw.
                 ex = Assert.Throws<OnnxRuntimeException>(() => { opt.RegisterOrtExtensions(); });
                 Assert.Contains("Microsoft.ML.OnnxRuntime.Extensions NuGet package must be referenced", ex.Message);
 
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
index 715aed7e1d64f..7f3d5d6624b07 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
@@ -145,7 +145,7 @@ private void TestCUDAProviderOptions()
         private void CanRunInferenceOnAModelWithTensorRT()
         {
             string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
-            
+
             int deviceId = 0;
             string deviceIdStr = System.Environment.GetEnvironmentVariable("ONNXRUNTIME_TEST_GPU_DEVICE_ID");
             if (!string.IsNullOrEmpty(deviceIdStr) && int.TryParse(deviceIdStr, out int parsedValue) && parsedValue >= 0)
diff --git a/csharp/tools/MauiModelTester/MauiModelTester.csproj b/csharp/tools/MauiModelTester/MauiModelTester.csproj
index a374c2933ce8f..39e688ce6c1b8 100644
--- a/csharp/tools/MauiModelTester/MauiModelTester.csproj
+++ b/csharp/tools/MauiModelTester/MauiModelTester.csproj
@@ -1,8 +1,8 @@
 <Project Sdk="Microsoft.NET.Sdk">
 
 	<PropertyGroup>
-		<TargetFrameworks>net6.0-android;net6.0-ios</TargetFrameworks>
-		<TargetFrameworks Condition="$([MSBuild]::IsOSPlatform('windows'))">$(TargetFrameworks);net6.0-windows10.0.19041.0</TargetFrameworks>
+		<TargetFrameworks>net8.0-ios;net8.0-android34.0</TargetFrameworks>
+		<TargetFrameworks Condition="$([MSBuild]::IsOSPlatform('windows'))">$(TargetFrameworks);net8.0-windows10.0.19041.0</TargetFrameworks>
 		<OutputType>Exe</OutputType>
 		<RootNamespace>MauiModelTester</RootNamespace>
 		<UseMaui>true</UseMaui>
@@ -21,7 +21,7 @@
 		<ApplicationVersion>1</ApplicationVersion>
 
 		<SupportedOSPlatformVersion Condition="$([MSBuild]::GetTargetPlatformIdentifier('$(TargetFramework)')) == 'ios'">12.0</SupportedOSPlatformVersion>
-		<SupportedOSPlatformVersion Condition="$([MSBuild]::GetTargetPlatformIdentifier('$(TargetFramework)')) == 'android'">21.0</SupportedOSPlatformVersion>
+		<SupportedOSPlatformVersion Condition="$([MSBuild]::GetTargetPlatformIdentifier('$(TargetFramework)')) == 'android'">29.0</SupportedOSPlatformVersion>
 		<SupportedOSPlatformVersion Condition="$([MSBuild]::GetTargetPlatformIdentifier('$(TargetFramework)')) == 'windows'">10.0.17763.0</SupportedOSPlatformVersion>
 		<TargetPlatformMinVersion Condition="$([MSBuild]::GetTargetPlatformIdentifier('$(TargetFramework)')) == 'windows'">10.0.17763.0</TargetPlatformMinVersion>
 		<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
diff --git a/csharp/tools/MauiModelTester/Platforms/Android/AndroidManifest.xml b/csharp/tools/MauiModelTester/Platforms/Android/AndroidManifest.xml
index cc320dab474a0..2ef2296d7441f 100644
--- a/csharp/tools/MauiModelTester/Platforms/Android/AndroidManifest.xml
+++ b/csharp/tools/MauiModelTester/Platforms/Android/AndroidManifest.xml
@@ -4,5 +4,5 @@
 	<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
 	<uses-permission android:name="android.permission.INTERNET" />
 	<uses-permission android:name="android.permission.DIAGNOSTIC" />
-	<uses-sdk android:minSdkVersion="21" android:targetSdkVersion="31" />
+	<uses-sdk android:minSdkVersion="29" android:targetSdkVersion="34" />
 </manifest>
\ No newline at end of file
diff --git a/dockerfiles/Dockerfile.migraphx b/dockerfiles/Dockerfile.migraphx
index bc513a8e8ba6d..c3541a8bd3425 100644
--- a/dockerfiles/Dockerfile.migraphx
+++ b/dockerfiles/Dockerfile.migraphx
@@ -5,57 +5,22 @@
 # Dockerfile to run ONNXRuntime with MIGraphX integration
 #--------------------------------------------------------------------------
 
-FROM ubuntu:20.04
+FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1
 
 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
 ARG ONNXRUNTIME_BRANCH=main
-ARG ROCM_VERSION=5.4
-# MIGraphX version should be the same as ROCm version
-ARG MIGRAPHX_VERSION=rocm-5.4.0
-ENV DEBIAN_FRONTEND noninteractive
-ENV MIGRAPHX_DISABLE_FAST_GELU=1
 
-RUN apt-get clean && apt-get update && apt-get install -y locales
-RUN locale-gen en_US.UTF-8
-RUN update-locale LANG=en_US.UTF-8
-ENV LC_ALL C.UTF-8
-ENV LANG C.UTF-8
+ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH}
 
-# Install rocm
-RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && \
-  curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \
-  sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/${ROCM_VERSION}/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
-
-RUN apt-get update &&\
-    apt-get install -y sudo git bash build-essential rocm-dev python3-dev python3-pip miopen-hip \
-    rocblas half aria2 libnuma-dev pkg-config
-
-RUN aria2c -q -d /tmp -o cmake-3.27.3-linux-x86_64.tar.gz \
-https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3-linux-x86_64.tar.gz &&\
-tar -zxf /tmp/cmake-3.27.3-linux-x86_64.tar.gz --strip=1 -C /usr
-
-# Install rbuild
-RUN pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz numpy yapf==0.28.0
-
-ENV PATH /opt/miniconda/bin:/code/cmake-3.27.3-linux-x86_64/bin:${PATH}
-
-# Install MIGraphX from source
-RUN mkdir -p /migraphx
-RUN cd /migraphx && git clone --depth=1 --branch ${MIGRAPHX_VERSION} https://github.com/ROCmSoftwarePlatform/AMDMIGraphX src
-RUN cd /migraphx && rbuild package --cxx /opt/rocm/llvm/bin/clang++ -d /migraphx/deps -B /migraphx/build -S /migraphx/src/ -DPYTHON_EXECUTABLE=/usr/bin/python3
-RUN dpkg -i /migraphx/build/*.deb
-RUN rm -rf /migraphx
-
-# Install rocm ep dependencies
 RUN apt-get update &&\
-    apt-get install -y rocrand rccl hipsparse hipfft hipcub hipblas rocthrust
+    apt-get install -y migraphx
 
 WORKDIR /code
 
 # Prepare onnxruntime repository & build onnxruntime
 RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\
     /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\
-    cd onnxruntime  &&\
+    cd onnxruntime  && pip install --upgrade pip &&\
     /bin/sh ./build.sh --allow_running_as_root --cmake_extra_defines ONNXRUNTIME_VERSION=`cat ./VERSION_NUMBER` --config Release --parallel \
             --skip_tests --build_wheel --use_rocm --rocm_version=${ROCM_VERSION} --rocm_home /opt/rocm --use_migraphx &&\
     pip install /code/onnxruntime/build/Linux/Release/dist/*.whl
diff --git a/dockerfiles/Dockerfile.openvino b/dockerfiles/Dockerfile.openvino
index 78d04a51ba162..049916fac92f1 100644
--- a/dockerfiles/Dockerfile.openvino
+++ b/dockerfiles/Dockerfile.openvino
@@ -1,9 +1,9 @@
 #-------------------------------------------------------------------------
-# Copyright(C) 2021-2023 Intel Corporation.
+# Copyright(C) 2021-2024 Intel Corporation.
 # SPDX-License-Identifier: MIT
 #--------------------------------------------------------------------------
 
-ARG OPENVINO_VERSION=2023.0.0
+ARG OPENVINO_VERSION=2024.0.0
 
 
 # Build stage
@@ -17,7 +17,7 @@ ARG DEVICE=CPU_FP32
 ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime.git
 ARG ONNXRUNTIME_BRANCH=main
 
-ENV InferenceEngine_DIR=${INTEL_OPENVINO_DIR}/runtime/cmake
+ENV OpenVINO_DIR=${INTEL_OPENVINO_DIR}/runtime/cmake
 
 USER root
 RUN apt update; apt install -y git protobuf-compiler libprotobuf-dev
diff --git a/dockerfiles/Dockerfile.openvino-centos7 b/dockerfiles/Dockerfile.openvino-centos7
deleted file mode 100755
index 697db44801e3b..0000000000000
--- a/dockerfiles/Dockerfile.openvino-centos7
+++ /dev/null
@@ -1,105 +0,0 @@
-#-------------------------------------------------------------------------
-# Copyright(C) 2021 Intel Corporation.
-# SPDX-License-Identifier: MIT
-#--------------------------------------------------------------------------
-
-FROM centos:7.8.2003
-
-WORKDIR /code
-
-ARG MY_ROOT=/code
-ARG YUM_OV_PACKAGE=intel-openvino-runtime-centos7-2021.4.752.x86_64
-ARG DEVICE=CPU_FP32
-ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime
-ARG ONNXRUNTIME_BRANCH=main
-
-ENV INTEL_OPENVINO_DIR=/opt/intel/openvino_2021.4.752
-ENV InferenceEngine_DIR=${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/share
-ENV IE_PLUGINS_PATH=${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/lib/intel64
-ENV ngraph_DIR=${INTEL_OPENVINO_DIR}/deployment_tools/ngraph/cmake
-ENV LD_LIBRARY_PATH=/opt/intel/opencl:${INTEL_OPENVINO_DIR}/inference_engine/external/gna/lib:${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/external/mkltiny_lnx/lib:$INTEL_OPENVINO_DIR/deployment_tools/ngraph/lib:${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/external/omp/lib:${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/external/tbb/lib:${IE_PLUGINS_PATH}:${LD_LIBRARY_PATH}
-ENV OpenCV_DIR=${INTEL_OPENVINO_DIR}/opencv/share/OpenCV
-ENV LD_LIBRARY_PATH=${INTEL_OPENVINO_DIR}/opencv/lib:${INTEL_OPENVINO_DIR}/opencv/share/OpenCV/3rdparty/lib:${LD_LIBRARY_PATH}
-ENV HDDL_INSTALL_DIR=${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/external/hddl
-ENV LD_LIBRARY_PATH=${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/external/hddl/lib:$LD_LIBRARY_PATH
-ENV LD_LIBRARY_PATH=/usr/local/lib:/usr/lib:/usr/local/lib64:/usr/lib64:/lib64:$LD_LIBRARY_PATH
-
-# Install packages
-RUN yum update -y && \
-    yum groupinstall "Development Tools" -y && \
-    yum install -y yum-utils autoconf automake libtool unzip udev wget zlib-devel libffi-devel openssl-devel boost-devel-1.53.0 && \
-    yum clean packages &&  yum clean all && rm -rf /var/cache/yum && \
-# Install cmake
-    cd $MY_ROOT && \
-    wget https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3.tar.gz && \
-    tar -zxvf cmake-3.27.3.tar.gz && rm -rf cmake-3.27.3.tar.gz && \
-    cd cmake-3.27.3 && \
-    ./bootstrap && \
-    make && \
-    make install && \
-    cd $MY_ROOT && \
-# libusb1.0.22
-    cd /opt/ && wget https://github.com/libusb/libusb/archive/v1.0.22.zip && \
-    unzip v1.0.22.zip && rm -rf v1.0.22.zip && cd  /opt/libusb-1.0.22 && \
-# bootstrap steps
-    ./bootstrap.sh && \
-    ./configure --disable-udev --enable-shared && \
-    make -j4 && \
-# configure libusb1.0.22
-    cd /opt/libusb-1.0.22/libusb && \
-    /bin/mkdir -p '/usr/local/lib' && \
-    /bin/bash ../libtool   --mode=install /usr/bin/install -c   libusb-1.0.la '/usr/local/lib' && \
-    /bin/mkdir -p '/usr/local/include/libusb-1.0' && \
-    /usr/bin/install -c -m 644 libusb.h '/usr/local/include/libusb-1.0' && \
-    /bin/mkdir -p '/usr/local/lib/pkgconfig' && \
-# Install openvino
-    yum-config-manager --add-repo https://yum.repos.intel.com/openvino/2021/setup/intel-openvino-2021.repo && \
-    rpm --import https://yum.repos.intel.com/openvino/2021/setup/RPM-GPG-KEY-INTEL-OPENVINO-2021 && \
-    yum update -y && yum list intel-openvino* && \
-    yum install -y $YUM_OV_PACKAGE && \
-    cd ${INTEL_OPENVINO_DIR}/install_dependencies/ && ./install_openvino_dependencies.sh -y && \
-    printf "\nexport LD_LIBRARY_PATH=\${LD_LIBRARY_PATH}:/usr/local/lib\n" >> /opt/intel/openvino_2021.4.752/bin/setupvars.sh && \
-    cd /opt/libusb-1.0.22 && \
-    /usr/bin/install -c -m 644 libusb-1.0.pc '/usr/local/lib/pkgconfig' && \
-    cp /opt/intel/openvino_2021/deployment_tools/inference_engine/external/97-myriad-usbboot.rules /etc/udev/rules.d/ && \
-    ldconfig && \
-# Install GPU runtime and drivers
-    cd ${MY_ROOT} && \
-    mkdir /tmp/opencl && \
-    cd /tmp/opencl && \
-    yum install -y epel-release && \
-    yum install -y ocl-icd ocl-icd-devel && \
-    wget -O intel-igc-core-1.0.2597-1.el7.x86_64.rpm https://sourceforge.net/projects/intel-compute-runtime/files/19.41.14441/centos-7/intel-igc-core-1.0.2597-1.el7.x86_64.rpm/download && \
-    wget -O intel-opencl-19.41.14441-1.el7.x86_64.rpm https://sourceforge.net/projects/intel-compute-runtime/files/19.41.14441/centos-7/intel-opencl-19.41.14441-1.el7.x86_64.rpm/download && \
-    wget -O intel-igc-opencl-devel-1.0.2597-1.el7.x86_64.rpm https://sourceforge.net/projects/intel-compute-runtime/files/19.41.14441/centos-7/intel-igc-opencl-devel-1.0.2597-1.el7.x86_64.rpm/download && \
-    wget -O intel-igc-opencl-1.0.2597-1.el7.x86_64.rpm https://sourceforge.net/projects/intel-compute-runtime/files/19.41.14441/centos-7/intel-igc-opencl-1.0.2597-1.el7.x86_64.rpm/download && \
-    wget -O intel-gmmlib-19.3.2-1.el7.x86_64.rpm https://sourceforge.net/projects/intel-compute-runtime/files/19.41.14441/centos-7/intel-gmmlib-19.3.2-1.el7.x86_64.rpm/download && \
-    wget -O intel-gmmlib-devel-19.3.2-1.el7.x86_64.rpm https://sourceforge.net/projects/intel-compute-runtime/files/19.41.14441/centos-7/intel-gmmlib-devel-19.3.2-1.el7.x86_64.rpm/download && \
-    rpm -i /tmp/opencl/*.rpm  && \
-    ldconfig  && \
-    rm -rf /tmp/opencl && \
-# Installing gcc-10
-    yum install -y centos-release-scl && \
-    yum install -y devtoolset-10-gcc* && \
-    echo 'source scl_source enable devtoolset-10' >> ~/.bashrc && \
-# python installation
-    source scl_source enable devtoolset-10 && \
-    cd /code/ && \
-    wget https://www.python.org/ftp/python/3.8.3/Python-3.8.3.tgz && tar xvf Python-3.8.3.tgz && \
-    cd Python-3.8*/ && ./configure && make && make install && \
-    cd ../ &&  mkdir -p /usr/bin/Python38 && ln -s Python-3.8.3/ /usr/bin/Python38 && \
-# installing dependancies
-    yum install -y python3-lxml python3-six libusb.x86_64 && \
-    yum clean packages &&  yum clean all && rm -rf /var/cache/yum && \
-# Build onnxruntime
-    cd $MY_ROOT && \
-    pip3 install numpy wheel setuptools cython && \
-    git clone --recursive -b ${ONNXRUNTIME_BRANCH} ${ONNXRUNTIME_REPO} && \
-    pip3 install onnx && \
-    cd /code/onnxruntime && ./build.sh --allow_running_as_root --config Release --update --build --parallel --use_openvino ${DEVICE} --build_shared_lib --build_wheel && \
-    pip3 install /code/onnxruntime/build/Linux/Release/dist/*-linux_x86_64.whl && \
-# Clean up
-    cd  $MY_ROOT && rm -rf onnxruntime Python-3* && \
-    cd ${MY_ROOT}/ && rm -rf cmake* && \
-    cd /usr/share/ && rm -rf gcc* && cd /usr/lib/ && rm -rf gcc cd && rm -rf .cache && \
-    cd ${INTEL_OPENVINO_DIR}/ && rm -rf documentation data_processing && cd deployment_tools/ && rm -rf tools
diff --git a/dockerfiles/Dockerfile.openvino-csharp b/dockerfiles/Dockerfile.openvino-csharp
deleted file mode 100644
index 2529ef4b73209..0000000000000
--- a/dockerfiles/Dockerfile.openvino-csharp
+++ /dev/null
@@ -1,90 +0,0 @@
-#-------------------------------------------------------------------------
-# Copyright(C) 2021-2023 Intel Corporation.
-# SPDX-License-Identifier: MIT
-#--------------------------------------------------------------------------
-
-ARG OPENVINO_VERSION=2023.0.0
-
-# Build stage
-FROM openvino/ubuntu20_runtime:${OPENVINO_VERSION} AS base
-
-ENV WORKDIR_PATH=/home/openvino
-WORKDIR $WORKDIR_PATH
-ENV DEBIAN_FRONTEND noninteractive
-
-USER root
-RUN apt update; apt install -y --no-install-recommends wget gnupg && \
-    rm -rf /var/lib/apt/lists/*
-
-# Install Mono
-RUN wget http://download.mono-project.com/repo/xamarin.gpg && apt-key add xamarin.gpg && rm xamarin.gpg && \
-    echo "deb https://download.mono-project.com/repo/ubuntu stable-bionic main" | tee /etc/apt/sources.list.d/mono-official-stable.list && \
-    apt update -y && \
-    apt install -y mono-devel
-
-# Install nuget.exe
-RUN wget https://dist.nuget.org/win-x86-commandline/latest/nuget.exe && \
-    mv nuget.exe /usr/local/bin/nuget.exe && \
-    echo 'mono /usr/local/bin/nuget.exe $@' > /usr/local/bin/nuget && \
-    chmod a+x /usr/local/bin/nuget
-
-# Install .NET core
-RUN wget https://packages.microsoft.com/config/ubuntu/20.04/packages-microsoft-prod.deb -O packages-microsoft-prod.deb && \
-    dpkg -i packages-microsoft-prod.deb && \
-    apt-get update -y &&\
-    apt-get install -y apt-transport-https && \
-    apt-get update -y && \
-    apt-get install -y dotnet-sdk-5.0
-
-# Build stage
-FROM base AS builder
-
-ENV WORKDIR_PATH=/home/openvino
-WORKDIR $WORKDIR_PATH
-ENV DEBIAN_FRONTEND noninteractive
-
-ARG DEVICE=CPU_FP32
-ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime.git
-ARG ONNXRUNTIME_BRANCH=main
-
-ENV InferenceEngine_DIR=${INTEL_OPENVINO_DIR}/runtime/cmake
-ENV LANG en_US.UTF-8
-
-USER root
-RUN apt update; apt install -y --no-install-recommends git protobuf-compiler libprotobuf-dev ca-certificates unattended-upgrades && \
-    unattended-upgrade && \
-    rm -rf /var/lib/apt/lists/*
-
-RUN git clone --recursive -b ${ONNXRUNTIME_BRANCH} ${ONNXRUNTIME_REPO}
-RUN /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh
-RUN ln -s cmake-* cmake-dir
-RUN python3 -m pip install wheel
-ENV PATH=${WORKDIR_PATH}/cmake-dir/bin:$PATH
-RUN pip3 install onnx
-RUN ln -s /usr/bin/python3 /usr/bin/python
-RUN apt install locales && \
-    locale-gen en_US en_US.UTF-8 && \
-    dpkg-reconfigure locales
-RUN cd onnxruntime && ./build.sh --allow_running_as_root --config Release --update --build --parallel --use_openvino ${DEVICE} --build_nuget --build_shared_lib
-RUN cp /home/openvino/onnxruntime/build/Linux/Release/Microsoft.ML.OnnxRuntime.Managed* /home/openvino/onnxruntime/build/Linux/Release/nuget-artifacts
-
-# Deploy stage
-FROM base
-
-ENV DEBIAN_FRONTEND noninteractive
-USER root
-
-RUN apt update; apt install -y unattended-upgrades fonts-freefont-ttf && \
-    unattended-upgrade
-ARG BUILD_UID=1001
-ARG BUILD_USER=onnxruntimedev
-RUN adduser --uid $BUILD_UID $BUILD_USER
-RUN usermod -a -G video,users ${BUILD_USER}
-ENV WORKDIR_PATH /home/${BUILD_USER}
-WORKDIR ${WORKDIR_PATH}
-COPY --from=builder /home/openvino/onnxruntime/build/Linux/Release/nuget-artifacts ${WORKDIR_PATH}/nuget-artifacts
-
-USER ${BUILD_USER}
-ENV PATH=${WORKDIR_PATH}/miniconda/bin:${WORKDIR_PATH}/cmake-dir/bin:$PATH
-ENV IE_PLUGINS_PATH=${INTEL_OPENVINO_DIR}/runtime/lib/intel64
-ENV LD_LIBRARY_PATH=/opt/intel/opencl:${INTEL_OPENVINO_DIR}/runtime/3rdparty/tbb/lib:${IE_PLUGINS_PATH}:${LD_LIBRARY_PATH}
diff --git a/dockerfiles/Dockerfile.openvino-rhel8 b/dockerfiles/Dockerfile.openvino-rhel8
deleted file mode 100644
index 5c504cfa553a1..0000000000000
--- a/dockerfiles/Dockerfile.openvino-rhel8
+++ /dev/null
@@ -1,87 +0,0 @@
-# Build stage
-FROM registry.access.redhat.com/ubi8/ubi:8.4
-
-WORKDIR /code
-
-ARG MY_ROOT=/code
-ARG DEVICE=CPU_FP32
-ARG ONNXRUNTIME_REPO=https://github.com/microsoft/onnxruntime
-ARG ONNXRUNTIME_BRANCH=main
-
-ENV INTEL_OPENVINO_DIR=/opt/intel/openvino_2022.3.0
-
-ENV InferenceEngine_DIR=${INTEL_OPENVINO_DIR}/runtime/cmake
-ENV IE_PLUGINS_PATH=${INTEL_OPENVINO_DIR}/runtime/lib/intel64/
-ENV ngraph_DIR=${INTEL_OPENVINO_DIR}/runtime/cmake
-ENV LD_LIBRARY_PATH=${INTEL_OPENVINO_DIR}/runtime/3rdparty/tbb/lib/:${IE_PLUGINS_PATH}:${LD_LIBRARY_PATH}
-ENV OpenCV_DIR=${INTEL_OPENVINO_DIR}/extras/opencv/cmake
-ENV LD_LIBRARY_PATH=${INTEL_OPENVINO_DIR}/extras/opencv/lib:${LD_LIBRARY_PATH}
-ENV LD_LIBRARY_PATH=/usr/local/lib:/usr/lib:/usr/local/lib64:/usr/lib64:/lib64:${LD_LIBRARY_PATH}
-ENV PATH=${MY_ROOT}/cmake-dir/bin:$PATH
-
-# Install packages
-RUN yum install -y yum-utils autoconf automake libtool unzip udev wget zlib-devel libffi-devel openssl-devel git make gcc && \
-    yum clean packages &&  yum clean all && rm -rf /var/cache/yum && \
-# Install python 3.8
-    cd $MY_ROOT && \
-    wget https://www.python.org/ftp/python/3.8.9/Python-3.8.9.tgz && tar xvf Python-3.8.9.tgz && rm -rf Python-3.8.9.tgz && \
-    cd Python-3.8*/ && ./configure && make && make install && \
-    cd ../ &&  mkdir -p /usr/bin/Python38 && ln -s Python-3.8.9/ /usr/bin/Python38 && ln -s /usr/bin/pip3 /usr/bin/pip && \
-# libusb1.0.22
-    cd /opt/ && wget https://github.com/libusb/libusb/archive/v1.0.22.zip && \
-    unzip v1.0.22.zip && rm -rf v1.0.22.zip && cd  /opt/libusb-1.0.22 && \
-# bootstrap steps
-    ./bootstrap.sh && \
-    ./configure --disable-udev --enable-shared && \
-    make -j4 && \
-# configure libusb1.0.22
-    cd /opt/libusb-1.0.22/libusb && \
-    /bin/mkdir -p '/usr/local/lib' && \
-    /bin/bash ../libtool   --mode=install /usr/bin/install -c   libusb-1.0.la '/usr/local/lib' && \
-    /bin/mkdir -p '/usr/local/include/libusb-1.0' && \
-    /usr/bin/install -c -m 644 libusb.h '/usr/local/include/libusb-1.0' && \
-    /bin/mkdir -p '/usr/local/lib/pkgconfig' && \
-# Install openvino
-    cd /opt/ && mkdir intel/ && cd intel && \
-    wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2022.3/linux/l_openvino_toolkit_rhel8_2022.3.0.9052.9752fafe8eb_x86_64.tgz  && \
-    tar xvf l_openvino_toolkit_rhel8_2022.3.0.9052.9752fafe8eb_x86_64.tgz && \
-    rm -rf l_openvino_toolkit_rhel8_2022.3.0.9052.9752fafe8eb_x86_64.tgz && \
-    mv l_openvino_toolkit_rhel8_2022.3.0.9052.9752fafe8eb_x86_64 openvino_2022.3.0 && \
-    cd ${INTEL_OPENVINO_DIR}/install_dependencies/ && ./install_openvino_dependencies.sh -y && ./install_NEO_OCL_driver.sh -y && \
-    printf "\nexport LD_LIBRARY_PATH=\${LD_LIBRARY_PATH}:/usr/local/lib\n" >> /opt/intel/openvino_2022.3.0/setupvars.sh && \
-    cd /opt/libusb-1.0.22 && \
-    /usr/bin/install -c -m 644 libusb-1.0.pc '/usr/local/lib/pkgconfig' && \
-    # MYRIAD plugins are not available for openvino 2022.3.0 release
-    #cp /opt/intel/openvino_2022.3.0/install_dependencies/97-myriad-usbboot.rules /etc/udev/rules.d/ && \
-    ldconfig && \
-#Install protobuf
-    cd $MY_ROOT && \
-    git clone https://github.com/protocolbuffers/protobuf.git && \
-    cd protobuf && \
-    git checkout v3.16.0 && \
-    git submodule update --init --recursive && \
-    mkdir build_source && cd build_source && \
-    cmake ../cmake  -DCMAKE_INSTALL_LIBDIR=lib64 -Dprotobuf_BUILD_SHARED_LIBS=OFF -DCMAKE_INSTALL_PREFIX=/usr -DCMAKE_INSTALL_SYSCONFDIR=/etc -DCMAKE_POSITION_INDEPENDENT_CODE=ON -Dprotobuf_BUILD_TESTS=OFF -DCMAKE_BUILD_TYPE=Release && \
-    make -j$(nproc) && \
-    make install && \
-# Build onnxruntime
-    cd $MY_ROOT && \
-    pip3 install numpy wheel setuptools cython onnx && \
-    git clone --recursive -b ${ONNXRUNTIME_BRANCH} ${ONNXRUNTIME_REPO} && \
-    bash onnxruntime/dockerfiles/scripts/install_common_deps.sh && \
-    ln -s cmake-* cmake-dir && \
-    source /opt/intel/openvino_2022.3.0/setupvars.sh && \
-    cd /code/onnxruntime && ./build.sh --allow_running_as_root --config Release --update --build --parallel --use_openvino ${DEVICE} --build_shared_lib --build_wheel && \
-    pip3 install /code/onnxruntime/build/Linux/Release/dist/*-linux_x86_64.whl && \
-# Clean up
-    cd ${MY_ROOT} && rm -rf onnxruntime && rm -rf Python-3.8.9 && rm -rf protobuf
-
-# Deploy stage
-ARG BUILD_UID=1001
-ARG BUILD_USER=onnxruntimedev
-RUN adduser --uid $BUILD_UID $BUILD_USER
-RUN usermod -a -G video,users,render ${BUILD_USER}
-ENV WORKDIR_PATH /home/${BUILD_USER}
-
-WORKDIR ${WORKDIR_PATH}
-USER ${BUILD_USER}
diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm
index 35a676383337b..c242933f677f0 100644
--- a/dockerfiles/Dockerfile.rocm
+++ b/dockerfiles/Dockerfile.rocm
@@ -5,14 +5,14 @@
 # Dockerfile to run ONNXRuntime with ROCm integration
 #--------------------------------------------------------------------------
 
-FROM rocm/pytorch:rocm5.4_ubuntu20.04_py3.7_pytorch_1.12.1
+FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1
 
 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
 ARG ONNXRUNTIME_BRANCH=main
 
 WORKDIR /code
 
-ENV PATH /opt/miniconda/bin:/code/cmake-3.27.3-linux-x86_64/bin:${PATH}
+ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH}
 
 # Prepare onnxruntime repository & build onnxruntime
 RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\
diff --git a/dockerfiles/README.md b/dockerfiles/README.md
index f226ebfe8b193..a2e99d66d4654 100644
--- a/dockerfiles/README.md
+++ b/dockerfiles/README.md
@@ -277,7 +277,7 @@ Nothing else from ONNX Runtime source tree will be copied/installed to the image
 Note: When running the container you built in Docker, please either use 'nvidia-docker' command instead of 'docker', or use Docker command-line options to make sure NVIDIA runtime will be used and appropiate files mounted from host. Otherwise, CUDA libraries won't be found. You can also [set NVIDIA runtime as default in Docker](https://github.com/dusty-nv/jetson-containers#docker-default-runtime).
 
 ## MIGraphX
-**Ubuntu 20.04, ROCm5.4, AMDMIGraphX v1.2**
+**Ubuntu 20.04, ROCm6.0, MIGraphX**
 
 1. Build the docker image from the Dockerfile in this repository.
   ```
@@ -291,7 +291,7 @@ Note: When running the container you built in Docker, please either use 'nvidia-
   ```
 
    ## ROCm
-**Ubuntu 20.04, ROCm5.4**
+**Ubuntu 20.04, ROCm6.0**
 
 1. Build the docker image from the Dockerfile in this repository.
   ```
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index e7b537d6894c8..32a4ca16b7824 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -461,7 +461,7 @@ This version of the operator has been available since version 1 of the 'com.micr
 <dt><tt>repetition_penalty</tt> (optional) : T</dt>
 <dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
 <dt><tt>vocab_mask</tt> (optional) : M</dt>
-<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
+<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
 <dt><tt>prefix_vocab_mask</tt> (optional) : M</dt>
 <dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
 <dt><tt>attention_mask</tt> (optional) : I</dt>
@@ -2252,7 +2252,7 @@ This version of the operator has been available since version 1 of the 'com.micr
 <dt><tt>repetition_penalty</tt> (optional) : T</dt>
 <dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
 <dt><tt>vocab_mask</tt> (optional) : I</dt>
-<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
+<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
 <dt><tt>prefix_vocab_mask</tt> (optional) : I</dt>
 <dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
 <dt><tt>attention_mask</tt> (optional) : I</dt>
@@ -2795,7 +2795,7 @@ This version of the operator has been available since version 1 of the 'com.micr
 <dd>Constrain input A data type to 8-bit integer tensor.</dd>
 <dt><tt>T2</tt> : tensor(int8), tensor(uint8)</dt>
 <dd>Constrain input B data type to 8-bit integer tensor.</dd>
-<dt><tt>T3</tt> : tensor(float)</dt>
+<dt><tt>T3</tt> : tensor(float), tensor(float16)</dt>
 <dd>Constrain input a_scale, b_scale and output Y data type as float tensor.</dd>
 </dl>
 
@@ -2808,22 +2808,23 @@ This version of the operator has been available since version 1 of the 'com.micr
        And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
     3. Input B's scale and zero point are specified by input scales and zero_points.
   
-  Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
-  - n_blocks_per_col = (K + block_size - 1) / block_size
-  - blob_size = block_size / 8 * bits
+    Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
+    - n_blocks_per_col = (K + block_size - 1) / block_size
+    - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>)
+    For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t.
+      - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t.
+          4bit example:
+          |.|.|.|.| .|.|.|.| =uint8_t (2x4bit)
+      - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted.
+          3bit example:
+          |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used.
+    The last uint_8 may have some bits unused.
   
-    For a block blob. It is stored in format:
-    struct Blob {
-      uint8 one_bits[(bits & 0x1) * 1 * block_size / 8];  // highest 1 bit for 3, 5, 7 bits quantization
-      uint8 two_bits[(bits & 0x2) * 2 * block_size / 8];  // high 2 bits for 2, 6, 7 bits quantization
-      uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization
-    }
   
   Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col]
-  Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is:
-    - [(N * n_blocks_per_col + 1) / 2] if bits <=4
-    - [N * n_blocks_per_col] if bits > 4
-  
+  Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B.
+    - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)]
+    If zero_points has same type as A, it's not packed and has the same shape as Scales.
 
 #### Version
 
@@ -2844,17 +2845,19 @@ This version of the operator has been available since version 1 of the 'com.micr
 <dd>number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.</dd>
 </dl>
 
-#### Inputs (3 - 4)
+#### Inputs (3 - 5)
 
 <dl>
 <dt><tt>A</tt> : T1</dt>
 <dd>The input tensor, not quantized</dd>
 <dt><tt>B</tt> : T2</dt>
-<dd>1-dimensional data blob</dd>
+<dd>1 or 2 dimensional data blob</dd>
 <dt><tt>scales</tt> : T1</dt>
 <dd>quantization scale</dd>
-<dt><tt>zero_points</tt> (optional) : T2</dt>
+<dt><tt>zero_points</tt> (optional) : T3</dt>
 <dd>quantization zero points</dd>
+<dt><tt>g_idx</tt> (optional) : T4</dt>
+<dd>group_idx</dd>
 </dl>
 
 #### Outputs
@@ -2869,8 +2872,12 @@ This version of the operator has been available since version 1 of the 'com.micr
 <dl>
 <dt><tt>T1</tt> : tensor(float), tensor(float16)</dt>
 <dd>Constrain input and output types to float/half_float tensors.</dd>
-<dt><tt>T2</tt> : tensor(uint8)</dt>
-<dd>Constrain quantized weight types to uint8.</dd>
+<dt><tt>T2</tt> : tensor(uint8), tensor(int32)</dt>
+<dd>Constrain quantized weight types to uint8/int32.</dd>
+<dt><tt>T3</tt> : tensor(uint8), tensor(int32), tensor(float16), tensor(float)</dt>
+<dd>Constrain quantized zero point types to uint8/int32/float16/float.</dd>
+<dt><tt>T4</tt> : tensor(int32)</dt>
+<dd>the index tensor.</dd>
 </dl>
 
 
@@ -2924,8 +2931,8 @@ This version of the operator has been available since version 1 of the 'com.micr
 ### <a name="com.microsoft.MoE"></a><a name="com.microsoft.moe">**com.microsoft.MoE**</a>
 
   Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1,
-        GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf)
-        usually uses top 32 experts.
+        GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf)
+        usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral).
         
 
 #### Version
@@ -2939,9 +2946,11 @@ This version of the operator has been available since version 1 of the 'com.micr
 <dd>Activation function to use. Choose from relu, gelu, silu and identity. Default is relu</dd>
 <dt><tt>k</tt> : int</dt>
 <dd>Number of top experts to select from expert pool</dd>
+<dt><tt>normalize_routing_weights</tt> : int</dt>
+<dd>Whether to normalize routing weights</dd>
 </dl>
 
-#### Inputs (4 - 6)
+#### Inputs (5 - 8)
 
 <dl>
 <dt><tt>input</tt> : T</dt>
@@ -2950,12 +2959,16 @@ This version of the operator has been available since version 1 of the 'com.micr
 <dd>2D input tensor with shape (num_rows, num_experts)</dd>
 <dt><tt>fc1_experts_weights</tt> : T</dt>
 <dd>3D input tensor with shape (num_experts, hidden_size, inter_size)</dd>
-<dt><tt>fc2_experts_weights</tt> : T</dt>
-<dd>3D input tensor with shape (num_experts, inter_size, hidden_size)</dd>
 <dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
 <dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
+<dt><tt>fc2_experts_weights</tt> : T</dt>
+<dd>3D input tensor with shape (num_experts, inter_size, hidden_size)</dd>
 <dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
 <dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
+<dt><tt>fc3_experts_weights</tt> (optional) : T</dt>
+<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size)</dd>
+<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>
+<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
 </dl>
 
 #### Outputs
@@ -5154,7 +5167,7 @@ This version of the operator has been available since version 1 of the 'com.micr
 <dt><tt>repetition_penalty</tt> (optional) : T</dt>
 <dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
 <dt><tt>vocab_mask</tt> (optional) : I</dt>
-<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
+<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
 <dt><tt>prefix_vocab_mask</tt> (optional) : I</dt>
 <dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
 <dt><tt>attention_mask</tt> (optional) : I</dt>
@@ -5743,12 +5756,14 @@ This version of the operator has been available since version 1 of the 'com.micr
 #### Attributes
 
 <dl>
+<dt><tt>beginning_timestamp_token_id</tt> : int</dt>
+<dd>The id of the first timestamp</dd>
 <dt><tt>decoder</tt> : graph (required)</dt>
 <dd>Decoder subgraph to execute in a loop.</dd>
 <dt><tt>decoder_output_cross_qk</tt> : int</dt>
 <dd>If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.</dd>
 <dt><tt>decoder_start_token_id</tt> : int</dt>
-<dd>The id of the token that indicates decoding starts.</dd>
+<dd>The id of the token that indicates decoding starts (i.e. the start of transcription token id)</dd>
 <dt><tt>early_stopping</tt> : int</dt>
 <dd>early stop or not</dd>
 <dt><tt>encoder</tt> : graph</dt>
@@ -5761,10 +5776,18 @@ This version of the operator has been available since version 1 of the 'com.micr
 <dd>Must be 2 for whisper</dd>
 <dt><tt>no_repeat_ngram_size</tt> : int</dt>
 <dd>no repeat ngrams size</dd>
-<dt><tt>no_speech_token</tt> : int</dt>
+<dt><tt>no_speech_token_id</tt> : int</dt>
 <dd>The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.</dd>
+<dt><tt>no_timestamps_token_id</tt> : int</dt>
+<dd>The id of the token that indicates no timestamps</dd>
 <dt><tt>pad_token_id</tt> : int (required)</dt>
 <dd>The id of the padding token</dd>
+<dt><tt>start_of_lm_token_id</tt> : int</dt>
+<dd>The id of the token that indicates LM starts</dd>
+<dt><tt>transcribe_token_id</tt> : int</dt>
+<dd>The id of the transcribe task</dd>
+<dt><tt>translate_token_id</tt> : int</dt>
+<dd>The id of the translate task</dd>
 <dt><tt>vocab_size</tt> : int</dt>
 <dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
 </dl>
@@ -5783,11 +5806,11 @@ This version of the operator has been available since version 1 of the 'com.micr
 <dt><tt>num_return_sequences</tt> : I</dt>
 <dd>The number of returned sequences in the batch. Shape is (1)</dd>
 <dt><tt>length_penalty</tt> (optional) : T</dt>
-<dd>Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)</dd>
+<dd>Exponential penalty to the length. Default value 1.0 means no penalty. Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. Shape is (1,)</dd>
 <dt><tt>repetition_penalty</tt> (optional) : T</dt>
 <dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
 <dt><tt>vocab_mask</tt> (optional) : M</dt>
-<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
+<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
 <dt><tt>prefix_vocab_mask</tt> (optional) : M</dt>
 <dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
 <dt><tt>attention_mask</tt> (optional) : I</dt>
@@ -5797,7 +5820,7 @@ This version of the operator has been available since version 1 of the 'com.micr
 <dt><tt>logits_processor</tt> (optional) : I</dt>
 <dd>Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)</dd>
 <dt><tt>cross_qk_layer_head</tt> (optional) : I</dt>
-<dd>Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]</dd>
+<dd>Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]</dd>
 <dt><tt>extra_decoding_ids</tt> (optional) : I</dt>
 <dd>Part of the decoder_input_ids that we need cross qk for it. it is of shape  (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.</dd>
 <dt><tt>temperature</tt> (optional) : T</dt>
@@ -5812,11 +5835,11 @@ This version of the operator has been available since version 1 of the 'com.micr
 <dt><tt>sequences_scores</tt> (optional) : T</dt>
 <dd>Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)</dd>
 <dt><tt>scores</tt> (optional) : T</dt>
-<dd>Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)</dd>
+<dd>Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)</dd>
 <dt><tt>cross_qk</tt> (optional) : V</dt>
-<dd>Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]</dd>
+<dd>Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]</dd>
 <dt><tt>non_speech_probs</tt> (optional) : T</dt>
-<dd>For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]</dd>
+<dd>For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. The shape of non_speech_probs is [B]</dd>
 </dl>
 
 #### Type Constraints
diff --git a/docs/Memory_Optimizer.md b/docs/Memory_Optimizer.md
index 97f7e7ff2c14b..eaa48c9da0609 100644
--- a/docs/Memory_Optimizer.md
+++ b/docs/Memory_Optimizer.md
@@ -51,6 +51,7 @@ There are two modes to enable the memory optimizations:
 	- Plan 8            :  OFF  :  Cast+:2:-1                                           1     2,048              2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
 	```
 3. As shown above, `Config` is a string representative for a re-computable subgraph. All are enabled for recompute in this case.
+4. By `export ORTMODULE_MEMORY_OPT_LEVEL=2`, all plans including compromised recomptable subgraphs will also be enabled.
 
 
 ### Mode 2 -  Advanced Usage (User Selected Subgraph Recompute)
diff --git a/docs/ORTModule_Convergence_Notes.md b/docs/ORTModule_Convergence_Notes.md
index 791b6c32c9b48..2374e7b7c538a 100644
--- a/docs/ORTModule_Convergence_Notes.md
+++ b/docs/ORTModule_Convergence_Notes.md
@@ -89,7 +89,7 @@ The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output
 dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example:
 
 ```diff
-+   from onnxruntime.training.utils import inspect_activation
++   from onnxruntime.training.utils.hooks import inspect_activation
 class BloomForCausalLM(BloomPreTrainedModel):
   def __init__(self, config: BloomConfig):
     ...
diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md
index 91057d3dfb120..54137937ad56d 100644
--- a/docs/ORTModule_Training_Guidelines.md
+++ b/docs/ORTModule_Training_Guidelines.md
@@ -246,7 +246,7 @@ to standard outputs.
 #### ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER
 
 - **Feature Area**: *ORTMODULE/Optimizations*
-- **Description**: By default, this is disabled. This env var can be used for enabling or disabling the embedding input
+- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the embedding input
 data sparsity based performance optimizations.
 
 	```bash
@@ -287,13 +287,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e
 #### ORTMODULE_MEMORY_OPT_LEVEL
 
 - **Feature Area**: *ORTMODULE/Optimizations*
-- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. Setting the level to be 0 means all detected subgraphs with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. When level is not 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details.
+- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement.
+   - Setting the level to be 1 means all detected recomputable subgraphs (NOT including compromised recomputable graphs)  with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint.
+   - Setting the level to be 2 means all detected recomputable subgraphs (including compromised recomputable graphs) with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint.
+   - When the level is 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details.
 
     ```bash
     export ORTMODULE_MEMORY_OPT_LEVEL=0
     ```
 
-### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT
+#### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT
 
 - **Feature Area**: *ORTMODULE/Optimizations*
 - **Description**: By default, the memory-efficient gradient management is turned off. The gradient after it is computed in ONNX Runtime, will trigger the corresponding parameter's backward function through `PythonOpGrad` operator. This would help release the gradient buffer managed in ONNX Runtime, which originally is released once all backward computation finishes.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 2ea557b7d61fe..bca8e17b3dfd4 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -127,6 +127,7 @@ Do not modify directly.*
 |GatherND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
 |||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
 |||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
+|Gelu|*in* X:**T**<br> *out* Y:**T**|20+|**T** = tensor(float)|
 |Gemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
 |||[11, 12]|**T** = tensor(double), tensor(float)|
 |||[9, 10]|**T** = tensor(double), tensor(float)|
@@ -159,9 +160,9 @@ Do not modify directly.*
 |||[1, 10]|**B** = tensor(bool)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
 |ImageScaler|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
 |InstanceNormalization|*in* input:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *out* output:**T**|6+|**T** = tensor(float)|
-|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
+|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
 |||[10, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
-|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
+|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
 |||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
 |||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
 |LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float)|
@@ -469,7 +470,7 @@ Do not modify directly.*
 |MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
 |MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
 |MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
-|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
+|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(float), tensor(uint8)<br/> **T4** = tensor(int32)|
 |MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
 |MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float)|
 |MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
@@ -606,6 +607,7 @@ Do not modify directly.*
 |GatherND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
 |||12|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
 |||11|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
+|Gelu|*in* X:**T**<br> *out* Y:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16)|
 |Gemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
 |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
 |||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
@@ -617,6 +619,7 @@ Do not modify directly.*
 |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
 |GreaterOrEqual|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T1** = tensor(bool)|
 |||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T1** = tensor(bool)|
+|GridSample|*in* X:**T1**<br> *in* grid:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
 |HardSigmoid|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
 |Identity|*in* input:**T**<br> *out* output:**T**<br><br>or<br><br>*in* input:**V**<br> *out* output:**V**|19+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
 |||[14, 18]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -628,6 +631,11 @@ Do not modify directly.*
 |||[1, 10]|**B** = tensor(bool)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
 |ImageScaler|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
 |InstanceNormalization|*in* input:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
+|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
+|||[10, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
+|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
+|||[13, 19]|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
+|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
 |LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
 |||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
 |LSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
@@ -731,7 +739,8 @@ Do not modify directly.*
 |||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **shape** = tensor(int64)|
 |||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **shape** = tensor(int64)|
 |||[1, 4]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|13+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
+|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|18+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
+|||[13, 17]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
 |||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
 |||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
 |ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -765,7 +774,7 @@ Do not modify directly.*
 |Sigmoid|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
 |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
 |Sign|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|SimplifiedLayerNormalization|*in* X:**T**<br> *in* scale:**V**<br> *out* Y:**V**<br> *out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float)<br/> **V** = tensor(double), tensor(float), tensor(float16)|
+|SimplifiedLayerNormalization|*in* X:**T**<br> *in* scale:**V**<br> *out* Y:**V**<br> *out* inv_std_var:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float)<br/> **V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
 |Sin|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)|
 |Size|*in* data:**T**<br> *out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
 |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
@@ -784,7 +793,7 @@ Do not modify directly.*
 |||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
 |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
 |||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Sqrt|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Sqrt|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
 |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
 |Squeeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* squeezed:**T**<br><br>or<br><br>*in* data:**T**<br> *out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
 |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -851,8 +860,8 @@ Do not modify directly.*
 |Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
 |LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
 |MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
-|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
-|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
+|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
+|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
 |MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
 |NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
 |NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -1259,13 +1268,16 @@ Do not modify directly.*
 |BiasSplitGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
 |ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
 |DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int32), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
+|DynamicQuantizeMatMul|*in* A:**T1**<br> *in* B:**T2**<br> *in* b_scale:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
 |EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
 |FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
 |FusedMatMulActivation|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
 |Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
 |GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**M** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
+|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
 |MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
 |NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
+|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
 |QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
 |QLinearAveragePool|*in* X:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
 |QLinearConcat|*in* Y_scale:**TF**<br> *in* Y_zero_point:**T8**<br> *in* inputs:**TV**<br> *out* Y:**T8**|1+|**T8** = tensor(int8), tensor(uint8)<br/> **TF** = tensor(float)<br/> **TV** = tensor(float), tensor(int8), tensor(uint8)|
diff --git a/docs/python/examples/plot_train_convert_predict.py b/docs/python/examples/plot_train_convert_predict.py
index dcbc84b20767a..44b6bb74c29df 100644
--- a/docs/python/examples/plot_train_convert_predict.py
+++ b/docs/python/examples/plot_train_convert_predict.py
@@ -134,7 +134,7 @@ def loop(X_test, fct, n=None):
     nrow = X_test.shape[0]
     if n is None:
         n = nrow
-    for i in range(0, n):
+    for i in range(n):
         im = i % nrow
         fct(X_test[im : im + 1])
 
diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h
index fbeee8a2aedc5..3a3b5cb6888f2 100644
--- a/include/onnxruntime/core/framework/data_types_internal.h
+++ b/include/onnxruntime/core/framework/data_types_internal.h
@@ -305,7 +305,7 @@ class CallableDispatchableHelper {
     return 0;
   }
 
-  void CheckCalledOnce() {
+  void CheckCalledOnce() const {
     ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_);
   }
 };
diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h
index 31c988f500779..40ca96a19aef1 100644
--- a/include/onnxruntime/core/framework/execution_provider.h
+++ b/include/onnxruntime/core/framework/execution_provider.h
@@ -33,6 +33,8 @@ class Node;
 #include "core/framework/stream_handles.h"
 #include "core/framework/tuning_context.h"
 
+struct OrtRunOptions;
+
 namespace onnxruntime {
 
 /**
@@ -51,6 +53,8 @@ struct NodeComputeInfo {
   DestroyFunctionStateFunc release_state_func;
 };
 
+using RunOptions = OrtRunOptions;
+
 enum class DataLayout {
   NCHW,
   NHWC,
@@ -184,7 +188,7 @@ class IExecutionProvider {
      Run may not be finished on device This function should be regarded as the
      point after which a new Run would start to submit commands from CPU
   */
-  virtual common::Status OnRunStart() { return Status::OK(); }
+  virtual common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); }
 
   /**
      Called when InferenceSession::Run ended
@@ -192,25 +196,27 @@ class IExecutionProvider {
      may not be finished on device This function should be regarded as the point
      that all commands of current Run has been submmited by CPU
   */
-  virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); }
+  virtual common::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) {
+    return Status::OK();
+  }
 
   /**
      Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
-     the provider. Currently only CUDA execution provider supports it.
+     the provider.
    */
   virtual bool IsGraphCaptureEnabled() const { return false; }
 
   /**
-     Indicate whether the graph has been captured and instantiated. Currently
-     only CUDA execution provider supports it.
+     Indicate whether the graph has been captured and instantiated.
    */
-  virtual bool IsGraphCaptured() const { return false; }
+  virtual bool IsGraphCaptured(int /*graph_annotation_id*/) const { return false; }
 
   /**
-     Run the instantiated graph. Currently only CUDA execution provider supports
-     it.
+     Run the instantiated graph.
    */
-  virtual common::Status ReplayGraph() { return Status::OK(); }
+  virtual common::Status ReplayGraph(int /*graph_annotation_id*/) {
+    return Status::OK();
+  }
 
   /**
      Called when session creation is complete
diff --git a/include/onnxruntime/core/framework/stream_handles.h b/include/onnxruntime/core/framework/stream_handles.h
index c235ee904762e..26d78133b52fc 100644
--- a/include/onnxruntime/core/framework/stream_handles.h
+++ b/include/onnxruntime/core/framework/stream_handles.h
@@ -100,6 +100,8 @@ class Stream {
     return nullptr;
   }
 
+  virtual WaitNotificationFn GetWaitNotificationFn() const { return nullptr; }
+
  private:
   StreamHandle handle_;
   const OrtDevice& device_;
diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h
index 9b26ba914c7dd..8e04050d089a0 100644
--- a/include/onnxruntime/core/graph/constants.h
+++ b/include/onnxruntime/core/graph/constants.h
@@ -31,6 +31,7 @@ constexpr size_t kMaxExecutionProviderNameLen = 30;
 
 constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider";
 constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider";
+constexpr const char* kCudaNHWCExecutionProvider = "CUDANHWCExecutionProvider";
 constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider";
 constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider";
 constexpr const char* kVitisAIExecutionProvider = "VitisAIExecutionProvider";
diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h
index b9b8a25286b7b..b16d52dbdab68 100644
--- a/include/onnxruntime/core/graph/graph.h
+++ b/include/onnxruntime/core/graph/graph.h
@@ -21,7 +21,7 @@
 #pragma warning(pop)
 #endif
 
-#include "flatbuffers/flatbuffers.h"
+#include "core/common/flatbuffers.h"
 
 #include "core/common/gsl.h"
 
diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
index 03715eb5b78b2..55abb90b981f5 100644
--- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
+++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
@@ -28,9 +28,12 @@ enum COREMLFlags {
   // dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes.
   COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008,
 
+  // Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later.
+  COREML_FLAG_CREATE_MLPROGRAM = 0x010,
+
   // Keep COREML_FLAG_LAST at the end of the enum definition
   // And assign the last COREMLFlag to it
-  COREML_FLAG_LAST = COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES,
+  COREML_FLAG_LAST = COREML_FLAG_CREATE_MLPROGRAM,
 };
 
 #ifdef __cplusplus
diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h
index 108173474db46..7104e70c3a8a9 100644
--- a/include/onnxruntime/core/providers/cuda/cuda_context.h
+++ b/include/onnxruntime/core/providers/cuda/cuda_context.h
@@ -58,7 +58,7 @@ struct CudaContext : public CustomOpContext {
 
   template <typename T>
   T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) {
-    if (sizeof(T) > sizeof(void*)) {
+    if constexpr (sizeof(T) > sizeof(void*)) {
       ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT);
     }
     const auto& ort_api = Ort::GetApi();
diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h
index 1fef077860be3..00e7dec5727d1 100644
--- a/include/onnxruntime/core/providers/cuda/cuda_resource.h
+++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h
@@ -19,4 +19,4 @@ enum CudaResource : int {
   enable_skip_layer_norm_strict_mode_t,
   prefer_nhwc_t,
   use_tf32_t,
-};
\ No newline at end of file
+};
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 5577c840c5379..41b034e9c1dcc 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -1837,14 +1837,28 @@ struct OrtApi {
 
   /** \brief Used for custom operators, get an input of a kernel
    *
-   * \see ::OrtCustomOp
+   * The function attempts fetches the input of the kernel. If the input is optional
+   * and not present, the function returns success and out is set to nullptr.
+   *
+   * \param[in] context ::OrtKernelContext instance
+   * \param[in] input index. See KernelContext_GetInputCount for boundaries check.
+   * \param[in, out] returns a ptr to OrtValue if the input is present
+   *
+   * \snippet{doc} snippets.dox OrtStatus Return Value
    */
   ORT_API2_STATUS(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index,
                   _Out_ const OrtValue** out);
 
   /** \brief Used for custom operators, get an output of a kernel
    *
-   * \see ::OrtCustomOp
+   * The function attempts fetches the output of the kernel. If the output is optional
+   * and not present, the function returns success and out is set to nullptr.
+   *
+   * \param[in] context ::OrtKernelContext instance
+   * \param[in] output index. See KernelContext_GetOutputCount for boundaries check.
+   * \param[in, out] returns a ptr to OrtValue if the output is present
+   *
+   * \snippet{doc} snippets.dox OrtStatus Return Value
    */
   ORT_API2_STATUS(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index,
                   _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out);
@@ -3619,6 +3633,10 @@ struct OrtApi {
    *     - "73"
    *     - "75"
    *   "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device).
+       "enable_htp_fp16_precision": Only used for float32 model.
+       Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision.
+         - "0": Default. With fp32 precision.
+         - "1": With fp16 precision.
    *
    * SNPE supported keys:
    *   "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
@@ -4586,6 +4604,26 @@ struct OrtApi {
                   _In_reads_(num_keys) const char* const* provider_options_keys,
                   _In_reads_(num_keys) const char* const* provider_options_values,
                   _In_ size_t num_keys);
+
+  /** \brief Get scratch buffer from the corresponding allocator under the sepcific OrtMemoryInfo object.
+   *         NOTE: callers are responsible to release this scratch buffer from the corresponding allocator
+   *  \param[in] context OrtKernelContext instance
+   *  \param[in] mem_info OrtMemoryInfo instance
+   *  \param[in] count_or_bytes How many bytes is this scratch buffer
+   *  \param[out] out A pointer to the scrach buffer
+   *  \snippet{doc} snippets.dox OrtStatus Return Value
+   */
+  ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
+
+  /** \brief Get allocator from KernelInfo for a specific memory type. Please use C API ReleaseAllocator to release out object
+   *
+   * \param[in] info OrtKernelInfo instance
+   * \param[in] mem_type OrtMemType object
+   * \param[out] out A pointer to OrtAllocator
+   *
+   * \snippet{doc} snippets.dox OrtStatus Return Value
+   */
+  ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out);
 };
 
 /*
@@ -4683,6 +4721,13 @@ struct OrtCustomOp {
   // Get start range
   int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op);
   int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op);
+
+  // Get the inplace_map that defines which output can reuse which input
+  // Callers will provide 2 raw int* and pass in their address, this function will fill these 2 arrays
+  // when return, output (*output_index)[i] may reuse the input (*input_index[i]).
+  // The return value is the size of these 2 arrays.
+  // Callers are responsible to delete these 2 arrays after use.
+  size_t(ORT_API_CALL* GetMayInplace)(_Out_ int** input_index, _Out_ int** output_index);
 };
 
 /*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index ae4c4bef90c64..60540514fbfa6 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -2055,7 +2055,11 @@ struct KernelContext {
   explicit KernelContext(OrtKernelContext* context);
   size_t GetInputCount() const;
   size_t GetOutputCount() const;
+  // If input is optional and is not present, the method returns en empty ConstValue
+  // which can be compared to nullptr.
   ConstValue GetInput(size_t index) const;
+  // If outout is optional and is not present, the method returns en empty UnownedValue
+  // which can be compared to nullptr.
   UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
   UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
   void* GetGPUComputeStream() const;
diff --git a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h
index 1f5fcd50e185c..c80b8c0c164b6 100644
--- a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h
+++ b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h
@@ -30,3 +30,22 @@ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memor
 // Per default it will be set to '0'
 // Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream.
 static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers";
+
+// Set HTP performance mode for QNN HTP backend before session run.
+// options for HTP performance mode: "burst", "balanced", "default", "high_performance",
+// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver",
+// "sustained_high_performance". Default to "default".
+static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode";
+
+// Set HTP performance mode for QNN HTP backend post session run.
+static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run";
+
+// Set RPC control latency for QNN HTP backend
+static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency";
+
+// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true.
+// The value should be an integer. If the value is not set, the default value is 0 and
+// ORT session only captures one cuda graph before another capture is requested.
+// If the value is set to -1, cuda graph capture/replay is disabled in that run.
+// User are not expected to set the value to 0 as it is reserved for internal use.
+static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id";
diff --git a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java
index eb124decf75f3..cec3fadf446ca 100644
--- a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java
+++ b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved.
  * Licensed under the MIT License.
  */
 package ai.onnxruntime.providers;
@@ -14,7 +14,18 @@ public enum CoreMLFlags implements OrtFlags {
   /** Enables CoreML on subgraphs. */
   ENABLE_ON_SUBGRAPH(2), // COREML_FLAG_ENABLE_ON_SUBGRAPH(0x002)
   /** Only enable usage of CoreML if the device has an Apple Neural Engine. */
-  ONLY_ENABLE_DEVICE_WITH_ANE(4); // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004),
+  ONLY_ENABLE_DEVICE_WITH_ANE(4), // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004)
+  /**
+   * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also
+   * allow inputs with dynamic shapes. However, the performance may be negatively impacted if inputs
+   * have dynamic shapes.
+   */
+  ONLY_ALLOW_STATIC_INPUT_SHAPES(8), // COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES(0x008)
+  /**
+   * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or
+   * later.
+   */
+  CREATE_MLPROGRAM(16); // COREML_FLAG_CREATE_MLPROGRAM(0x010)
 
   /** The native value of the enum. */
   public final int value;
diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java
index 7fef2dc784b7b..9925197e4507c 100644
--- a/java/src/test/java/ai/onnxruntime/InferenceTest.java
+++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java
@@ -673,7 +673,7 @@ private void runProvider(OrtProvider provider) throws OrtException {
           // CoreML gives slightly different answers on a 2020 13" M1 MBP
           assertArrayEquals(expectedOutput, resultArray, 1e-2f);
         } else {
-          assertArrayEquals(expectedOutput, resultArray, 1e-6f);
+          assertArrayEquals(expectedOutput, resultArray, 1e-5f);
         }
       } catch (OrtException e) {
         throw new IllegalStateException("Failed to execute a scoring operation", e);
diff --git a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java
index 1ed883ace36e5..0e3bc15ba9c70 100644
--- a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java
+++ b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java
@@ -96,7 +96,7 @@ private static void runProvider(OrtProvider provider, OrtSession.SessionOptions
         OnnxValue resultTensor = result.get(0);
         float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue());
         assertEquals(expectedOutput.length, resultArray.length);
-        assertArrayEquals(expectedOutput, resultArray, 1e-6f);
+        assertArrayEquals(expectedOutput, resultArray, 1e-5f);
       } catch (OrtException e) {
         throw new IllegalStateException("Failed to execute a scoring operation", e);
       }
diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts
index 3e1e833addb91..e90efd7b97c29 100644
--- a/js/common/lib/backend-impl.ts
+++ b/js/common/lib/backend-impl.ts
@@ -2,6 +2,7 @@
 // Licensed under the MIT License.
 
 import {Backend} from './backend.js';
+import {InferenceSession} from './inference-session.js';
 
 interface BackendInfo {
   backend: Backend;
@@ -10,6 +11,7 @@ interface BackendInfo {
   initPromise?: Promise<void>;
   initialized?: boolean;
   aborted?: boolean;
+  error?: string;
 }
 
 const backends: Map<string, BackendInfo> = new Map();
@@ -60,43 +62,100 @@ export const registerBackend = (name: string, backend: Backend, priority: number
 };
 
 /**
- * Resolve backend by specified hints.
+ * Try to resolve and initialize a backend.
  *
- * @param backendHints - a list of execution provider names to lookup. If omitted use registered backends as list.
- * @returns a promise that resolves to the backend.
+ * @param backendName - the name of the backend.
+ * @returns the backend instance if resolved and initialized successfully, or an error message if failed.
+ */
+const tryResolveAndInitializeBackend = async(backendName: string): Promise<Backend|string> => {
+  const backendInfo = backends.get(backendName);
+  if (!backendInfo) {
+    return 'backend not found.';
+  }
+
+  if (backendInfo.initialized) {
+    return backendInfo.backend;
+  } else if (backendInfo.aborted) {
+    return backendInfo.error!;
+  } else {
+    const isInitializing = !!backendInfo.initPromise;
+    try {
+      if (!isInitializing) {
+        backendInfo.initPromise = backendInfo.backend.init(backendName);
+      }
+      await backendInfo.initPromise;
+      backendInfo.initialized = true;
+      return backendInfo.backend;
+    } catch (e) {
+      if (!isInitializing) {
+        backendInfo.error = `${e}`;
+        backendInfo.aborted = true;
+      }
+      return backendInfo.error!;
+    } finally {
+      delete backendInfo.initPromise;
+    }
+  }
+};
+
+/**
+ * Resolve execution providers from the specific session options.
+ *
+ * @param options - the session options object.
+ * @returns a promise that resolves to a tuple of an initialized backend instance and a session options object with
+ * filtered EP list.
  *
  * @ignore
  */
-export const resolveBackend = async(backendHints: readonly string[]): Promise<Backend> => {
-  const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints;
-  const errors = [];
-  for (const backendName of backendNames) {
-    const backendInfo = backends.get(backendName);
-    if (backendInfo) {
-      if (backendInfo.initialized) {
-        return backendInfo.backend;
-      } else if (backendInfo.aborted) {
-        continue;  // current backend is unavailable; try next
-      }
+export const resolveBackendAndExecutionProviders = async(options: InferenceSession.SessionOptions):
+    Promise<[backend: Backend, options: InferenceSession.SessionOptions]> => {
+      // extract backend hints from session options
+      const eps = options.executionProviders || [];
+      const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
+      const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints;
 
-      const isInitializing = !!backendInfo.initPromise;
-      try {
-        if (!isInitializing) {
-          backendInfo.initPromise = backendInfo.backend.init(backendName);
+      // try to resolve and initialize all requested backends
+      let backend: Backend|undefined;
+      const errors = [];
+      const availableBackendNames = new Set<string>();
+      for (const backendName of backendNames) {
+        const resolveResult = await tryResolveAndInitializeBackend(backendName);
+        if (typeof resolveResult === 'string') {
+          errors.push({name: backendName, err: resolveResult});
+        } else {
+          if (!backend) {
+            backend = resolveResult;
+          }
+          if (backend === resolveResult) {
+            availableBackendNames.add(backendName);
+          }
         }
-        await backendInfo.initPromise;
-        backendInfo.initialized = true;
-        return backendInfo.backend;
-      } catch (e) {
-        if (!isInitializing) {
-          errors.push({name: backendName, err: e});
+      }
+
+      // if no backend is available, throw error.
+      if (!backend) {
+        throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`);
+      }
+
+      // for each explicitly requested backend, if it's not available, output warning message.
+      for (const {name, err} of errors) {
+        if (backendHints.includes(name)) {
+          // eslint-disable-next-line no-console
+          console.warn(`removing requested execution provider "${
+              name}" from session options because it is not available: ${err}`);
         }
-        backendInfo.aborted = true;
-      } finally {
-        delete backendInfo.initPromise;
       }
-    }
-  }
 
-  throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`);
-};
+      const filteredEps = eps.filter(i => availableBackendNames.has(typeof i === 'string' ? i : i.name));
+
+      return [
+        backend, new Proxy(options, {
+          get: (target, prop) => {
+            if (prop === 'executionProviders') {
+              return filteredEps;
+            }
+            return Reflect.get(target, prop);
+          }
+        })
+      ];
+    };
diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts
index 9bfcb12206057..8c07bdd5c5c4a 100644
--- a/js/common/lib/backend.ts
+++ b/js/common/lib/backend.ts
@@ -58,7 +58,7 @@ export interface TrainingSessionHandler extends SessionHandler {
       options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
 
   getParametersSize(trainableOnly: boolean): Promise<number>;
-  loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
+  loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise<void>;
   getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
 }
 
@@ -77,8 +77,8 @@ export interface Backend {
       Promise<InferenceSessionHandler>;
 
   createTrainingSessionHandler?
-      (checkpointStateUriOrBuffer: TrainingSession.URIorBuffer, trainModelUriOrBuffer: TrainingSession.URIorBuffer,
-       evalModelUriOrBuffer: TrainingSession.URIorBuffer, optimizerModelUriOrBuffer: TrainingSession.URIorBuffer,
+      (checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer, trainModelUriOrBuffer: TrainingSession.UriOrBuffer,
+       evalModelUriOrBuffer: TrainingSession.UriOrBuffer, optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer,
        options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler>;
 }
 
diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts
index 6299c26159400..c8df1613b3268 100644
--- a/js/common/lib/env.ts
+++ b/js/common/lib/env.ts
@@ -36,6 +36,7 @@ export declare namespace Env {
     /**
      * set or get a boolean value indicating whether to enable trace.
      *
+     * @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored.
      * @defaultValue `false`
      */
     trace?: boolean;
@@ -142,13 +143,52 @@ export declare namespace Env {
        */
       ondata?: (data: WebGpuProfilingData) => void;
     };
+    /**
+     * Set or get the power preference.
+     *
+     * Setting this property only has effect before the first WebGPU inference session is created. The value will be
+     * used as options for `navigator.gpu.requestAdapter()`.
+     *
+     * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details.
+     *
+     * @defaultValue `undefined`
+     */
+    powerPreference?: 'low-power'|'high-performance';
+    /**
+     * Set or get the force fallback adapter flag.
+     *
+     * Setting this property only has effect before the first WebGPU inference session is created. The value will be
+     * used as options for `navigator.gpu.requestAdapter()`.
+     *
+     * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details.
+     *
+     * @defaultValue `undefined`
+     */
+    forceFallbackAdapter?: boolean;
+    /**
+     * Set or get the adapter for WebGPU.
+     *
+     * Setting this property only has effect before the first WebGPU inference session is created. The value will be
+     * used as the GPU adapter for the underlying WebGPU backend to create GPU device.
+     *
+     * If this property is not set, it will be available to get after the first WebGPU inference session is created. The
+     * value will be the GPU adapter that created by the underlying WebGPU backend.
+     *
+     * When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types".
+     * Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type.
+     *
+     * see comments on {@link Tensor.GpuBufferType}
+     */
+    adapter: unknown;
     /**
      * Get the device for WebGPU.
      *
+     * This property is only available after the first WebGPU inference session is created.
+     *
      * When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types".
      * Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type.
      *
-     * see comments on {@link GpuBufferType} for more details about why not use types defined in "@webgpu/types".
+     * see comments on {@link Tensor.GpuBufferType} for more details about why not use types defined in "@webgpu/types".
      */
     readonly device: unknown;
     /**
@@ -167,6 +207,7 @@ export interface Env {
    * @defaultValue `'warning'`
    */
   logLevel?: 'verbose'|'info'|'warning'|'error'|'fatal';
+
   /**
    * Indicate whether run in debug mode.
    *
@@ -174,6 +215,13 @@ export interface Env {
    */
   debug?: boolean;
 
+  /**
+   * set or get a boolean value indicating whether to enable trace.
+   *
+   * @defaultValue `false`
+   */
+  trace?: boolean;
+
   /**
    * Get version of the current package.
    */
diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts
index d7c98380f3fa4..3ed56b3c2e812 100644
--- a/js/common/lib/index.ts
+++ b/js/common/lib/index.ts
@@ -11,7 +11,7 @@
  * - [onnxruntime-react-native](https://www.npmjs.com/package/onnxruntime-react-native)
  *
  * See also:
- * - [Get Started](https://onnxruntime.ai/docs/get-started/with-javascript.html)
+ * - [Get Started](https://onnxruntime.ai/docs/get-started/with-javascript/)
  * - [Inference examples](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/js)
  *
  * @packageDocumentation
@@ -21,6 +21,9 @@ export * from './backend.js';
 export * from './env.js';
 export * from './inference-session.js';
 export * from './tensor.js';
+export * from './tensor-conversion.js';
+export * from './tensor-factory.js';
 export * from './trace.js';
+export * from './onnx-model.js';
 export * from './onnx-value.js';
 export * from './training-session.js';
diff --git a/js/common/lib/inference-session-impl.ts b/js/common/lib/inference-session-impl.ts
index 55f40c8907a89..ab4c6a3e0c46b 100644
--- a/js/common/lib/inference-session-impl.ts
+++ b/js/common/lib/inference-session-impl.ts
@@ -1,7 +1,7 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
-import {resolveBackend} from './backend-impl.js';
+import {resolveBackendAndExecutionProviders} from './backend-impl.js';
 import {InferenceSessionHandler} from './backend.js';
 import {InferenceSession as InferenceSessionInterface} from './inference-session.js';
 import {OnnxValue} from './onnx-value.js';
@@ -195,11 +195,9 @@ export class InferenceSession implements InferenceSessionInterface {
       throw new TypeError('Unexpected argument[0]: must be \'path\' or \'buffer\'.');
     }
 
-    // get backend hints
-    const eps = options.executionProviders || [];
-    const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
-    const backend = await resolveBackend(backendHints);
-    const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, options);
+    // resolve backend, update session options with validated EPs, and create session handler
+    const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options);
+    const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, optionsWithValidatedEPs);
     TRACE_FUNC_END();
     return new InferenceSession(handler);
   }
diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts
index 4f85c3b46e253..4f7fbdcdcf0ca 100644
--- a/js/common/lib/inference-session.ts
+++ b/js/common/lib/inference-session.ts
@@ -186,22 +186,22 @@ export declare namespace InferenceSession {
   // #region execution providers
 
   // Currently, we have the following backends to support execution providers:
-  // Backend Node.js binding: supports 'cpu' and 'cuda'.
+  // Backend Node.js binding: supports 'cpu', 'dml' (win32), 'coreml' (macOS) and 'cuda' (linux).
   // Backend WebAssembly: supports 'cpu', 'wasm', 'webgpu' and 'webnn'.
   // Backend ONNX.js: supports 'webgl'.
   // Backend React Native: supports 'cpu', 'xnnpack', 'coreml' (iOS), 'nnapi' (Android).
   interface ExecutionProviderOptionMap {
+    coreml: CoreMLExecutionProviderOption;
     cpu: CpuExecutionProviderOption;
-    coreml: CoreMlExecutionProviderOption;
     cuda: CudaExecutionProviderOption;
     dml: DmlExecutionProviderOption;
+    nnapi: NnapiExecutionProviderOption;
     tensorrt: TensorRtExecutionProviderOption;
     wasm: WebAssemblyExecutionProviderOption;
     webgl: WebGLExecutionProviderOption;
-    xnnpack: XnnpackExecutionProviderOption;
     webgpu: WebGpuExecutionProviderOption;
     webnn: WebNNExecutionProviderOption;
-    nnapi: NnapiExecutionProviderOption;
+    xnnpack: XnnpackExecutionProviderOption;
   }
 
   type ExecutionProviderName = keyof ExecutionProviderOptionMap;
@@ -219,10 +219,6 @@ export declare namespace InferenceSession {
     readonly name: 'cuda';
     deviceId?: number;
   }
-  export interface CoreMlExecutionProviderOption extends ExecutionProviderOption {
-    readonly name: 'coreml';
-    coreMlFlags?: number;
-  }
   export interface DmlExecutionProviderOption extends ExecutionProviderOption {
     readonly name: 'dml';
     deviceId?: number;
@@ -253,8 +249,39 @@ export declare namespace InferenceSession {
   }
   export interface CoreMLExecutionProviderOption extends ExecutionProviderOption {
     readonly name: 'coreml';
+    /**
+     * The bit flags for CoreML execution provider.
+     *
+     * ```
+     * COREML_FLAG_USE_CPU_ONLY = 0x001
+     * COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002
+     * COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004
+     * COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008
+     * COREML_FLAG_CREATE_MLPROGRAM = 0x010
+     * ```
+     *
+     * See include/onnxruntime/core/providers/coreml/coreml_provider_factory.h for more details.
+     *
+     * This flag is available only in ONNXRuntime (Node.js binding).
+     */
+    coreMlFlags?: number;
+    /**
+     * Specify whether to use CPU only in CoreML EP.
+     *
+     * This setting is available only in ONNXRuntime (react-native).
+     */
     useCPUOnly?: boolean;
+    /**
+     * Specify whether to enable CoreML EP on subgraph.
+     *
+     * This setting is available only in ONNXRuntime (react-native).
+     */
     enableOnSubgraph?: boolean;
+    /**
+     * Specify whether to only enable CoreML EP for Apple devices with ANE (Apple Neural Engine).
+     *
+     * This setting is available only in ONNXRuntime (react-native).
+     */
     onlyEnableDeviceWithANE?: boolean;
   }
   export interface NnapiExecutionProviderOption extends ExecutionProviderOption {
diff --git a/js/common/lib/onnx-value.ts b/js/common/lib/onnx-value.ts
index a16a30d25d839..72369ce8b4209 100644
--- a/js/common/lib/onnx-value.ts
+++ b/js/common/lib/onnx-value.ts
@@ -3,7 +3,7 @@
 
 import {Tensor} from './tensor.js';
 
-type NonTensorType = never;
+export type NonTensorType = never;
 
 /**
  * Type OnnxValue Represents both tensors and non-tensors value for model's inputs/outputs.
diff --git a/js/common/lib/tensor-factory.ts b/js/common/lib/tensor-factory.ts
index 6e19d7fb898a3..431de4c3635c2 100644
--- a/js/common/lib/tensor-factory.ts
+++ b/js/common/lib/tensor-factory.ts
@@ -253,7 +253,7 @@ export interface TensorFactory {
   /**
    * create a tensor from an ImageBitmap object
    *
-   * @param bitMap - the ImageBitmap object to create tensor from
+   * @param bitmap - the ImageBitmap object to create tensor from
    * @param options - An optional object representing options for creating tensor from URL.
    *
    * The following default settings will be applied:
diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts
index c4a43ea27fea1..b29cb8cbd6d35 100644
--- a/js/common/lib/tensor-impl-type-mapping.ts
+++ b/js/common/lib/tensor-impl-type-mapping.ts
@@ -14,7 +14,6 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map<string, SupportedTy
   ['uint8', Uint8Array],
   ['int8', Int8Array],
   ['uint16', Uint16Array],
-  ['float16', Uint16Array],
   ['int16', Int16Array],
   ['int32', Int32Array],
   ['bool', Uint8Array],
@@ -34,16 +33,22 @@ export const NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP = new Map<SupportedTypedArray
   [Uint32Array, 'uint32'],
 ]);
 
-// the following code allows delaying execution of BigInt checking. This allows lazy initialization for
-// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt polyfill
-// if available.
-let isBigIntChecked = false;
-export const checkBigInt = () => {
-  if (!isBigIntChecked) {
-    isBigIntChecked = true;
-    const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function';
-    const isBigUint64ArrayAvailable =
-        typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function';
+// a dummy type declaration for Float16Array in case any polyfill is available.
+declare global {
+  // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any
+  const Float16Array: any;
+}
+
+// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for
+// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array
+// polyfill if available.
+let isTypedArrayChecked = false;
+export const checkTypedArray = () => {
+  if (!isTypedArrayChecked) {
+    isTypedArrayChecked = true;
+    const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from;
+    const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from;
+    const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from;
 
     if (isBigInt64ArrayAvailable) {
       NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array);
@@ -53,5 +58,12 @@ export const checkBigInt = () => {
       NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array);
       NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64');
     }
+    if (isFloat16ArrayAvailable) {
+      NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Float16Array);
+      NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(Float16Array, 'float16');
+    } else {
+      // if Float16Array is not available, use 'Uint16Array' to store the data.
+      NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Uint16Array);
+    }
   }
 };
diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts
index e3e2b9c728556..56682ef98e117 100644
--- a/js/common/lib/tensor-impl.ts
+++ b/js/common/lib/tensor-impl.ts
@@ -5,7 +5,7 @@ import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js';
 import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js';
 import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js';
 import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js';
-import {checkBigInt, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
+import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
 import {calculateSize, tensorReshape} from './tensor-utils-impl.js';
 import {Tensor as TensorInterface} from './tensor.js';
 
@@ -67,8 +67,8 @@ export class Tensor implements TensorInterface {
       arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters|
       TextureConstructorParameters|GpuBufferConstructorParameters,
       arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) {
-    // perform one-time check for BigInt support
-    checkBigInt();
+    // perform one-time check for BigInt/Float16Array support
+    checkTypedArray();
 
     let type: TensorType;
     let dims: readonly number[];
@@ -103,7 +103,7 @@ export class Tensor implements TensorInterface {
         }
         case 'gpu-buffer': {
           if ((type !== 'float32' && type !== 'float16' && type !== 'int32' && type !== 'int64' && type !== 'uint32' &&
-               type !== 'bool')) {
+               type !== 'uint8' && type !== 'bool')) {
             throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`);
           }
           this.gpuBufferData = arg0.gpuBuffer;
@@ -142,7 +142,9 @@ export class Tensor implements TensorInterface {
             throw new TypeError(`Unsupported tensor type: ${arg0}.`);
           }
           if (Array.isArray(arg1)) {
-            if (arg0 === 'float16') {
+            if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) {
+              // When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array.
+              //
               // Throw error here because when user try to use number array as data,
               // e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call
               // Uint16Array.from(arg1) which generates wrong data.
diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts
index 6c08d1fe8e057..20319ebb800c2 100644
--- a/js/common/lib/tensor.ts
+++ b/js/common/lib/tensor.ts
@@ -135,7 +135,7 @@ export declare namespace Tensor {
   /**
    * supported data types for constructing a tensor from a WebGPU buffer
    */
-  export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'bool';
+  export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'uint8'|'bool';
 
   /**
    * represent where the tensor data is stored
@@ -160,7 +160,7 @@ export interface Tensor extends TypedTensorBase<Tensor.Type>, TypedTensorUtils<T
 /**
  * type TensorConstructor defines the constructors of 'Tensor' to create CPU tensor instances.
  */
-export interface TensorConstructor {
+export interface TensorConstructor extends TensorFactory {
   // #region CPU tensor - specify element type
   /**
    * Construct a new string tensor object from the given type, data and dims.
@@ -326,4 +326,4 @@ export interface TensorConstructor {
 }
 
 // eslint-disable-next-line @typescript-eslint/naming-convention
-export const Tensor = TensorImpl as (TensorConstructor & TensorFactory);
+export const Tensor = TensorImpl as TensorConstructor;
diff --git a/js/common/lib/trace.ts b/js/common/lib/trace.ts
index 404f7ef8089af..44ad6cacb4bb4 100644
--- a/js/common/lib/trace.ts
+++ b/js/common/lib/trace.ts
@@ -3,8 +3,11 @@
 
 import {env} from './env-impl.js';
 
+/**
+ * @ignore
+ */
 export const TRACE = (deviceType: string, label: string) => {
-  if (!env.wasm.trace) {
+  if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
     return;
   }
   // eslint-disable-next-line no-console
@@ -29,15 +32,21 @@ const TRACE_FUNC = (msg: string, extraMsg?: string) => {
   }
 };
 
+/**
+ * @ignore
+ */
 export const TRACE_FUNC_BEGIN = (extraMsg?: string) => {
-  if (!env.wasm.trace) {
+  if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
     return;
   }
   TRACE_FUNC('BEGIN', extraMsg);
 };
 
+/**
+ * @ignore
+ */
 export const TRACE_FUNC_END = (extraMsg?: string) => {
-  if (!env.wasm.trace) {
+  if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
     return;
   }
   TRACE_FUNC('END', extraMsg);
diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts
index 23bd4421ae672..bae38b0dfda5a 100644
--- a/js/common/lib/training-session-impl.ts
+++ b/js/common/lib/training-session-impl.ts
@@ -1,7 +1,7 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
-import {resolveBackend} from './backend-impl.js';
+import {resolveBackendAndExecutionProviders} from './backend-impl.js';
 import {SessionHandler, TrainingSessionHandler} from './backend.js';
 import {InferenceSession as InferenceSession} from './inference-session.js';
 import {OnnxValue} from './onnx-value.js';
@@ -55,13 +55,12 @@ export class TrainingSession implements TrainingSessionInterface {
     const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || '';
     const options: SessionOptions = sessionOptions || {};
 
-    // get backend hints
-    const eps = options.executionProviders || [];
-    const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
-    const backend = await resolveBackend(backendHints);
+    // resolve backend, update session options with validated EPs, and create session handler
+    const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options);
     if (backend.createTrainingSessionHandler) {
       const handler = await backend.createTrainingSessionHandler(
-          trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
+          trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel,
+          optionsWithValidatedEPs);
       return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
     } else {
       throw new Error(noBackendErrMsg);
diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts
index e54aed90e702c..f9de77e3ac7d0 100644
--- a/js/common/lib/training-session.ts
+++ b/js/common/lib/training-session.ts
@@ -11,7 +11,7 @@ export declare namespace TrainingSession {
   /**
    * Either URI file path (string) or Uint8Array containing model or checkpoint information.
    */
-  type URIorBuffer = string|Uint8Array;
+  type UriOrBuffer = string|Uint8Array;
 }
 
 /**
@@ -98,13 +98,13 @@ export interface TrainingSession {
   getParametersSize(trainableOnly: boolean): Promise<number>;
 
   /**
-   * Copies parameter values from the given array to the training state. Currently, only supporting models with
+   * Copies parameter values from the given buffer to the training state. Currently, only supporting models with
    * parameters of type Float32.
    *
-   * @param buffer - Float32 buffer containing parameters converted to a Uint8Array.
+   * @param buffer - A Uint8Array representation of Float32 parameters.
    * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true.
    */
-  loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
+  loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise<void>;
 
   /**
    * Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning.
@@ -157,19 +157,19 @@ export interface TrainingSessionCreateOptions {
   /**
    * URI or buffer for a .ckpt file that contains the checkpoint for the training model.
    */
-  checkpointState: TrainingSession.URIorBuffer;
+  checkpointState: TrainingSession.UriOrBuffer;
   /**
    * URI or buffer for the .onnx training file.
    */
-  trainModel: TrainingSession.URIorBuffer;
+  trainModel: TrainingSession.UriOrBuffer;
   /**
    * Optional. URI or buffer for the .onnx optimizer model file.
    */
-  optimizerModel?: TrainingSession.URIorBuffer;
+  optimizerModel?: TrainingSession.UriOrBuffer;
   /**
    * Optional. URI or buffer for the .onnx eval model file.
    */
-  evalModel?: TrainingSession.URIorBuffer;
+  evalModel?: TrainingSession.UriOrBuffer;
 }
 
 /**
diff --git a/js/common/package-lock.json b/js/common/package-lock.json
index a5ada877b916a..3988ac80707e0 100644
--- a/js/common/package-lock.json
+++ b/js/common/package-lock.json
@@ -9,13 +9,13 @@
       "version": "1.18.0",
       "license": "MIT",
       "devDependencies": {
-        "typedoc": "^0.23.22"
+        "typedoc": "^0.25.7"
       }
     },
     "node_modules/ansi-sequence-parser": {
-      "version": "1.1.0",
-      "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz",
-      "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==",
+      "version": "1.1.1",
+      "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz",
+      "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==",
       "dev": true
     },
     "node_modules/balanced-match": {
@@ -34,9 +34,9 @@
       }
     },
     "node_modules/jsonc-parser": {
-      "version": "3.2.0",
-      "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz",
-      "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==",
+      "version": "3.2.1",
+      "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz",
+      "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==",
       "dev": true
     },
     "node_modules/lunr": {
@@ -46,9 +46,9 @@
       "dev": true
     },
     "node_modules/marked": {
-      "version": "4.2.12",
-      "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz",
-      "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==",
+      "version": "4.3.0",
+      "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz",
+      "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==",
       "dev": true,
       "bin": {
         "marked": "bin/marked.js"
@@ -58,24 +58,24 @@
       }
     },
     "node_modules/minimatch": {
-      "version": "7.4.2",
-      "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz",
-      "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==",
+      "version": "9.0.3",
+      "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz",
+      "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==",
       "dev": true,
       "dependencies": {
         "brace-expansion": "^2.0.1"
       },
       "engines": {
-        "node": ">=10"
+        "node": ">=16 || 14 >=14.17"
       },
       "funding": {
         "url": "https://github.com/sponsors/isaacs"
       }
     },
     "node_modules/shiki": {
-      "version": "0.14.1",
-      "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz",
-      "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==",
+      "version": "0.14.7",
+      "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz",
+      "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==",
       "dev": true,
       "dependencies": {
         "ansi-sequence-parser": "^1.1.0",
@@ -85,30 +85,30 @@
       }
     },
     "node_modules/typedoc": {
-      "version": "0.23.26",
-      "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz",
-      "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==",
+      "version": "0.25.7",
+      "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz",
+      "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==",
       "dev": true,
       "dependencies": {
         "lunr": "^2.3.9",
-        "marked": "^4.2.12",
-        "minimatch": "^7.1.3",
-        "shiki": "^0.14.1"
+        "marked": "^4.3.0",
+        "minimatch": "^9.0.3",
+        "shiki": "^0.14.7"
       },
       "bin": {
         "typedoc": "bin/typedoc"
       },
       "engines": {
-        "node": ">= 14.14"
+        "node": ">= 16"
       },
       "peerDependencies": {
-        "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x"
+        "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x || 5.0.x || 5.1.x || 5.2.x || 5.3.x"
       }
     },
     "node_modules/typescript": {
-      "version": "4.9.5",
-      "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz",
-      "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==",
+      "version": "5.2.2",
+      "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz",
+      "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==",
       "dev": true,
       "peer": true,
       "bin": {
@@ -116,7 +116,7 @@
         "tsserver": "bin/tsserver"
       },
       "engines": {
-        "node": ">=4.2.0"
+        "node": ">=14.17"
       }
     },
     "node_modules/vscode-oniguruma": {
@@ -134,9 +134,9 @@
   },
   "dependencies": {
     "ansi-sequence-parser": {
-      "version": "1.1.0",
-      "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz",
-      "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==",
+      "version": "1.1.1",
+      "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz",
+      "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==",
       "dev": true
     },
     "balanced-match": {
@@ -155,9 +155,9 @@
       }
     },
     "jsonc-parser": {
-      "version": "3.2.0",
-      "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz",
-      "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==",
+      "version": "3.2.1",
+      "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz",
+      "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==",
       "dev": true
     },
     "lunr": {
@@ -167,24 +167,24 @@
       "dev": true
     },
     "marked": {
-      "version": "4.2.12",
-      "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz",
-      "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==",
+      "version": "4.3.0",
+      "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz",
+      "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==",
       "dev": true
     },
     "minimatch": {
-      "version": "7.4.2",
-      "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz",
-      "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==",
+      "version": "9.0.3",
+      "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz",
+      "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==",
       "dev": true,
       "requires": {
         "brace-expansion": "^2.0.1"
       }
     },
     "shiki": {
-      "version": "0.14.1",
-      "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz",
-      "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==",
+      "version": "0.14.7",
+      "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz",
+      "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==",
       "dev": true,
       "requires": {
         "ansi-sequence-parser": "^1.1.0",
@@ -194,21 +194,21 @@
       }
     },
     "typedoc": {
-      "version": "0.23.26",
-      "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz",
-      "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==",
+      "version": "0.25.7",
+      "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz",
+      "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==",
       "dev": true,
       "requires": {
         "lunr": "^2.3.9",
-        "marked": "^4.2.12",
-        "minimatch": "^7.1.3",
-        "shiki": "^0.14.1"
+        "marked": "^4.3.0",
+        "minimatch": "^9.0.3",
+        "shiki": "^0.14.7"
       }
     },
     "typescript": {
-      "version": "4.9.5",
-      "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz",
-      "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==",
+      "version": "5.2.2",
+      "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz",
+      "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==",
       "dev": true,
       "peer": true
     },
diff --git a/js/common/package.json b/js/common/package.json
index 64ab2736adbe3..cd2612aab4984 100644
--- a/js/common/package.json
+++ b/js/common/package.json
@@ -9,7 +9,7 @@
   },
   "author": "fs-eire",
   "scripts": {
-    "build:cjs": "tsc --module commonjs --outDir ./dist/cjs",
+    "build:cjs": "tsc --module commonjs --moduleResolution node10 --outDir ./dist/cjs",
     "build:esm": "tsc",
     "build:bundles": "webpack",
     "build": "node ./build.js",
@@ -18,7 +18,7 @@
     "test": "mocha ./test/**/*.js --timeout 30000"
   },
   "devDependencies": {
-    "typedoc": "^0.23.22"
+    "typedoc": "^0.25.7"
   },
   "main": "dist/cjs/index.js",
   "exports": {
diff --git a/js/common/test/tsconfig.json b/js/common/test/tsconfig.json
index 2e4927ac3b325..e9068ad837a81 100644
--- a/js/common/test/tsconfig.json
+++ b/js/common/test/tsconfig.json
@@ -2,7 +2,7 @@
   "extends": "../../tsconfig.tools.json",
   "exclude": ["type-tests/**/*.ts"],
   "compilerOptions": {
-    "module": "ES2022",
+    "module": "Node16",
     "sourceMap": true
   }
 }
diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts
index e8eb0e9babf5a..927953b4f1dd6 100644
--- a/js/node/lib/backend.ts
+++ b/js/node/lib/backend.ts
@@ -36,7 +36,7 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler {
   async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions):
       Promise<SessionHandler.ReturnType> {
     return new Promise((resolve, reject) => {
-      process.nextTick(() => {
+      setImmediate(() => {
         try {
           resolve(this.#inferenceSession.run(feeds, fetches, options));
         } catch (e) {
@@ -56,7 +56,7 @@ class OnnxruntimeBackend implements Backend {
   async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions):
       Promise<InferenceSessionHandler> {
     return new Promise((resolve, reject) => {
-      process.nextTick(() => {
+      setImmediate(() => {
         try {
           resolve(new OnnxruntimeSessionHandler(pathOrBuffer, options || {}));
         } catch (e) {
diff --git a/js/node/package-lock.json b/js/node/package-lock.json
index 2d7c39c86097f..62b47698a1438 100644
--- a/js/node/package-lock.json
+++ b/js/node/package-lock.json
@@ -30,7 +30,7 @@
       "version": "1.18.0",
       "license": "MIT",
       "devDependencies": {
-        "typedoc": "^0.23.22"
+        "typedoc": "^0.25.7"
       }
     },
     "node_modules/@protobufjs/aspromise": {
@@ -336,9 +336,9 @@
       "dev": true
     },
     "node_modules/follow-redirects": {
-      "version": "1.15.4",
-      "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz",
-      "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==",
+      "version": "1.15.6",
+      "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz",
+      "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==",
       "dev": true,
       "funding": [
         {
@@ -1242,9 +1242,9 @@
       "dev": true
     },
     "follow-redirects": {
-      "version": "1.15.4",
-      "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz",
-      "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==",
+      "version": "1.15.6",
+      "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz",
+      "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==",
       "dev": true
     },
     "form-data": {
@@ -1503,7 +1503,7 @@
     "onnxruntime-common": {
       "version": "file:../common",
       "requires": {
-        "typedoc": "^0.23.22"
+        "typedoc": "^0.25.7"
       }
     },
     "parse-json": {
diff --git a/js/react_native/e2e/yarn.lock b/js/react_native/e2e/yarn.lock
index 9e20a286c4e27..6f05faf046098 100644
--- a/js/react_native/e2e/yarn.lock
+++ b/js/react_native/e2e/yarn.lock
@@ -3351,9 +3351,9 @@ invariant@^2.2.4:
     loose-envify "^1.0.0"
 
 ip@^1.1.5:
-  version "1.1.8"
-  resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48"
-  integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg==
+  version "1.1.9"
+  resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396"
+  integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ==
 
 is-accessor-descriptor@^0.1.6:
   version "0.1.6"
diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock
index 4dca90d7415cf..bbb0c4f3d1e22 100644
--- a/js/react_native/yarn.lock
+++ b/js/react_native/yarn.lock
@@ -3701,9 +3701,9 @@ invariant@^2.2.4:
     loose-envify "^1.0.0"
 
 ip@^1.1.5:
-  version "1.1.8"
-  resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48"
-  integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg==
+  version "1.1.9"
+  resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396"
+  integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ==
 
 is-absolute@^1.0.0:
   version "1.0.0"
diff --git a/js/web/README.md b/js/web/README.md
index c75a40ad6da28..906c78a1b7ec4 100644
--- a/js/web/README.md
+++ b/js/web/README.md
@@ -12,7 +12,7 @@ The [Open Neural Network Exchange](http://onnx.ai/) (ONNX) is an open standard f
 
 With ONNX Runtime Web, web developers can score models directly on browsers with various benefits including reducing server-client communication and protecting user privacy, as well as offering install-free and cross-platform in-browser ML experience.
 
-ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web complies the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend.
+ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web compiles the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend.
 
 See [Compatibility](#Compatibility) and [Operators Supported](#Operators) for a list of platforms and operators ONNX Runtime Web currently supports.
 
@@ -22,7 +22,7 @@ Refer to [ONNX Runtime JavaScript examples](https://github.com/microsoft/onnxrun
 
 ## Documents
 
-### Developement
+### Development
 
 Refer to the following links for development information:
 
diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md
index b21af8e715db3..4a8c92bb97bfd 100644
--- a/js/web/docs/webgpu-operators.md
+++ b/js/web/docs/webgpu-operators.md
@@ -62,6 +62,7 @@ Do not modify directly.*
 | LessOrEqual | ai.onnx(12-15,16+) |  |
 | Log | ai.onnx(6-12,13+) |  |
 | MatMul | ai.onnx(1-12,13+) |  |
+| MatMulNBits | com.microsoft(1+) |  |
 | MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation |
 | MemcpyFromHost | ai.onnx(1+) |  |
 | MemcpyToHost | ai.onnx(1+) |  |
diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js
index 8fce79843f617..9e44d9c0d9652 100644
--- a/js/web/karma.conf.js
+++ b/js/web/karma.conf.js
@@ -86,11 +86,11 @@ module.exports = function(config) {
     hostname,
     listenAddress,
     customLaunchers: {
-      // the following flags are used to make sure Edge on CI agents to initialize WebGPU correctly.
+      // Chromium-based browsers
       EdgeTest: {base: 'Edge', flags: chromiumFlags},
       ChromeTest: {base: 'Chrome', flags: chromiumFlags},
-      ChromeTestHeadless: {base: 'ChromeHeadless', flags: chromiumFlags},
       ChromeCanaryTest: {base: 'ChromeCanary', flags: chromiumFlags},
+
       //
       // ==== BrowserStack browsers ====
       //
diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts
index 5dd715191c830..56925b728e9a3 100644
--- a/js/web/lib/wasm/binding/ort-wasm.d.ts
+++ b/js/web/lib/wasm/binding/ort-wasm.d.ts
@@ -16,20 +16,97 @@ export declare namespace JSEP {
   type CaptureBeginFunction = () => void;
   type CaptureEndFunction = () => void;
   type ReplayFunction = () => void;
-}
 
-export interface OrtWasmModule extends EmscriptenModule {
-  // #region emscripten functions
-  stackSave(): number;
-  stackRestore(stack: number): void;
-  stackAlloc(size: number): number;
-
-  UTF8ToString(offset: number, maxBytesToRead?: number): string;
-  lengthBytesUTF8(str: string): number;
-  stringToUTF8(str: string, offset: number, maxBytes: number): void;
-  // #endregion
+  export interface Module extends WebGpuModule {
+    /**
+     * Mount the external data file to an internal map, which will be used during session initialization.
+     *
+     * @param externalDataFilePath - specify the relative path of the external data file.
+     * @param externalDataFileData - specify the content data.
+     */
+    mountExternalData(externalDataFilePath: string, externalDataFileData: Uint8Array): void;
+    /**
+     * Unmount all external data files from the internal map.
+     */
+    unmountExternalData(): void;
+
+    /**
+     * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per
+     * backend. This function initializes Asyncify support. If name is 'webgpu', also initializes WebGPU backend and
+     * registers a few callbacks that will be called in C++ code.
+     */
+    jsepInit(name: 'webgpu', initParams: [
+      backend: BackendType, alloc: AllocFunction, free: FreeFunction, upload: UploadFunction,
+      download: DownloadFunction, createKernel: CreateKernelFunction, releaseKernel: ReleaseKernelFunction,
+      run: RunFunction, captureBegin: CaptureBeginFunction, captureEnd: CaptureEndFunction, replay: ReplayFunction
+    ]): void;
+    jsepInit(name: 'webnn', initParams?: never): void;
+  }
+
+  export interface WebGpuModule {
+    /**
+     * [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
+     *
+     * @param context - specify the kernel context pointer.
+     * @param index - specify the index of the output.
+     * @param data - specify the pointer to encoded data of type and dims.
+     */
+    _JsepOutput(context: number, index: number, data: number): number;
+    /**
+     * [exported from wasm] Get name of an operator node.
+     *
+     * @param kernel - specify the kernel pointer.
+     * @returns the pointer to a C-style UTF8 encoded string representing the node name.
+     */
+    _JsepGetNodeName(kernel: number): number;
+
+    /**
+     * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output.
+     *
+     * @param sessionId - specify the session ID.
+     * @param index - specify an integer to represent which input/output it is registering for. For input, it is the
+     *     input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index
+     *     corresponding to the session's ouputNames.
+     * @param buffer - specify the GPU buffer to register.
+     * @param size - specify the original data size in byte.
+     * @returns the GPU data ID for the registered GPU buffer.
+     */
+    jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number;
+    /**
+     * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID.
+     *
+     * @param dataId - specify the GPU data ID
+     * @returns the GPU buffer.
+     */
+    jsepGetBuffer: (dataId: number) => GPUBuffer;
+    /**
+     * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor.
+     *
+     * @param gpuBuffer - specify the GPU buffer
+     * @param size - specify the original data size in byte.
+     * @param type - specify the tensor type.
+     * @returns the generated downloader function.
+     */
+    jsepCreateDownloader:
+        (gpuBuffer: GPUBuffer, size: number,
+         type: Tensor.GpuBufferDataTypes) => () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
+    /**
+     *  [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before
+     * _OrtRun[WithBinding]() is called.
+     * @param sessionId - specify the session ID.
+     */
+    jsepOnRunStart: (sessionId: number) => void;
+    /**
+     * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is
+     * called.
+     * @param sessionId - specify the session ID.
+     * @returns
+     */
+    jsepOnReleaseSession: (sessionId: number) => void;
+  }
+}
 
-  // #region ORT APIs
+export interface OrtInferenceAPIs {
   _OrtInit(numThreads: number, loggingLevel: number): number;
 
   _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void;
@@ -74,126 +151,61 @@ export interface OrtWasmModule extends EmscriptenModule {
   _OrtReleaseRunOptions(runOptionsHandle: number): void;
 
   _OrtEndProfiling(sessionHandle: number): number;
-  // #endregion
+}
+
+export interface OrtTrainingAPIs {
+  _OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number;
 
-  // #region ORT Training APIs
-  _OrtTrainingLoadCheckpoint?(dataOffset: number, dataLength: number): number;
+  _OrtTrainingReleaseCheckpoint(checkpointHandle: number): void;
 
-  _OrtTrainingReleaseCheckpoint?(checkpointHandle: number): void;
+  _OrtTrainingCreateSession(
+      sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number,
+      evalOffset: number, evalLength: number, optimizerOffset: number, optimizerLength: number): number;
 
-  _OrtTrainingCreateSession?
-      (sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number,
-       evalOffset: number, evalLength: number, optimizerOffset: number, optimizerLength: number): number;
+  _OrtTrainingLazyResetGrad(trainingHandle: number): number;
 
-  _OrtTrainingLazyResetGrad?(trainingHandle: number): number;
+  _OrtTrainingRunTrainStep(
+      trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number,
+      runOptionsHandle: number): number;
 
-  _OrtTrainingRunTrainStep?
-      (trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number,
-       runOptionsHandle: number): number;
+  _OrtTrainingOptimizerStep(trainingHandle: number, runOptionsHandle: number): number;
 
-  _OrtTrainingOptimizerStep?(trainingHandle: number, runOptionsHandle: number): number;
+  _OrtTrainingEvalStep(
+      trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number,
+      runOptionsHandle: number): number;
 
-  _OrtTrainingEvalStep?
-      (trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number,
-       runOptionsHandle: number): number;
+  _OrtTrainingGetParametersSize(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number;
 
-  _OrtTrainingGetParametersSize?(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number;
+  _OrtTrainingCopyParametersToBuffer(
+      trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
 
-  _OrtTrainingCopyParametersToBuffer?
-      (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
+  _OrtTrainingCopyParametersFromBuffer(
+      trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
 
-  _OrtTrainingCopyParametersFromBuffer?
-      (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
+  _OrtTrainingGetModelInputOutputCount(
+      trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number;
+  _OrtTrainingGetModelInputOutputName(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean):
+      number;
+
+  _OrtTrainingReleaseSession(trainingHandle: number): void;
+}
 
-  _OrtTrainingGetModelInputOutputCount?
-      (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number;
-  _OrtTrainingGetModelInputOutputName?
-      (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number;
+export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial<OrtTrainingAPIs>,
+                                       Partial<JSEP.Module> {
+  // #region emscripten functions
+  stackSave(): number;
+  stackRestore(stack: number): void;
+  stackAlloc(size: number): number;
 
-  _OrtTrainingReleaseSession?(trainingHandle: number): void;
+  UTF8ToString(offset: number, maxBytesToRead?: number): string;
+  lengthBytesUTF8(str: string): number;
+  stringToUTF8(str: string, offset: number, maxBytes: number): void;
   // #endregion
 
   // #region config
   numThreads?: number;
   mainScriptUrlOrBlob?: string|Blob;
   // #endregion
-
-  // #region external data API
-  mountExternalData?(externalDataFilePath: string, externalDataFileData: Uint8Array): void;
-  unmountExternalData?(): void;
-  // #endregion
-
-  // #region JSEP
-  /**
-   * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime.
-   * This function initializes WebGPU backend and registers a few callbacks that will be called in C++ code.
-   */
-  jsepInit?
-      (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction,
-       download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction,
-       releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction,
-       captureEnd: JSEP.CaptureEndFunction, replay: JSEP.ReplayFunction): void;
-
-  /**
-   * [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
-   *
-   * @param context - specify the kernel context pointer.
-   * @param index - specify the index of the output.
-   * @param data - specify the pointer to encoded data of type and dims.
-   */
-  _JsepOutput(context: number, index: number, data: number): number;
-  /**
-   * [exported from wasm] Get name of an operator node.
-   *
-   * @param kernel - specify the kernel pointer.
-   * @returns the pointer to a C-style UTF8 encoded string representing the node name.
-   */
-  _JsepGetNodeName(kernel: number): number;
-
-  /**
-   * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output.
-   *
-   * @param sessionId - specify the session ID.
-   * @param index - specify an integer to represent which input/output it is registering for. For input, it is the
-   *     input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index
-   *     corresponding to the session's ouputNames.
-   * @param buffer - specify the GPU buffer to register.
-   * @param size - specify the original data size in byte.
-   * @returns the GPU data ID for the registered GPU buffer.
-   */
-  jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number;
-  /**
-   * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID.
-   *
-   * @param dataId - specify the GPU data ID
-   * @returns the GPU buffer.
-   */
-  jsepGetBuffer: (dataId: number) => GPUBuffer;
-  /**
-   * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor.
-   *
-   * @param gpuBuffer - specify the GPU buffer
-   * @param size - specify the original data size in byte.
-   * @param type - specify the tensor type.
-   * @returns the generated downloader function.
-   */
-  jsepCreateDownloader:
-      (gpuBuffer: GPUBuffer, size: number,
-       type: Tensor.GpuBufferDataTypes) => () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
-  /**
-   *  [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before
-   * _OrtRun[WithBinding]() is called.
-   * @param sessionId - specify the session ID.
-   */
-  jsepOnRunStart: (sessionId: number) => void;
-  /**
-   * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is
-   * called.
-   * @param sessionId - specify the session ID.
-   * @returns
-   */
-  jsepOnReleaseSession: (sessionId: number) => void;
-  // #endregion
 }
 
 declare const moduleFactory: EmscriptenModuleFactory<OrtWasmModule>;
diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts
index 98990a6fe477b..b36dc73330d46 100644
--- a/js/web/lib/wasm/jsep/backend-webgpu.ts
+++ b/js/web/lib/wasm/jsep/backend-webgpu.ts
@@ -10,7 +10,7 @@ import {createView, TensorView} from './tensor-view';
 import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager';
 import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules';
 import {ProgramManager} from './webgpu/program-manager';
-import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types';
+import {AdapterInfo, ComputeContext, GpuArchitecture, GpuData, GpuVendor, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types';
 
 interface CommandInfo {
   readonly kernelId: number;
@@ -94,11 +94,32 @@ const getProgramInfoUniqueKey =
       return key;
     };
 
+class AdapterInfoImpl implements AdapterInfo {
+  readonly architecture?: string;
+  readonly vendor?: string;
+
+  constructor(adapterInfo: GPUAdapterInfo) {
+    if (adapterInfo) {
+      this.architecture = adapterInfo.architecture;
+      this.vendor = adapterInfo.vendor;
+    }
+  }
+
+  isArchitecture(architecture: GpuArchitecture): boolean {
+    return this.architecture === architecture;
+  }
+
+  isVendor(vendor: GpuVendor): boolean {
+    return this.vendor === vendor;
+  }
+}
+
 /**
  * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as
  * the first parameter so that it is stored for future use.
  */
 export class WebGpuBackend {
+  adapterInfo: AdapterInfoImpl;
   device: GPUDevice;
   /**
    * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping
@@ -212,6 +233,7 @@ export class WebGpuBackend {
     }
 
     this.device = await adapter.requestDevice(deviceDescriptor);
+    this.adapterInfo = new AdapterInfoImpl(await adapter.requestAdapterInfo());
     this.gpuDataManager = createGpuDataManager(this);
     this.programManager = new ProgramManager(this);
     this.kernels = new Map();
@@ -230,7 +252,10 @@ export class WebGpuBackend {
       }
     };
 
-    Object.defineProperty(this.env.webgpu, 'device', {value: this.device});
+    Object.defineProperty(
+        this.env.webgpu, 'device', {value: this.device, writable: false, enumerable: true, configurable: false});
+    Object.defineProperty(
+        this.env.webgpu, 'adapter', {value: adapter, writable: false, enumerable: true, configurable: false});
 
     // init queryType, which is necessary for InferenceSession.create
     this.setQueryType();
@@ -385,11 +410,16 @@ export class WebGpuBackend {
     // create info for inputs
     const inputDatas: GpuData[] = [];
     for (let i = 0; i < inputTensorViews.length; ++i) {
-      const gpuData = this.gpuDataManager.get(inputTensorViews[i].data);
+      const data = inputTensorViews[i].data;
+      // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
+      if (data === 0) {
+        continue;
+      }
+      const gpuData = this.gpuDataManager.get(data);
       if (!gpuData) {
-        throw new Error(`no GPU data for input: ${inputTensorViews[i].data}`);
+        throw new Error(`no GPU data for input: ${data}`);
       }
-      inputDatas[i] = gpuData;
+      inputDatas.push(gpuData);
     }
 
     const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews);
@@ -419,6 +449,11 @@ export class WebGpuBackend {
       const tensorView = (isTemporary || isPersistent) ?
           createIntermediateOutput(outputs[i].dataType, outputs[i].dims) :
           createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims);
+      outputTensorViews.push(tensorView);
+      // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
+      if (tensorView.data === 0) {
+        continue;
+      }
       const gpuData = this.gpuDataManager.get(tensorView.data);
       if (!gpuData) {
         throw new Error(`no GPU data for output: ${tensorView.data}`);
@@ -434,10 +469,24 @@ export class WebGpuBackend {
         }
         persistentData.push(gpuData);
       }
-      outputTensorViews.push(tensorView);
       outputDatas.push(gpuData);
     }
 
+    // when there are any zero-sized tensor in the inputs or outputs, we should report error unless all outputs are
+    // zero-sized tensors.
+    if (inputDatas.length !== inputTensorViews.length || outputDatas.length !== outputTensorViews.length) {
+      // if all outputs are zero-sized tensors, there is no need to run the program.
+      if (outputDatas.length === 0) {
+        TRACE_FUNC_END(program.name);
+        return outputTensorViews;
+      }
+      // if some outputs are zero-sized tensors, report an error.
+      //
+      // TODO: so far we don't see any use case that outputs include both zero-sized tensors and non-zero-sized tensors.
+      // If we see such use case, we need to make a change here to support it.
+      throw new Error(
+          `Program ${program.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`);
+    }
 
     // load uniforms
     // TODO: add cache for uniform (is it necessary?)
@@ -686,7 +735,8 @@ export class WebGpuBackend {
   }
   setQueryType(): void {
     this.queryType = 'none';
-    if (this.env.webgpu.profiling?.mode === 'default' || this.env.wasm.trace) {
+    if (this.env.webgpu.profiling?.mode === 'default' ||
+        (typeof this.env.trace === 'undefined' ? this.env.wasm.trace : this.env.trace)) {
       if (this.device.features.has('chromium-experimental-timestamp-query-inside-passes')) {
         this.queryType = 'inside-passes';
       } else if (this.device.features.has('timestamp-query')) {
diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts
index 786ae41646554..adcaa145cdca8 100644
--- a/js/web/lib/wasm/jsep/init.ts
+++ b/js/web/lib/wasm/jsep/init.ts
@@ -10,7 +10,7 @@ import {WebGpuBackend} from './backend-webgpu';
 import {LOG_DEBUG} from './log';
 import {TensorView} from './tensor-view';
 import {ShapeUtil} from './util';
-import {ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types';
+import {AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types';
 
 /* eslint-disable no-bitwise */
 
@@ -54,6 +54,7 @@ class TensorViewImpl implements TensorView {
 }
 
 class ComputeContextImpl implements ComputeContext {
+  readonly adapterInfo: AdapterInfo;
   readonly opKernelContext: number;
   readonly inputs: readonly TensorView[];
   readonly outputCount: number;
@@ -66,6 +67,7 @@ class ComputeContextImpl implements ComputeContext {
   private customDataOffset = 0;
   private customDataSize = 0;
   constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) {
+    this.adapterInfo = backend.adapterInfo;
     const heapU32 = module.HEAPU32;
 
     // extract context data
@@ -104,7 +106,8 @@ class ComputeContextImpl implements ComputeContext {
         throw new Error(`Unsupported data type: ${dataType}`);
       }
       const bufferSize = elementSize * ShapeUtil.size(dims);
-      return new TensorViewImpl(this.module, dataType, this.backend.gpuDataManager.create(bufferSize).id, dims);
+      const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0;
+      return new TensorViewImpl(this.module, dataType, gpuDataId, dims);
     };
     return this.backend.run(program, mappedInputs, outputIndices, createKernelOutput, createTemporaryOutput);
   }
@@ -118,7 +121,7 @@ class ComputeContextImpl implements ComputeContext {
       for (let i = 0; i < dims.length; i++) {
         this.module.HEAPU32[offset++] = dims[i];
       }
-      return this.module._JsepOutput(this.opKernelContext, index, data);
+      return this.module._JsepOutput!(this.opKernelContext, index, data);
     } catch (e) {
       throw new Error(
           `Failed to generate kernel's output[${index}] with dims [${dims}]. ` +
@@ -133,27 +136,39 @@ class ComputeContextImpl implements ComputeContext {
 /**
  * Initialize JSEP with WebGPU backend.
  *
- * This function will be called only once after the WebAssembly module is loaded and initialized ("_OrtInit" is called).
- * This function expects:
+ * This function will be called after the WebAssembly module is loaded and initialized ("_OrtInit" is called), once for
+ * each of the following EPs if they are specified:
+ * - "webgpu"
+ * - "webnn"
+ *
+ * For WebGPU, this function expects:
  *  - WebGPU is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false).
  *  - WebGPU is available in current environment. (a valid GPUAdapter is passed in)
+ *
+ * For WebNN, this function expects:
+ * - WebNN is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false).
+ * - WebNN is available in current environment. (navigator.ml is not undefined)
+ *
  * If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate
- * 'webgpu' backend.
+ * 'webgpu'/'webnn' backend.
  *
+ * @param name - the name of the EP, either "webgpu" or "webnn"
  * @param module - the ORT WebAssembly module
  * @param env - the ORT environment variable (ort.env)
  * @param gpuAdapter - the pre-created GPU adapter
  */
-export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapter): Promise<void> => {
+export const init =
+    async(name: 'webgpu'|'webnn', module: OrtWasmModule, env: Env, gpuAdapter?: GPUAdapter): Promise<void> => {
   const jsepInit = module.jsepInit;
   if (!jsepInit) {
     throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.');
   }
 
-  const backend = new WebGpuBackend();
-  await backend.initialize(env, gpuAdapter);
+  if (name === 'webgpu') {
+    const backend = new WebGpuBackend();
+    await backend.initialize(env, gpuAdapter!);
 
-  jsepInit(
+    jsepInit('webgpu', [
       // backend
       backend,
 
@@ -187,8 +202,8 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
           },
 
       // jsepCreateKernel
-      (kernelType: string, kernelId: number, attribute: unknown) =>
-          backend.createKernel(kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName(kernelId))),
+      (kernelType: string, kernelId: number, attribute: unknown) => backend.createKernel(
+          kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))),
 
       // jsepReleaseKernel
       (kernel: number) => backend.releaseKernel(kernel),
@@ -207,5 +222,9 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
       // jsepCaptureEnd
       () => backend.captureEnd(),
       // jsepReplay
-      () => backend.replay());
+      () => backend.replay()
+    ]);
+  } else {
+    jsepInit('webnn');
+  }
 };
diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts
index 6922d7ff5df6e..9a1d5463f7843 100644
--- a/js/web/lib/wasm/jsep/util.ts
+++ b/js/web/lib/wasm/jsep/util.ts
@@ -56,7 +56,16 @@ export class BroadcastUtil {
       if (aLen !== bLen && aLen > 1 && bLen > 1) {
         return undefined;
       }
-      cdims[crank - i] = Math.max(aLen, bLen);
+      const max = Math.max(aLen, bLen);
+      if (aLen && bLen) {
+        cdims[crank - i] = Math.max(aLen, bLen);
+      } else {
+        // when either aLen or bLen is 0, the other should be either 0 or 1, otherwise it is not broadcastable.
+        if (max > 1) {
+          return undefined;
+        }
+        cdims[crank - i] = 0;
+      }
     }
 
     return cdims;
@@ -92,6 +101,34 @@ export class ShapeUtil {
     return ShapeUtil.getSizeFromDimensionRange(dims, 0, dims.length);
   }
 
+  /**
+   * convert dims corresponding to type change to pack. ex. uint8 data to uint32
+   */
+  static convertShape(dims: readonly number[], size = 4): readonly number[] {
+    const rank = dims.length;
+    if (rank === 0) {
+      return [];
+    }
+    const newDims = new Array(rank);
+    let i = rank - 1;
+    while (i >= 0) {
+      if (dims[i] % size === 0) {
+        newDims[i] = dims[i] / size;
+        break;
+      }
+      if (size % dims[i] !== 0) {
+        throw new Error('cannot convert shape');
+      }
+      newDims[i] = 1;
+      size /= dims[i];
+      i--;
+    }
+    for (i--; i >= 0; i--) {
+      newDims[i] = dims[i];
+    }
+    return newDims;
+  }
+
   /**
    * calculate the size (number of elements) from the given axis (inclusive)
    */
diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
index ac08c5fb1f7ab..ba874c8dd0f80 100644
--- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
+++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
@@ -20,6 +20,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm';
 import {instanceNorm} from './ops/instance-norm';
 import {layerNorm} from './ops/layer-norm';
 import {matMul} from './ops/matmul';
+import {matMulNBits, parseMatMulNBitsAttributes} from './ops/matmulnbits';
 import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion';
 import {pad} from './ops/pad';
 import * as pool from './ops/pool';
@@ -92,6 +93,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
   ['LessOrEqual', [binaryOps.lessOrEqual]],
   ['Log', [unaryOps.log]],
   ['MatMul', [matMul]],
+  ['MatMulNBits', [matMulNBits, parseMatMulNBitsAttributes]],
   // TODO: support new attributes for MaxPool-8 and MaxPool-10
   ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]],
   ['Mul', [binaryOps.mul]],
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
index b5b6a2a15cd8c..11c8778b72335 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
@@ -23,17 +23,17 @@ import {DataType} from '../../../../wasm-common';
 import {LOG_DEBUG} from '../../../log';
 import {TensorView} from '../../../tensor-view';
 import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
-import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common';
+import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
 import {ConvTransposeAttributes} from '../conv-transpose';
 import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils';
 
-import {biasSnippet, typeSnippet} from './activation_util';
+import {biasSnippet} from './activation_util';
 import {utilFunctions} from './conv_util';
 import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
 
 const conv2dTransposeCommonSnippet =
-    (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => {
-      const type = typeSnippet(innerElementSize, 'f32');
+    (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string,
+     innerElementSize = 4): string => {
       const getWSnippet = (innerElementSize: number) => {
         switch (innerElementSize) {
           case 1:
@@ -47,7 +47,7 @@ const conv2dTransposeCommonSnippet =
             let v1 = w[getIndexFromCoords4D(coord1, vec4<i32>(uniforms.w_shape))];
             let v2 = w[getIndexFromCoords4D(coord2, vec4<i32>(uniforms.w_shape))];
             let v3 = w[getIndexFromCoords4D(coord3, vec4<i32>(uniforms.w_shape))];
-            return vec4<f32>(v0, v1, v2, v3);
+            return ${type}(v0, v1, v2, v3);
             `;
           default:
             throw new Error(`innerElementSize ${innerElementSize} is not supported.`);
@@ -224,7 +224,7 @@ export const createConv2DTransposeMatMulProgramInfo =
           const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
           inputVariables.push(bias);
           declareFunctions += `
-          fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
+          fn getBiasByOutputCoords(coords : vec4<i32>) -> ${bias.type.value} {
             return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
           }`;
         }
@@ -236,16 +236,20 @@ export const createConv2DTransposeMatMulProgramInfo =
           {name: 'pads', type: 'i32', length: pads.length}
         ];
         appendActivationUniforms(attributes, uniforms);
+        const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1);
+        if (elemType !== 'f16' && elemType !== 'f32') {
+          throw new Error(`elemType ${elemType} is not supported.`);
+        }
         return `
         ${utilFunctions('uniforms.result_strides')}
         ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)};
         ${declareFunctions}
-        ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)}
+        ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)}
         ${
             isVec4 ? makeMatMulPackedVec4Source(
-                         elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) :
+                         elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) :
                      makeMatMulPackedSource(
-                         elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false,
+                         elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false,
                          undefined, sequentialAccessByThreads)}`;
       };
 
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts
index b06c9fb496d15..010ee589c44fa 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts
@@ -13,25 +13,32 @@ export interface ConcatAttributes extends AttributeWithCacheKey {
   readonly axis: number;
 }
 
-const validateInputs = (inputs: readonly TensorView[]): void => {
+const validateInputs = (inputs: readonly TensorView[], axis: number): void => {
   if (!inputs || inputs.length < 1) {
     throw new Error('too few inputs');
   }
-
-  const inputType = inputs[0].dataType;
-  const inputDimensionality = inputs[0].dims.length;
-
-  for (const input of inputs) {
+  const referenceIndex = 0;
+  const referenceInput = inputs[referenceIndex];
+  const inputType = referenceInput.dataType;
+  const inputRank = referenceInput.dims.length;
+  inputs.forEach((input, i) => {
+    if (i === referenceIndex) {
+      return;
+    }
     // make sure types of all inputs match
     if (input.dataType !== inputType) {
       throw new Error('input tensors should be one type');
     }
-
     // make sure the dimensionality of all inputs are the same
-    if (input.dims.length !== inputDimensionality) {
+    if (input.dims.length !== inputRank) {
       throw new Error('input tensors should have the same shape');
     }
-  }
+    input.dims.forEach((dim, i) => {
+      if (i !== axis && dim !== referenceInput.dims[i]) {
+        throw new Error('non concat dimensions must match');
+      }
+    });
+  });
 };
 
 const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => `
@@ -64,65 +71,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe
   return codeLines.join('\n');
 };
 
-const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): ProgramInfo => {
-  const inputShape = inputs[0].dims.slice();
-  if (axis >= inputShape.length || axis < (-1 * inputShape.length)) {
-    throw new Error('axis specified for concat doesn\'t match input dimensionality');
-  }
-  const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis;
-  // ensure all of the non-concatenated axes match each other
-  // calculate the shape of the output tensor while we do that
-  const outputShape = inputShape.slice(0);
-  for (let i = 1; i < inputs.length; i++) {
-    const dataNShape = inputs[i].dims.slice();
-    for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) {
-      // add to the placeholder for computing output shape
-      if (axisIndex === adjustedAxis) {
-        outputShape[adjustedAxis] += dataNShape[axisIndex];
+const createConcatProgramInfo =
+    (inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => {
+      const outputSize = ShapeUtil.size(outputShape);
+
+      const sizeInConcatAxis = new Array<number>(inputs.length);
+      const inputVars = new Array<IndicesHelper>(inputs.length);
+
+      let previousSum = 0;
+      const inputDependencies: ProgramInputTensorInfoDependency[] = [];
+      const inputRanks = [];
+      const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
+      for (let i = 0; i < inputs.length; ++i) {
+        previousSum += inputs[i].dims[adjustedAxis];
+        sizeInConcatAxis[i] = previousSum;
+        inputRanks.push(inputs[i].dims.length);
+        inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
+        inputDependencies.push('rank');
+        programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
       }
-      // ensure all non-cancatenated axes match each other
-      else if (inputShape[axisIndex] !== dataNShape[axisIndex]) {
-        throw new Error('non concat dimensions must match');
+      for (let i = 0; i < inputs.length; ++i) {
+        programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
       }
-    }
-  }
-
-  const outputSize = ShapeUtil.size(outputShape);
-
-  const sizeInConcatAxis = new Array<number>(inputs.length);
-  const inputVars = new Array<IndicesHelper>(inputs.length);
-  const dataType = inputs[0].dataType;
-
-  let previousSum = 0;
-  const inputDependencies: ProgramInputTensorInfoDependency[] = [];
-  const inputRanks = [];
-  const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
-  for (let i = 0; i < inputs.length; ++i) {
-    previousSum += inputs[i].dims[adjustedAxis];
-    sizeInConcatAxis[i] = previousSum;
-    inputRanks.push(inputs[i].dims.length);
-    inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
-    inputDependencies.push('rank');
-    programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
-  }
-  for (let i = 0; i < inputs.length; ++i) {
-    programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
-  }
-  programUniforms.push(...createTensorShapeVariables(outputShape));
+      programUniforms.push(...createTensorShapeVariables(outputShape));
 
-  const output = outputVariable('output', dataType, outputShape.length);
-  const indicesAxis = output.indicesGet('indices', adjustedAxis);
-  const sizeInConcatAxisStr =
-      Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
-  const getShaderSource = (shaderHelper: ShaderHelper) => `
+      const output = outputVariable('output', dataType, outputShape.length);
+      const indicesAxis = output.indicesGet('indices', adjustedAxis);
+      const sizeInConcatAxisStr =
+          Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
+      const getShaderSource = (shaderHelper: ShaderHelper) => `
 
   ${(() => {
-    shaderHelper.registerUniform('outputSize', 'u32');
-    for (let i = 0; i < inputs.length; i++) {
-      shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
-    }
-    return shaderHelper.declareVariables(...inputVars, output);
-  })()}
+        shaderHelper.registerUniform('outputSize', 'u32');
+        for (let i = 0; i < inputs.length; i++) {
+          shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
+        }
+        return shaderHelper.declareVariables(...inputVars, output);
+      })()}
 
   ${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)}
 
@@ -140,21 +125,30 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
     ${assignOutputData(inputVars, output)}
   }`;
 
-  return {
-    name: 'Concat',
-    shaderCache: {hint: `${axis}`, inputDependencies},
-    getRunData: () => ({
-      outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
-      dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
-      programUniforms,
-    }),
-    getShaderSource,
-  };
-};
+      return {
+        name: 'Concat',
+        shaderCache: {hint: `${adjustedAxis}`, inputDependencies},
+        getRunData: () => ({
+          outputs: [{dims: outputShape, dataType}],
+          dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
+          programUniforms,
+        }),
+        getShaderSource,
+      };
+    };
 
 export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
-  validateInputs(context.inputs);
-  context.compute(createConcatProgramInfo(context.inputs, attributes.axis));
+  const inputs = context.inputs;
+  const inputShape = inputs[0].dims;
+  const adjustedAxis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
+  validateInputs(inputs, adjustedAxis);
+  const outputShape = inputShape.slice();
+  outputShape[adjustedAxis] =
+      inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0);
+  // 0 length tensors are valid for concat, remove them
+  const nonEmptyInputs = inputs.filter(input => ShapeUtil.size(input.dims) > 0);
+  context.compute(
+      createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {inputs: nonEmptyInputs});
 };
 
 export const parseConcatAttributes = (attributes: Record<string, unknown>): ConcatAttributes =>
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
index 5afec0389fac8..b68d4dcae4cb9 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
@@ -148,11 +148,12 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
   // const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */
   const isChannelsLast = attributes.format === 'NHWC';
   if (attributes.group !== 1) {
-    // Temporarily disable createGroupedConvVectorizeProgramInfo path due to bots failures with below two cases:
+    // NVIDIA GPU with ampere architecture fails with below 2 cases, but we couldn't repro them with any other
+    // GPUs. So just disable vectorize on NVIDIA ampere to ensure always correct outputs.
     // [webgpu]Conv - conv - vectorize group - B
     // [webgpu]Conv - conv - vectorize group - D
-    const disableGroupedConvVectorize = true;
-    if (!disableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group &&
+    const enableGroupedConvVectorize = !context.adapterInfo.isArchitecture('ampere');
+    if (enableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group &&
         inputs[1].dims[1] === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1) {
       const outputShape = calculateOutputShape(
           inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides,
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts
index 5c31e6dd86c00..d48bb909f7f8f 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts
@@ -55,7 +55,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
           if (idx${x} < 0) {
             idx${x} = idx${x} + uniforms.axisDimLimit;
           }
-          var dataIndices${x} = ${data.type.indices}(0);
+          var dataIndices${x} : ${data.type.indices};
         `;
       for (let i = 0, j = 0; i < inputRank; i++) {
         if (i === axis) {
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
index 2f652dbd310ab..2c72def089144 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
@@ -207,7 +207,7 @@ const computeMean =
     let offset = currentImageNumber * uniforms.image_size;
     var sum = ${fillVector('f32', components)};
     var squaredSum = ${fillVector('f32', components)};
-    for (var i: u32 = 0; i < ${WG}; i++) {
+    for (var i: u32 = 0; i < min(${WG}, uniforms.H); i++) {
         let value = input[offset + i + currentChannelNumber * ${WG}];
         sum += value[0];
         squaredSum += value[1];
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
index 3f73d9cb7c5bc..d5f97213e49ce 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
@@ -85,28 +85,28 @@ const createLayerNormProgramInfo =
   ${shaderHelper.mainStart()}
     ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.norm_count')}
     let offset = global_idx * uniforms.norm_size_vectorized;
-    var meanVector = ${fillVector('f32', components)};
-    var meanSquareVector = ${fillVector('f32', components)};
+    var mean_vector = ${fillVector('f32', components)};
+    var mean_square_vector = ${fillVector('f32', components)};
 
     for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {
       let value = ${castToF32(dataType, components, 'x[h + offset]')};
-      meanVector += value;
-      meanSquareVector += value * value;
+      mean_vector += value;
+      mean_square_vector += value * value;
     }
-    let mean = ${sumVector('meanVector', components)} / uniforms.norm_size;
-    let invStdDev =
-        inverseSqrt(${sumVector('meanSquareVector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon);
+    let mean = ${sumVector('mean_vector', components)} / uniforms.norm_size;
+    let inv_std_dev = inverseSqrt(${
+            sumVector('mean_square_vector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon);
 
     for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {
       let f32input = ${castToF32(dataType, components, 'x[j + offset]')};
       let f32scale = ${castToF32(dataType, components, 'scale[j]')};
-      output[j + offset] = ${variables[0].type.value}((f32input - mean) * invStdDev * f32scale
+      output[j + offset] = ${variables[0].type.value}((f32input - mean) * inv_std_dev * f32scale
         ${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''}
       );
     }
 
     ${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''};
-    ${hasInvStdOutput ? 'inv_std_output[global_idx] = invStdDev' : ''};
+    ${hasInvStdOutput ? 'inv_std_output[global_idx] = inv_std_dev' : ''};
   }`;
       };
       const outputs = [{dims: outputShape, dataType: inputs[0].dataType}];
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
new file mode 100644
index 0000000000000..9bf5e4066139d
--- /dev/null
+++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
@@ -0,0 +1,250 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+import {DataType} from '../../../wasm-common';
+import {TensorView} from '../../tensor-view';
+import {ShapeUtil} from '../../util';
+import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
+import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
+
+import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
+
+//  TODO support quantization bits not equal to 4
+export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
+  k: number;
+  n: number;
+  accuracyLevel: number;
+  bits: number;
+  blockSize: number;
+}
+
+const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): void => {
+  if (inputs.length < 3 || inputs.length > 4) {
+    throw new Error('MatMulNBits requires 3 or 4 inputs');
+  }
+  const a = inputs[0];
+  const aRank = a.dims.length;
+  if (a.dims[aRank - 1] !== attributes.k) {
+    throw new Error('The last dim of input shape does not match the k value');
+  }
+  const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
+  const blobSize = attributes.blockSize / 8 * attributes.bits;
+  const b = inputs[1];
+  if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) {
+    throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize');
+  }
+  const scales = inputs[2];
+  const scalesShape = scales.dims;
+  if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) {
+    throw new Error('scales input size error.');
+  }
+  if (inputs.length === 4) {
+    const zeroPoints = inputs[3];
+    const zeroPointsShape = zeroPoints.dims;
+    const expectedZeroPointsSize =
+        attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2);
+    if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) {
+      throw new Error('zeroPoints input size error.');
+    }
+  }
+};
+
+export const createMatMulNBitsProgramInfo =
+    (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => {
+      const inputShape = inputs[0].dims;
+      const aRank = inputShape.length;
+      const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.n);
+      const m = inputShape[aRank - 2];
+      const blobSize = attributes.blockSize / 8 * attributes.bits;
+      const blobSizeInWords = blobSize / 4;
+      const outputNumber = getMaxComponents(m);
+      const components = getMaxComponents(attributes.n);
+      const aComponents = getMaxComponents(attributes.k);
+      const bComponents = getMaxComponents(blobSizeInWords);
+      const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
+      const programUniforms: ProgramUniform[] = [
+        {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k},
+        {type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel},
+        {type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize}
+      ];
+      const aShape = inputShape.slice();
+      aShape.splice(-1, 1, attributes.k / aComponents);
+      const bShape = ShapeUtil.convertShape(inputs[1].dims).slice();
+      bShape.splice(-1, 1, blobSizeInWords / bComponents);
+      programUniforms.push(...createTensorShapeVariables(aShape));
+      programUniforms.push(...createTensorShapeVariables(bShape));
+      programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
+      if (inputs.length === 4) {
+        programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
+      }
+      const oShape = outputShape.slice();
+      oShape.splice(-1, 1, attributes.n / components);
+      programUniforms.push(...createTensorShapeVariables(oShape));
+      const getShaderSource = (shaderHelper: ShaderHelper) => {
+        const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
+        const b = inputVariable('b', DataType.uint32, bShape.length, bComponents);
+        const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
+        const inputVariables = [a, b, scales];
+        const zeroPoints =
+            inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
+        if (zeroPoints) {
+          inputVariables.push(zeroPoints);
+        }
+        const output = outputVariable('output', inputs[0].dataType, outputShape.length, components);
+        const uniforms: UniformsArrayType = [
+          {name: 'output_size', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
+          {name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'}
+        ];
+        const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
+        const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
+
+        const qDqDataType = (() => {
+          switch (aComponents) {
+            case 1:
+              return `array<${dataType}, 8>`;
+            case 2:
+              return `mat4x2<${dataType}>`;
+            case 4:
+              return `mat2x4<${dataType}>`;
+            default:
+              throw new Error(`${aComponents}-component is not supported.`);
+          }
+        })();
+
+        const dequantizeImpl = `
+        fn dequantize(quantized: ${qDqDataType}, zero_point: ${dataType}, scale: ${dataType}) -> ${qDqDataType} {
+          ${(() => {
+          if (aComponents === 1) {
+            return `var dequantized = ${qDqDataType}(${
+                Array.from({length: 8}, (_, i) => `(quantized[${i}] - zero_point) * scale`).join(', ')});
+              return dequantized;`;
+          } else {
+            return `var zero_points: ${qDqDataType} = ${qDqDataType}(${Array(8).fill('zero_point').join(',')});
+              return (quantized - zero_points) * scale;`;
+          }
+        })()}
+        }`;
+        const ortUnpack8x4snormImpl = `
+        fn ortUnpack8x4snorm(value: u32) -> ${qDqDataType} {
+          var quantized: ${qDqDataType};
+          var offset: u32 = 0;
+          let count: u32 = 4;
+          for (var i: u32 = 0; i < 8u; i++) {
+            var result = ${dataType}(extractBits(value, offset, count));
+            ${(() => {
+          switch (aComponents) {
+            case 1:
+              return 'quantized[i] = result;';
+            case 2:
+              return 'quantized[i / 2][i % 2] = result;';
+            case 4:
+              return 'quantized[i / 4][i % 4] = result;';
+            default:
+              throw new Error(`${aComponents}-component is not supported.`);
+          }
+        })()}
+            offset += count;
+          }
+          return quantized;
+        }`;
+
+        const updateZeroPointIndex = zeroPoints ? `
+          zero_point_offset += 4;
+          if (zero_point_offset == 32) {
+            zero_point_offset = 0;
+            zero_point_index++;
+            zero_point_word = ${zeroPoints.getByOffset('zero_point_index')};
+          }` :
+                                                  '';
+
+        return `
+        ${dequantizeImpl};
+        ${ortUnpack8x4snormImpl};
+        ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
+        ${shaderHelper.mainStart()}
+          ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
+          var output_values: array<${output.type.value}, ${outputNumber}>;
+          var output_indices = ${output.offsetToIndices('global_idx')};
+          var n = ${output.indicesGet('output_indices', aRank - 1)};
+          var m = ${output.indicesGet('output_indices', aRank - 2)};
+          var a_indices: ${a.type.indices} = output_indices;
+          // Two zero points are packed into one byte because uniforms.bits <= 4.
+          // zero_point_offset is either 0 or 4. It is bit offset within one byte.
+          // TODO support zero_point_offset for bits > 4
+          ${
+            zeroPoints ? `
+          var zero_point_index: u32 = n * ${components} * ((${nBlocksPerCol} + 1) / 2) / 4;
+          var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')};
+          var zero_point_offset: u32 = 0;` :
+                         ''}
+          var scale_index = n * ${nBlocksPerCol * components};
+          var b_indices: ${b.type.indices};
+          for (var c: u32 = 0; c < ${components}; c++) {
+            ${b.indicesSet('b_indices', '0', `n * ${components} + c`)};
+            var block_offset: u32 = 0;
+            for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) {
+              // The scale and zero points are computed per block.
+              let scale = ${scales.getByOffset('scale_index')};
+              // The default zero point is 8 for unsigned 4-bit quantization.
+              let zero_point = ${dataType}(${zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0});
+              ${b.indicesSet('b_indices', '1', 'block')};
+              var word_offset: u32 = block_offset;
+              for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) {
+                ${b.indicesSet('b_indices', '2', 'word')};
+                let b_data = ${b.getByIndices('b_indices')};
+                for (var i: u32 = 0; i < ${bComponents}; i++) {
+                  let b_value = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'};
+                  let b_quantized_values: ${qDqDataType} = ortUnpack8x4snorm(b_value);
+                  let b_dequantized_values = dequantize(b_quantized_values, zero_point, scale);
+                  // Number of B elements per 32-bit word is 32/bits = 32/4 = 8
+                  var offset: u32 = word_offset;
+                  for (var j: u32 = 0; j < 8/${aComponents}; j++) {
+                    ${a.indicesSet('a_indices', aRank - 1, `offset/${aComponents}`)};
+                    for (var k: u32 = 0; k < ${outputNumber}u; k++) {
+                      ${a.indicesSet('a_indices', aRank - 2, `m * ${outputNumber} + k`)};
+                      let a_data = ${a.getByIndices('a_indices')};
+                      output_values[k]${components > 1 ? '[c]' : ''} += ${
+            aComponents === 1 ? 'a_data * b_dequantized_values[j]' : 'dot(a_data, b_dequantized_values[j])'};
+                    }
+                    offset += ${aComponents};
+                  }
+                  word_offset += 8;
+                }
+              }
+              scale_index++;
+              ${updateZeroPointIndex}
+              block_offset += uniforms.block_size;
+            }
+            // Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte.
+            ${
+            zeroPoints ? `if (zero_point_offset % 8 > 0) {
+                ${updateZeroPointIndex}
+              }` :
+                         ''}
+            }
+            for (var k: u32 = 0u; k < ${outputNumber}u; k++) {
+              ${output.indicesSet('output_indices', aRank - 2, `${outputNumber + ' * m + k'}`)};
+              ${output.setByIndices('output_indices', 'output_values[k]')}
+            }
+        }`;
+      };
+      return {
+        name: 'MatMulNBits',
+        shaderCache:
+            {hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')},
+        getRunData: () => ({
+          outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
+          dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
+          programUniforms
+        }),
+        getShaderSource
+      };
+    };
+
+export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
+  validateInputs(context.inputs, attributes);
+  context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));
+};
+
+export const parseMatMulNBitsAttributes = (attributes: Record<string, unknown>): MatMulNBitsAttributes =>
+    createAttributeWithCacheKey(attributes as Omit<MatMulNBitsAttributes, keyof AttributeWithCacheKey>);
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
index 4e933573b9137..5521650e8ded4 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
@@ -381,8 +381,9 @@ const createMaxPoolProgramInfo =
           programUniforms
         }),
         getShaderSource: shaderHelper => generatePoolingCode(
-            shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, -1e5, uniforms,
-            hasPads, pwStartEndNotZero, phStartEndNotZero),
+            shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2,
+            (input.dataType === DataType.float16) ? -65504 : -1e5, uniforms, hasPads, pwStartEndNotZero,
+            phStartEndNotZero),
       };
     };
 
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts
index a9b28d7c034f3..210b3ee7e2fca 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts
@@ -131,7 +131,7 @@ export const createReduceSharedProgramInfo =
       const workgroupSize = 32;
 
       const sharedMemorySnippet = `
-          var<workgroup> aBestValues : array<${output.type.storage}, ${workgroupSize}>;
+          var<workgroup> aBestValues : array<f32, ${workgroupSize}>;
        `;
 
       const getShaderSource = (shaderHelper: ShaderHelper) => `
@@ -145,10 +145,10 @@ export const createReduceSharedProgramInfo =
           let outputIndex = global_idx / ${workgroupSize};
           let offset = outputIndex * uniforms.reduceSize;
 
-          var bestValue = ${output.type.storage}(${reduceInitValues[reduceType]});
+          var bestValue = f32(${reduceInitValues[reduceType]});
           let Length = uniforms.reduceSize;
           for (var k = local_idx; k < Length; k = k + ${workgroupSize}) {
-           let candidate = ${output.type.storage}(${input.getByOffset('offset + k')});
+           let candidate = f32(${input.getByOffset('offset + k')});
            bestValue = ${reduceOps[reduceType]};
           }
           aBestValues[local_idx] = bestValue;
@@ -172,8 +172,8 @@ export const createReduceSharedProgramInfo =
           output.setByOffset(
               'outputIndex',
               `${
-                  reduceType === 'mean' ? `bestValue / ${output.type.storage}(uniforms.reduceSize)` :
-                                          `${reduceOutputValues[reduceType]}`}`)};
+                  reduceType === 'mean' ? `${output.type.storage}(bestValue / f32(uniforms.reduceSize))` :
+                                          `${output.type.storage}(${reduceOutputValues[reduceType]})`}`)};
          }
         }`;
 
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts
index 14d6f37927590..a09ac78b17006 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts
@@ -68,7 +68,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split
   const dataType = inputs[0].dataType;
   const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
   const outputs = new Array<IndicesHelper>(attributes.numOutputs);
-  const input = inputVariable('input', dataType, inputShape);
+  const input = inputVariable('input', dataType, inputShape.length);
   const sizeInSplitAxis = new Array<number>(attributes.numOutputs);
   const outputsTensorInfo: TensorInfo[] = [];
   const outputShapes: number[][] = [];
@@ -80,7 +80,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split
     const outputShape = inputShape.slice();
     outputShape[attributes.axis] = attributes.splitSizes[i];
     outputShapes.push(outputShape);
-    outputs[i] = outputVariable(`output${i}`, dataType, outputShape);
+    outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length);
     outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType});
   }
   programUniforms.push(
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts
index cfee07a9239d7..a6375847fc42f 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts
@@ -27,7 +27,7 @@ const createWhereOpProgramShader =
           const expressionA = `a_data[index_a${x}][component_a${x}]`;
           const expressionB = `b_data[index_b${x}][component_b${x}]`;
           // eslint-disable-next-line no-bitwise
-          const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
+          const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`;
           return `
             let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
             let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)};
@@ -38,6 +38,7 @@ const createWhereOpProgramShader =
             let index_c${x} = offset_c${x} / 4u;
             let component_a${x} = offset_a${x} % 4u;
             let component_b${x} = offset_b${x} % 4u;
+            let component_c${x} = offset_c${x} % 4u;
             ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
           `;
         };
diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts
index ba5b84fcfe067..48e0855f01a97 100644
--- a/js/web/lib/wasm/jsep/webgpu/types.ts
+++ b/js/web/lib/wasm/jsep/webgpu/types.ts
@@ -15,6 +15,13 @@ export enum GpuDataType {
 }
 export type GpuDataId = number;
 
+export type GpuArchitecture = 'ampere';
+export type GpuVendor = 'amd'|'intel'|'nvidia';
+export interface AdapterInfo {
+  isArchitecture: (architecture: GpuArchitecture) => boolean;
+  isVendor: (vendor: GpuVendor) => boolean;
+}
+
 export interface GpuData {
   type: GpuDataType;
   id: GpuDataId;
@@ -146,6 +153,11 @@ export interface ComputeContextInputsOutputsMapping {
  * A ComputeContext instance carries the states that representing the current running of a kernel.
  */
 export interface ComputeContext {
+  /**
+   * gpu adapter info
+   */
+  readonly adapterInfo: AdapterInfo;
+
   /**
    * stores the pointer to OpKernelContext
    */
diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts
index 6cbd38c76ccc8..3ce37a2d6b652 100644
--- a/js/web/lib/wasm/proxy-worker/main.ts
+++ b/js/web/lib/wasm/proxy-worker/main.ts
@@ -103,7 +103,7 @@ self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
                   } else {
                     postMessage(
                         {type, out: outputs} as OrtWasmMessage,
-                        extractTransferableBuffers(outputs as SerializableTensorMetadata[]));
+                        extractTransferableBuffers([...inputs, ...outputs] as SerializableTensorMetadata[]));
                   }
                 },
                 err => {
diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts
index 86017a4ec6904..6ff4e86b1235e 100644
--- a/js/web/lib/wasm/proxy-wrapper.ts
+++ b/js/web/lib/wasm/proxy-wrapper.ts
@@ -155,7 +155,7 @@ export const createSession =
             ensureWorker();
             return new Promise<SerializableSessionMetadata>((resolve, reject) => {
               enqueueCallbacks('create', [resolve, reject]);
-              const message: OrtWasmMessage = {type: 'create', in : {model, options}};
+              const message: OrtWasmMessage = {type: 'create', in : {model, options: {...options}}};
               const transferable: Transferable[] = [];
               if (model instanceof Uint8Array) {
                 transferable.push(model.buffer);
diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts
index b9eff45e890c4..54eaf5e0c43cc 100644
--- a/js/web/lib/wasm/wasm-common.ts
+++ b/js/web/lib/wasm/wasm-common.ts
@@ -3,6 +3,12 @@
 
 import {Tensor} from 'onnxruntime-common';
 
+// a dummy type declaration for Float16Array in case any polyfill is available.
+declare global {
+  // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any
+  const Float16Array: any;
+}
+
 // This file includes common definitions. They do NOT have dependency on the WebAssembly instance.
 
 /**
@@ -117,7 +123,8 @@ export const tensorTypeToTypedArrayConstructor = (type: Tensor.Type): Float32Arr
     Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => {
       switch (type) {
         case 'float16':
-          return Uint16Array;
+          // allow Float16Array polyfill.
+          return typeof Float16Array !== 'undefined' && Float16Array.from ? Float16Array : Uint16Array;
         case 'float32':
           return Float32Array;
         case 'uint8':
@@ -169,7 +176,8 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro
  * Check whether the given tensor type is supported by GPU buffer
  */
 export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' ||
-    type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32';
+    type === 'float16' || type === 'int32' || type === 'int64' || type === 'uint32' || type === 'uint8' ||
+    type === 'bool';
 
 /**
  * Map string data location to integer value
diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts
index 37b9ed6a1002f..9b27051f1b9fe 100644
--- a/js/web/lib/wasm/wasm-core-impl.ts
+++ b/js/web/lib/wasm/wasm-core-impl.ts
@@ -84,27 +84,57 @@ export const initRuntime = async(env: Env): Promise<void> => {
  * @param epName
  */
 export const initEp = async(env: Env, epName: string): Promise<void> => {
-  if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) {
-    // perform WebGPU availability check
-    if (typeof navigator === 'undefined' || !navigator.gpu) {
-      throw new Error('WebGPU is not supported in current environment');
-    }
-    const adapter = await navigator.gpu.requestAdapter();
-    if (!adapter) {
-      throw new Error(
-          'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.');
-    }
+  if (!BUILD_DEFS.DISABLE_WEBGPU) {
+    // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
+    const initJsep = require('./jsep/init').init;
 
-    if (!env.wasm.simd) {
-      throw new Error(
-          'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP');
-    }
+    if (epName === 'webgpu') {
+      // perform WebGPU availability check
+      if (typeof navigator === 'undefined' || !navigator.gpu) {
+        throw new Error('WebGPU is not supported in current environment');
+      }
+
+      let adapter = env.webgpu.adapter as GPUAdapter | null;
+      if (!adapter) {
+        // if adapter is not set, request a new adapter.
+        const powerPreference = env.webgpu.powerPreference;
+        if (powerPreference !== undefined && powerPreference !== 'low-power' &&
+            powerPreference !== 'high-performance') {
+          throw new Error(`Invalid powerPreference setting: "${powerPreference}"`);
+        }
+        const forceFallbackAdapter = env.webgpu.forceFallbackAdapter;
+        if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') {
+          throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`);
+        }
+        adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter});
+        if (!adapter) {
+          throw new Error(
+              'Failed to get GPU adapter. ' +
+              'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.');
+        }
+      } else {
+        // if adapter is set, validate it.
+        if (typeof adapter.limits !== 'object' || typeof adapter.features !== 'object' ||
+            typeof adapter.requestDevice !== 'function') {
+          throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.');
+        }
+      }
 
-    // init JSEP if available
+      if (!env.wasm.simd) {
+        throw new Error(
+            'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP');
+      }
 
-    // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
-    const initJsep = require('./jsep/init').init;
-    await initJsep(getInstance(), env, adapter);
+      await initJsep('webgpu', getInstance(), env, adapter);
+    }
+    if (epName === 'webnn') {
+      // perform WebNN availability check
+      if (typeof navigator === 'undefined' || !(navigator as unknown as {ml: unknown}).ml) {
+        throw new Error('WebNN is not supported in current environment');
+      }
+
+      await initJsep('webnn', getInstance(), env);
+    }
   }
 };
 
@@ -372,7 +402,12 @@ export const prepareInputOutputTensor =
         const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
         const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!;
         dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
-        rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength);
+
+        const registerBuffer = wasm.jsepRegisterBuffer;
+        if (!registerBuffer) {
+          throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.');
+        }
+        rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength);
       } else {
         const data = tensor[2];
 
@@ -587,7 +622,11 @@ export const run = async(
           // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU
           // tensor for it. There is no mapping GPU buffer for an empty tensor.
           if (preferredLocation === 'gpu-buffer' && size > 0) {
-            const gpuBuffer = wasm.jsepGetBuffer(dataOffset);
+            const getBuffer = wasm.jsepGetBuffer;
+            if (!getBuffer) {
+              throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.');
+            }
+            const gpuBuffer = getBuffer(dataOffset);
             const elementSize = getTensorElementSize(dataType);
             if (elementSize === undefined || !isGpuBufferSupportedType(type)) {
               throw new Error(`Unsupported data type: ${type}`);
@@ -599,7 +638,7 @@ export const run = async(
             output.push([
               type, dims, {
                 gpuBuffer,
-                download: wasm.jsepCreateDownloader(gpuBuffer, size * elementSize, type),
+                download: wasm.jsepCreateDownloader!(gpuBuffer, size * elementSize, type),
                 dispose: () => {
                   wasm._OrtReleaseTensor(tensor);
                 }
diff --git a/js/web/package-lock.json b/js/web/package-lock.json
index 41c44aaa2679b..5c9113459ff06 100644
--- a/js/web/package-lock.json
+++ b/js/web/package-lock.json
@@ -52,7 +52,7 @@
       "version": "1.18.0",
       "license": "MIT",
       "devDependencies": {
-        "typedoc": "^0.23.22"
+        "typedoc": "^0.25.7"
       }
     },
     "node_modules/@chiragrupani/karma-chromium-edge-launcher": {
@@ -1351,9 +1351,9 @@
       "dev": true
     },
     "node_modules/follow-redirects": {
-      "version": "1.15.4",
-      "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz",
-      "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==",
+      "version": "1.15.6",
+      "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz",
+      "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==",
       "dev": true,
       "funding": [
         {
@@ -4595,9 +4595,9 @@
       "dev": true
     },
     "follow-redirects": {
-      "version": "1.15.4",
-      "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz",
-      "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==",
+      "version": "1.15.6",
+      "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz",
+      "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==",
       "dev": true
     },
     "from": {
@@ -5503,7 +5503,7 @@
     "onnxruntime-common": {
       "version": "file:../common",
       "requires": {
-        "typedoc": "^0.23.22"
+        "typedoc": "^0.25.7"
       }
     },
     "p-cancelable": {
diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts
index ed4dd76a6e315..b2b212bdb9bc1 100644
--- a/js/web/script/test-runner-cli-args.ts
+++ b/js/web/script/test-runner-cli-args.ts
@@ -29,8 +29,10 @@ Options:
 *** General Options ***
 
  -h, --help                    Print this message.
- -d, --debug                   Specify to run test runner in debug mode.
-                                 Debug mode outputs verbose log for test runner, sets up environment debug flag, and keeps karma not to exit after tests completed.
+ -d, --debug                   Specify to run test runner in debug mode. Debug mode does the following:
+                                 - outputs verbose log for test runner
+                                 - sets up environment debug flag (env.debug = true)
+                                 - opens Chromium debug port at 9333 and keeps karma not to exit after tests completed.
  -b=<...>, --backend=<...>     Specify one or more backend(s) to run the test upon.
                                  Backends can be one or more of the following, splitted by comma:
                                    webgl
@@ -47,38 +49,55 @@ Options:
                                  bs         (for BrowserStack tests)
  -p, --profile                 Enable profiler.
                                  Profiler will generate extra logs which include the information of events time consumption
+ -t, --trace                   Enable trace.
  -P[=<...>], --perf[=<...>]    Generate performance number. Cannot be used with flag --debug.
                                  This flag can be used with a number as value, specifying the total count of test cases to run. The test cases may be used multiple times. Default value is 10.
  -c, --file-cache              Enable file cache.
+
+*** Session Options ***
+ -u=<...>, --optimized-model-file-path=<...>        Specify whether to dump the optimized model.
+ -o=<...>, --graph-optimization-level=<...>         Specify graph optimization level.
+                                                      Default is 'all'. Valid values are 'disabled', 'basic', 'extended', 'all'.
  -i=<...>, --io-binding=<...>  Specify the IO binding testing type. Should be one of the following:
-                                 none          (default)
+                                 none            (default)
                                  gpu-tensor      use pre-allocated GPU tensors for inputs and outputs
                                  gpu-location    use pre-allocated GPU tensors for inputs and set preferredOutputLocation to 'gpu-buffer'
 
-*** Session Options ***
- -u=<...>, --optimized-model-file-path=<...>        Specify whether to dump the optimized model.
- -o=<...>, --graph-optimization-level=<...>    Specify graph optimization level.
-                                                 Default is 'all'. Valid values are 'disabled', 'basic', 'extended', 'all'.
 *** Logging Options ***
 
- --log-verbose=<...>           Set log level to verbose
- --log-info=<...>              Set log level to info
- --log-warning=<...>           Set log level to warning
- --log-error=<...>             Set log level to error
-                                 The 4 flags above specify the logging configuration. Each flag allows to specify one or more category(s), splitted by comma. If use the flags without value, the log level will be applied to all category.
+ --log-verbose                 Set log level to verbose
+ --log-info                    Set log level to info
+ --log-warning                 Set log level to warning
+ --log-error                   Set log level to error
+                                 The 4 flags above specify the logging configuration.
 
 *** Backend Options ***
 
+ --wasm.<...>=<...>            Set global environment flags for each backend.
+ --webgl.<...>=<...>             These flags can be used multiple times to set multiple flags. For example:
+ --webgpu.<...>=<...>            --webgpu.profiling.mode=default --wasm.numThreads=1 --wasm.simd=false
+ --webnn.<...>=<...>
+
+ --webnn-device-type           Set the WebNN device type (cpu/gpu)
+
  -x, --wasm-number-threads     Set the WebAssembly number of threads
+                                ("--wasm-number-threads" is deprecated. use "--wasm.numThreads" or "-x" instead)
  --wasm-init-timeout           Set the timeout for WebAssembly backend initialization, in milliseconds
+                                (deprecated. use "--wasm.initTimeout" instead)
  --wasm-enable-simd            Set whether to enable SIMD
+                                (deprecated. use "--wasm.simd" instead)
  --wasm-enable-proxy           Set whether to enable proxy worker
+                                (deprecated. use "--wasm.proxy" instead)
  --webgl-context-id            Set the WebGL context ID (webgl/webgl2)
+                                (deprecated. use "--webgl.contextId" instead)
  --webgl-matmul-max-batch-size Set the WebGL matmulMaxBatchSize
+                                (deprecated. use "--webgl.matmulMaxBatchSize" instead)
  --webgl-texture-cache-mode    Set the WebGL texture cache mode (initializerOnly/full)
+                                (deprecated. use "--webgl.textureCacheMode" instead)
  --webgl-texture-pack-mode     Set the WebGL texture pack mode (true/false)
+                                (deprecated. use "--webgl.pack" instead)
  --webgpu-profiling-mode       Set the WebGPU profiling mode (off/default)
- --webnn-device-type           Set the WebNN device type (cpu/gpu)
+                                (deprecated. use "--webgpu.profiling.mode" instead)
 
 *** Browser Options ***
 
@@ -171,7 +190,6 @@ export interface TestRunnerCliArgs {
 
   cpuOptions?: InferenceSession.CpuExecutionProviderOption;
   cudaOptions?: InferenceSession.CudaExecutionProviderOption;
-  cudaFlags?: Record<string, unknown>;
   wasmOptions?: InferenceSession.WebAssemblyExecutionProviderOption;
   webglOptions?: InferenceSession.WebGLExecutionProviderOption;
   webnnOptions?: InferenceSession.WebNNExecutionProviderOption;
@@ -260,40 +278,29 @@ function parseCpuOptions(_args: minimist.ParsedArgs): InferenceSession.CpuExecut
   return {name: 'cpu'};
 }
 
-function parseCpuFlags(_args: minimist.ParsedArgs): Record<string, unknown> {
-  return {};
-}
-
 function parseWasmOptions(_args: minimist.ParsedArgs): InferenceSession.WebAssemblyExecutionProviderOption {
   return {name: 'wasm'};
 }
 
 function parseWasmFlags(args: minimist.ParsedArgs): Env.WebAssemblyFlags {
-  const numThreads = args.x || args['wasm-number-threads'];
+  const wasm = args.wasm || {};
+  const numThreads = wasm.numThreads = wasm.numThreads ?? (args.x ?? args['wasm-number-threads']);
   if (typeof numThreads !== 'undefined' && typeof numThreads !== 'number') {
-    throw new Error('Flag "x"/"wasm-number-threads" must be a number value');
+    throw new Error('Flag "wasm.numThreads"/"x"/"wasm-number-threads" must be a number value');
   }
-  const initTimeout = args['wasm-init-timeout'];
+  const initTimeout = wasm.initTimeout = wasm.initTimeout ?? args['wasm-init-timeout'];
   if (typeof initTimeout !== 'undefined' && typeof initTimeout !== 'number') {
-    throw new Error('Flag "wasm-init-timeout" must be a number value');
-  }
-  let simd = args['wasm-enable-simd'];
-  if (simd === 'true') {
-    simd = true;
-  } else if (simd === 'false') {
-    simd = false;
-  } else if (typeof simd !== 'undefined' && typeof simd !== 'boolean') {
-    throw new Error('Flag "wasm-enable-simd" must be a boolean value');
-  }
-  let proxy = args['wasm-enable-proxy'];
-  if (proxy === 'true') {
-    proxy = true;
-  } else if (proxy === 'false') {
-    proxy = false;
-  } else if (typeof proxy !== 'undefined' && typeof proxy !== 'boolean') {
-    throw new Error('Flag "wasm-enable-proxy" must be a boolean value');
-  }
-  return {numThreads, initTimeout, simd, proxy};
+    throw new Error('Flag "wasm.initTimeout"/"wasm-init-timeout" must be a number value');
+  }
+  const simd = wasm.simd = parseBooleanArg(wasm.simd ?? args['wasm-enable-simd']);
+  if (typeof simd !== 'undefined' && typeof simd !== 'boolean') {
+    throw new Error('Flag "wasm.simd"/"wasm-enable-simd" must be a boolean value');
+  }
+  const proxy = wasm.proxy = parseBooleanArg(wasm.proxy ?? args['wasm-enable-proxy']);
+  if (typeof proxy !== 'undefined' && typeof proxy !== 'boolean') {
+    throw new Error('Flag "wasm.proxy"/"wasm-enable-proxy" must be a boolean value');
+  }
+  return wasm;
 }
 
 function parseWebglOptions(_args: minimist.ParsedArgs): InferenceSession.WebGLExecutionProviderOption {
@@ -301,39 +308,43 @@ function parseWebglOptions(_args: minimist.ParsedArgs): InferenceSession.WebGLEx
 }
 
 function parseWebglFlags(args: minimist.ParsedArgs): Partial<Env.WebGLFlags> {
-  const contextId = args['webgl-context-id'];
+  const webgl = args.webgl || {};
+  const contextId = webgl.contextId = webgl.contextId ?? args['webgl-context-id'];
   if (contextId !== undefined && contextId !== 'webgl' && contextId !== 'webgl2') {
-    throw new Error('Flag "webgl-context-id" is invalid');
+    throw new Error('Flag "webgl.contextId"/"webgl-context-id" is invalid');
   }
-  const matmulMaxBatchSize = args['webgl-matmul-max-batch-size'];
+  const matmulMaxBatchSize = webgl.matmulMaxBatchSize = webgl.matmulMaxBatchSize ?? args['webgl-matmul-max-batch-size'];
   if (matmulMaxBatchSize !== undefined && typeof matmulMaxBatchSize !== 'number') {
-    throw new Error('Flag "webgl-matmul-max-batch-size" must be a number value');
+    throw new Error('Flag "webgl.matmulMaxBatchSize"/"webgl-matmul-max-batch-size" must be a number value');
   }
-  const textureCacheMode = args['webgl-texture-cache-mode'];
+  const textureCacheMode = webgl.textureCacheMode = webgl.textureCacheMode ?? args['webgl-texture-cache-mode'];
   if (textureCacheMode !== undefined && textureCacheMode !== 'initializerOnly' && textureCacheMode !== 'full') {
-    throw new Error('Flag "webgl-texture-cache-mode" is invalid');
+    throw new Error('Flag "webgl.textureCacheMode"/"webgl-texture-cache-mode" is invalid');
   }
-  const pack = args['webgl-texture-pack-mode'];
+  const pack = webgl.pack = parseBooleanArg(webgl.pack ?? args['webgl-texture-pack-mode']);
   if (pack !== undefined && typeof pack !== 'boolean') {
-    throw new Error('Flag "webgl-texture-pack-mode" is invalid');
+    throw new Error('Flag "webgl.pack"/"webgl-texture-pack-mode" is invalid');
   }
-  const async = args['webgl-async'];
+  const async = webgl.async = parseBooleanArg(webgl.async ?? args['webgl-async']);
   if (async !== undefined && typeof async !== 'boolean') {
-    throw new Error('Flag "webgl-async" is invalid');
+    throw new Error('Flag "webgl.async"/"webgl-async" is invalid');
   }
-  return {contextId, matmulMaxBatchSize, textureCacheMode, pack};
+  return webgl;
 }
 
 function parseWebgpuFlags(args: minimist.ParsedArgs): Partial<Env.WebGpuFlags> {
-  const profilingMode = args['webgpu-profiling-mode'];
+  const webgpu = args.webgpu || {};
+  const profilingMode = (webgpu.profiling = webgpu.profiling ?? {}).mode =
+      webgpu?.profiling?.mode ?? webgpu.profilingMode ?? args['webgpu-profiling-mode'];
   if (profilingMode !== undefined && profilingMode !== 'off' && profilingMode !== 'default') {
     throw new Error('Flag "webgpu-profiling-mode" is invalid');
   }
-  const validateInputContent = args['webgpu-validate-input-content'];
+  const validateInputContent = webgpu.validateInputContent =
+      parseBooleanArg(webgpu.validateInputContent ?? args['webgpu-validate-input-content']);
   if (validateInputContent !== undefined && typeof validateInputContent !== 'boolean') {
     throw new Error('Flag "webgpu-validate-input-content" is invalid');
   }
-  return {profilingMode, validateInputContent};
+  return webgpu;
 }
 
 function parseWebNNOptions(args: minimist.ParsedArgs): InferenceSession.WebNNExecutionProviderOption {
@@ -344,12 +355,11 @@ function parseWebNNOptions(args: minimist.ParsedArgs): InferenceSession.WebNNExe
   return {name: 'webnn', deviceType};
 }
 
-function parseGlobalEnvFlags(args: minimist.ParsedArgs): NonNullable<TestRunnerCliArgs['globalEnvFlags']> {
+function parseGlobalEnvFlags(args: minimist.ParsedArgs) {
   const wasm = parseWasmFlags(args);
   const webgl = parseWebglFlags(args);
   const webgpu = parseWebgpuFlags(args);
-  const cpuFlags = parseCpuFlags(args);
-  return {webgl, wasm, webgpu, ...cpuFlags};
+  return {webgl, wasm, webgpu};
 }
 
 export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs {
@@ -394,15 +404,14 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs
     }
   }
 
-  const globalEnvFlags = parseGlobalEnvFlags(args);
-
   // Options:
   // --log-verbose=<...>
   // --log-info=<...>
   // --log-warning=<...>
   // --log-error=<...>
   const logConfig = parseLogConfig(args);
-  globalEnvFlags.logLevel = logConfig[0]?.config.minimalSeverity;
+  let logLevel = logConfig[0]?.config.minimalSeverity;
+
   // Option: -p, --profile
   const profile = (args.profile || args.p) ? true : false;
   if (profile) {
@@ -410,9 +419,18 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs
     logConfig.push({category: 'Profiler.node', config: {minimalSeverity: 'verbose'}});
     logConfig.push({category: 'Profiler.op', config: {minimalSeverity: 'verbose'}});
     logConfig.push({category: 'Profiler.backend', config: {minimalSeverity: 'verbose'}});
-    globalEnvFlags.logLevel = 'verbose';
+    logLevel = 'verbose';
   }
 
+  // Option: -t, --trace
+  const trace = parseBooleanArg(args.trace || args.t, false);
+
+  // Options:
+  // --wasm.<...>=<...>
+  // --webgl.<...>=<...>
+  // --webgpu.<...>=<...>
+  const globalEnvFlags = {...parseGlobalEnvFlags(args), debug, trace, logLevel};
+
   // Option: -P[=<...>], --perf[=<...>]
   const perfArg = (args.perf || args.P);
   const perf = perfArg ? true : false;
diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts
index 9105c02412e34..ace64e9532b12 100644
--- a/js/web/script/test-runner-cli.ts
+++ b/js/web/script/test-runner-cli.ts
@@ -542,14 +542,13 @@ async function main() {
       npmlog.info('TestRunnerCli.Run', '(4/4) Running karma to start test runner...');
       const webgpu = args.backends.indexOf('webgpu') > -1;
       const webnn = args.backends.indexOf('webnn') > -1;
-      const browser = getBrowserNameFromEnv(
-          args.env,
-          args.bundleMode === 'perf' ? 'perf' :
-              args.debug             ? 'debug' :
-                                       'test',
-          webgpu);
+      const browser = getBrowserNameFromEnv(args.env);
       const karmaArgs = ['karma', 'start', `--browsers ${browser}`];
       const chromiumFlags = ['--enable-features=SharedArrayBuffer', ...args.chromiumFlags];
+      if (args.bundleMode === 'dev' && !args.debug) {
+        // use headless for 'test' mode (when 'perf' and 'debug' are OFF)
+        chromiumFlags.push('--headless=new');
+      }
       if (args.debug) {
         karmaArgs.push('--log-level info --timeout-mocha 9999999');
         chromiumFlags.push('--remote-debugging-port=9333');
@@ -570,6 +569,9 @@ async function main() {
       if (webnn) {
         chromiumFlags.push('--enable-experimental-web-platform-features');
       }
+      if (process.argv.includes('--karma-debug')) {
+        karmaArgs.push('--log-level debug');
+      }
       karmaArgs.push(`--bundle-mode=${args.bundleMode}`);
       karmaArgs.push(...chromiumFlags.map(flag => `--chromium-flags=${flag}`));
       if (browser.startsWith('Edge')) {
@@ -662,10 +664,10 @@ async function main() {
     fs.writeJSONSync(path.join(TEST_ROOT, './testdata-config.json'), config);
   }
 
-  function getBrowserNameFromEnv(env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test', webgpu: boolean) {
+  function getBrowserNameFromEnv(env: TestRunnerCliArgs['env']) {
     switch (env) {
       case 'chrome':
-        return selectChromeBrowser(mode, webgpu);
+        return 'ChromeTest';
       case 'edge':
         return 'EdgeTest';
       case 'firefox':
@@ -680,20 +682,6 @@ async function main() {
         throw new Error(`env "${env}" not supported.`);
     }
   }
-
-  function selectChromeBrowser(mode: 'debug'|'perf'|'test', webgpu: boolean) {
-    if (webgpu) {
-      return 'ChromeTest';
-    } else {
-      switch (mode) {
-        case 'debug':
-        case 'perf':
-          return 'ChromeTest';
-        default:
-          return 'ChromeTestHeadless';
-      }
-    }
-  }
 }
 
 void main();
diff --git a/js/web/test/data/ops/add_zero-sized.jsonc b/js/web/test/data/ops/add_zero-sized.jsonc
new file mode 100644
index 0000000000000..37e08cd7f20ac
--- /dev/null
+++ b/js/web/test/data/ops/add_zero-sized.jsonc
@@ -0,0 +1,31 @@
+[
+  {
+    "name": "Add with no attributes",
+    "operator": "Add",
+    "attributes": [],
+    "cases": [
+      {
+        "name": "T[2,0] T[2,1]",
+        "inputs": [
+          {
+            "data": [],
+            "dims": [2, 0],
+            "type": "float32"
+          },
+          {
+            "data": [1, 2],
+            "dims": [2, 1],
+            "type": "float32"
+          }
+        ],
+        "outputs": [
+          {
+            "data": [],
+            "dims": [2, 0],
+            "type": "float32"
+          }
+        ]
+      }
+    ]
+  }
+]
diff --git a/js/web/test/data/ops/concat_zero-sized.jsonc b/js/web/test/data/ops/concat_zero-sized.jsonc
new file mode 100644
index 0000000000000..be9625145d157
--- /dev/null
+++ b/js/web/test/data/ops/concat_zero-sized.jsonc
@@ -0,0 +1,641 @@
+[
+  {
+    "name": "Concat 2D axis=0",
+    "operator": "Concat",
+    "attributes": [{ "name": "axis", "data": -2, "type": "int" }],
+    "cases": [
+      {
+        "name": "X",
+        "inputs": [
+          {
+            "data": [],
+            "dims": [1, 4, 0, 64],
+            "type": "float32"
+          },
+          {
+            "data": [
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2
+            ],
+            "dims": [1, 4, 36, 64],
+            "type": "float32"
+          }
+        ],
+        "outputs": [
+          {
+            "data": [
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2
+            ],
+            "dims": [1, 4, 36, 64],
+            "type": "float32"
+          }
+        ]
+      }
+    ]
+  },
+  {
+    "name": "Concat 2D axis=1; Preserve dims",
+    "operator": "Concat",
+    "attributes": [
+      {
+        "name": "axis",
+        "data": 0,
+        "type": "int"
+      }
+    ],
+    "cases": [
+      {
+        "name": "Some but not all input tensors are zero-sized",
+        "inputs": [
+          {
+            "data": [],
+            "dims": [0, 1],
+            "type": "float32"
+          },
+          {
+            "data": [1],
+            "dims": [1, 1],
+            "type": "float32"
+          }
+        ],
+        "outputs": [
+          {
+            "data": [1],
+            "dims": [1, 1],
+            "type": "float32"
+          }
+        ]
+      }
+    ]
+  },
+  {
+    "name": "Concat 2D axis=1; Preserve dims",
+    "operator": "Concat",
+    "attributes": [
+      {
+        "name": "axis",
+        "data": 1,
+        "type": "int"
+      }
+    ],
+    "cases": [
+      {
+        "name": "All input tensors are zero-sized",
+        "inputs": [
+          {
+            "data": [],
+            "dims": [0, 0],
+            "type": "float32"
+          },
+          {
+            "data": [],
+            "dims": [0, 1],
+            "type": "float32"
+          },
+          {
+            "data": [],
+            "dims": [0, 2],
+            "type": "float32"
+          },
+          {
+            "data": [],
+            "dims": [0, 3],
+            "type": "float32"
+          }
+        ],
+        "outputs": [
+          {
+            "data": [],
+            "dims": [0, 6],
+            "type": "float32"
+          }
+        ]
+      }
+    ]
+  }
+]
diff --git a/js/web/test/data/ops/instance-norm.jsonc b/js/web/test/data/ops/instance-norm.jsonc
index e89ac2da3795f..f28b016d47ab9 100644
--- a/js/web/test/data/ops/instance-norm.jsonc
+++ b/js/web/test/data/ops/instance-norm.jsonc
@@ -224,5 +224,85 @@
         ]
       }
     ]
+  },
+  {
+    "name": "Simple test with NHWC, components 1, buffer reuse",
+    "operator": "InstanceNormalization",
+    "inputShapeDefinitions": "rankOnly",
+    "opset": {
+      "domain": "",
+      "version": 17
+    },
+    "cases": [
+      {
+        "name": "Simple test",
+        "inputs": [
+          {
+            "data": [1, 2, 3, 4, 5, 6],
+            "dims": [2, 3, 1, 1],
+            "type": "float32"
+          },
+          {
+            "data": [1, 2, 3],
+            "dims": [3],
+            "type": "float32"
+          },
+          {
+            "data": [4, 5, 6],
+            "dims": [3],
+            "type": "float32"
+          }
+        ],
+        "outputs": [
+          {
+            "data": [4, 5, 6, 4, 5, 6],
+            "dims": [2, 3, 1, 1],
+            "type": "float32"
+          }
+        ]
+      }
+    ]
+  },
+  {
+    "name": "Simple test with NHWC, components 2, buffer reuse",
+    "operator": "InstanceNormalization",
+    "inputShapeDefinitions": "rankOnly",
+    "opset": {
+      "domain": "",
+      "version": 17
+    },
+    "cases": [
+      {
+        "name": "Simple test",
+        "inputs": [
+          {
+            "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4, 3, 2],
+            "dims": [1, 6, 1, 3],
+            "type": "float32"
+          },
+          {
+            "data": [1, 2, 3, 4, 5, 6],
+            "dims": [6],
+            "type": "float32"
+          },
+          {
+            "data": [4, 5, 6, 7, 8, 9],
+            "dims": [6],
+            "type": "float32"
+          }
+        ],
+        "outputs": [
+          {
+            "data": [
+              2.775264263153076, 4, 5.224735260009766, 2.5505285263061523, 5, 7.449470520019531, 2.325794219970703, 6,
+              9.674205780029297, 11.898944854736328, 7, 2.1010589599609375, 14.123676300048828, 8, 1.876321792602539,
+              16.348413467407227, 9, 1.6515865325927734
+            ],
+            "dims": [1, 6, 1, 3],
+            "type": "float32"
+          }
+        ]
+      }
+    ]
   }
 ]
diff --git a/js/web/test/data/ops/matmulnbits.jsonc b/js/web/test/data/ops/matmulnbits.jsonc
new file mode 100644
index 0000000000000..175be78cc0818
--- /dev/null
+++ b/js/web/test/data/ops/matmulnbits.jsonc
@@ -0,0 +1,1584 @@
+[
+  {
+    "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4",
+    "operator": "MatMulNBits",
+    "opset": { "domain": "com.microsoft", "version": 1 },
+    "attributes": [
+      { "name": "K", "data": 16, "type": "int" },
+      { "name": "N", "data": 8, "type": "int" },
+      { "name": "block_size", "data": 16, "type": "int" },
+      { "name": "bits", "data": 4, "type": "int" }
+    ],
+    "cases": [
+      {
+        "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric",
+        "inputs": [
+          {
+            "data": [
+              0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+              29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
+              55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
+              81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
+              106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
+              127
+            ],
+            "dims": [8, 16],
+            "type": "float32"
+          },
+          {
+            "dims": [8, 1, 8],
+            "type": "uint8",
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64
+            ]
+          },
+          {
+            "dims": [8],
+            "type": "float32",
+            "data": [0, 1, 2, 3, 4, 5, 6, 7]
+          }
+        ],
+        "outputs": [
+          {
+            "dims": [8, 8],
+            "type": "float32",
+            "data": [
+              0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0,
+              -1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735,
+              0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232,
+              -11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032,
+              -16405, -48288, -16247
+            ]
+          }
+        ]
+      }
+    ]
+  },
+  {
+    "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4",
+    "operator": "MatMulNBits",
+    "opset": { "domain": "com.microsoft", "version": 1 },
+    "attributes": [
+      { "name": "K", "data": 16, "type": "int" },
+      { "name": "N", "data": 16, "type": "int" },
+      { "name": "block_size", "data": 16, "type": "int" },
+      { "name": "bits", "data": 4, "type": "int" }
+    ],
+    "cases": [
+      {
+        "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric",
+        "inputs": [
+          {
+            "data": [
+              0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+              29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
+              55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
+              81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
+              106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
+              127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
+              148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
+              169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
+              190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
+              211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
+              232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252,
+              253, 254, 255
+            ],
+            "dims": [16, 16],
+            "type": "float32"
+          },
+          {
+            "dims": [16, 1, 8],
+            "type": "uint8",
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128
+            ]
+          },
+          {
+            "dims": [16],
+            "type": "float32",
+            "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
+          }
+        ],
+        "outputs": [
+          {
+            "dims": [16, 16],
+            "type": "float32",
+            "data": [
+              0, -385, -1120, -963, -1984, -1285, -2592, -1351, -2944, -1161, -3040, -715, -2880, -13, -2464, 945, 0,
+              -1073, -3808, -2643, -6848, -3445, -9120, -3479, -10624, -2745, -11360, -1243, -11328, 1027, -10528, 4065,
+              0, -1761, -6496, -4323, -11712, -5605, -15648, -5607, -18304, -4329, -19680, -1771, -19776, 2067, -18592,
+              7185, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, -25984, -5913, -28000, -2299, -28224, 3107,
+              -26656, 10305, 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, -33664, -7497, -36320, -2827,
+              -36672, 4147, -34720, 13425, 0, -3825, -14560, -9363, -26304, -12085, -35232, -11991, -41344, -9081,
+              -44640, -3355, -45120, 5187, -42784, 16545, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119,
+              -49024, -10665, -52960, -3883, -53568, 6227, -50848, 19665, 0, -5201, -19936, -12723, -36032, -16405,
+              -48288, -16247, -56704, -12249, -61280, -4411, -62016, 7267, -58912, 22785, 0, -5889, -22624, -14403,
+              -40896, -18565, -54816, -18375, -64384, -13833, -69600, -4939, -70464, 8307, -66976, 25905, 0, -6577,
+              -25312, -16083, -45760, -20725, -61344, -20503, -72064, -15417, -77920, -5467, -78912, 9347, -75040,
+              29025, 0, -7265, -28000, -17763, -50624, -22885, -67872, -22631, -79744, -17001, -86240, -5995, -87360,
+              10387, -83104, 32145, 0, -7953, -30688, -19443, -55488, -25045, -74400, -24759, -87424, -18585, -94560,
+              -6523, -95808, 11427, -91168, 35265, 0, -8641, -33376, -21123, -60352, -27205, -80928, -26887, -95104,
+              -20169, -102880, -7051, -104256, 12467, -99232, 38385, 0, -9329, -36064, -22803, -65216, -29365, -87456,
+              -29015, -102784, -21753, -111200, -7579, -112704, 13507, -107296, 41505, 0, -10017, -38752, -24483,
+              -70080, -31525, -93984, -31143, -110464, -23337, -119520, -8107, -121152, 14547, -115360, 44625, 0,
+              -10705, -41440, -26163, -74944, -33685, -100512, -33271, -118144, -24921, -127840, -8635, -129600, 15587,
+              -123424, 47745
+            ]
+          }
+        ]
+      }
+    ]
+  },
+  {
+    "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4",
+    "operator": "MatMulNBits",
+    "opset": { "domain": "com.microsoft", "version": 1 },
+    "attributes": [
+      { "name": "K", "data": 16, "type": "int" },
+      { "name": "N", "data": 16, "type": "int" },
+      { "name": "block_size", "data": 16, "type": "int" },
+      { "name": "bits", "data": 4, "type": "int" }
+    ],
+    "cases": [
+      {
+        "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; asymmetric",
+        "inputs": [
+          {
+            "data": [
+              0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+              29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
+              55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
+              81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
+              106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
+              127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
+              148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
+              169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
+              190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
+              211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
+              232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252,
+              253, 254, 255
+            ],
+            "dims": [16, 16],
+            "type": "float32"
+          },
+          {
+            "dims": [16, 1, 8],
+            "type": "uint8",
+            "data": [
+              0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+              29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
+              55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
+              81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
+              106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
+              127
+            ]
+          },
+          {
+            "dims": [16],
+            "type": "float32",
+            "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
+          },
+          {
+            "dims": [16],
+            "type": "uint8",
+            "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128]
+          }
+        ],
+        "outputs": [
+          {
+            "dims": [16, 16],
+            "type": "float32",
+            "data": [
+              0, 728, 688, 2376, 1632, 4280, 2832, 6440, 4288, 8856, 6000, 11528, 7968, 14456, 10192, 17640, 0, 2200,
+              1840, 7176, 4448, 12920, 7824, 19432, 11968, 26712, 16880, 34760, 22560, 43576, 29008, 53160, 0, 3672,
+              2992, 11976, 7264, 21560, 12816, 32424, 19648, 44568, 27760, 57992, 37152, 72696, 47824, 88680, 0, 5144,
+              4144, 16776, 10080, 30200, 17808, 45416, 27328, 62424, 38640, 81224, 51744, 101816, 66640, 124200, 0,
+              6616, 5296, 21576, 12896, 38840, 22800, 58408, 35008, 80280, 49520, 104456, 66336, 130936, 85456, 159720,
+              0, 8088, 6448, 26376, 15712, 47480, 27792, 71400, 42688, 98136, 60400, 127688, 80928, 160056, 104272,
+              195240, 0, 9560, 7600, 31176, 18528, 56120, 32784, 84392, 50368, 115992, 71280, 150920, 95520, 189176,
+              123088, 230760, 0, 11032, 8752, 35976, 21344, 64760, 37776, 97384, 58048, 133848, 82160, 174152, 110112,
+              218296, 141904, 266280, 0, 12504, 9904, 40776, 24160, 73400, 42768, 110376, 65728, 151704, 93040, 197384,
+              124704, 247416, 160720, 301800, 0, 13976, 11056, 45576, 26976, 82040, 47760, 123368, 73408, 169560,
+              103920, 220616, 139296, 276536, 179536, 337320, 0, 15448, 12208, 50376, 29792, 90680, 52752, 136360,
+              81088, 187416, 114800, 243848, 153888, 305656, 198352, 372840, 0, 16920, 13360, 55176, 32608, 99320,
+              57744, 149352, 88768, 205272, 125680, 267080, 168480, 334776, 217168, 408360, 0, 18392, 14512, 59976,
+              35424, 107960, 62736, 162344, 96448, 223128, 136560, 290312, 183072, 363896, 235984, 443880, 0, 19864,
+              15664, 64776, 38240, 116600, 67728, 175336, 104128, 240984, 147440, 313544, 197664, 393016, 254800,
+              479400, 0, 21336, 16816, 69576, 41056, 125240, 72720, 188328, 111808, 258840, 158320, 336776, 212256,
+              422136, 273616, 514920, 0, 22808, 17968, 74376, 43872, 133880, 77712, 201320, 119488, 276696, 169200,
+              360008, 226848, 451256, 292432, 550440
+            ]
+          }
+        ]
+      }
+    ]
+  },
+  {
+    "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4",
+    "operator": "MatMulNBits",
+    "opset": { "domain": "com.microsoft", "version": 1 },
+    "attributes": [
+      { "name": "K", "data": 16, "type": "int" },
+      { "name": "N", "data": 32, "type": "int" },
+      { "name": "block_size", "data": 16, "type": "int" },
+      { "name": "bits", "data": 4, "type": "int" }
+    ],
+    "cases": [
+      {
+        "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; symmetric",
+        "inputs": [
+          {
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
+              275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295,
+              296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316,
+              317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337,
+              338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358,
+              359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379,
+              380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400,
+              401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421,
+              422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442,
+              443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463,
+              464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484,
+              485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505,
+              506, 507, 508, 509, 510, 511, 512
+            ],
+            "dims": [32, 16],
+            "type": "float32"
+          },
+          {
+            "dims": [32, 1, 8],
+            "type": "uint8",
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256
+            ]
+          },
+          {
+            "dims": [32],
+            "type": "float32",
+            "data": [
+              0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+              29, 30, 31
+            ]
+          }
+        ],
+        "outputs": [
+          {
+            "dims": [32, 32],
+            "type": "float32",
+            "data": [
+              0, -428, -1288, -1068, -2288, -1420, -3000, -1484, -3424, -1260, -3560, -748, -3408, 52, -2968, 1140,
+              -2272, 2516, -1224, 4180, 80, 6132, 1672, 8372, 3552, 10900, 5720, 13716, 8176, 16820, 10920, 12276, 0,
+              -1116, -3976, -2748, -7152, -3580, -9528, -3612, -11104, -2844, -11880, -1276, -11856, 1092, -11032, 4260,
+              -8160, 8228, -6984, 12996, -3760, 18564, 264, 24932, 5088, 32100, 10712, 40068, 17136, 48836, 24360,
+              42532, 0, -1804, -6664, -4428, -12016, -5740, -16056, -5740, -18784, -4428, -20200, -1804, -20304, 2132,
+              -19096, 7380, -14048, 13940, -12744, 21812, -7600, 30996, -1144, 41492, 6624, 53300, 15704, 66420, 26096,
+              80852, 37800, 72788, 0, -2492, -9352, -6108, -16880, -7900, -22584, -7868, -26464, -6012, -28520, -2332,
+              -28752, 3172, -27160, 10500, -19936, 19652, -18504, 30628, -11440, 43428, -2552, 58052, 8160, 74500,
+              20696, 92772, 35056, 112868, 51240, 103044, 0, -3180, -12040, -7788, -21744, -10060, -29112, -9996,
+              -34144, -7596, -36840, -2860, -37200, 4212, -35224, 13620, -25824, 25364, -24264, 39444, -15280, 55860,
+              -3960, 74612, 9696, 95700, 25688, 119124, 44016, 144884, 64680, 133300, 0, -3868, -14728, -9468, -26608,
+              -12220, -35640, -12124, -41824, -9180, -45160, -3388, -45648, 5252, -43288, 16740, -31712, 31076, -30024,
+              48260, -19120, 68292, -5368, 91172, 11232, 116900, 30680, 145476, 52976, 176900, 78120, 163556, 0, -4556,
+              -17416, -11148, -31472, -14380, -42168, -14252, -49504, -10764, -53480, -3916, -54096, 6292, -51352,
+              19860, -37600, 36788, -35784, 57076, -22960, 80724, -6776, 107732, 12768, 138100, 35672, 171828, 61936,
+              208916, 91560, 193812, 0, -5244, -20104, -12828, -36336, -16540, -48696, -16380, -57184, -12348, -61800,
+              -4444, -62544, 7332, -59416, 22980, -43488, 42500, -41544, 65892, -26800, 93156, -8184, 124292, 14304,
+              159300, 40664, 198180, 70896, 240932, 105000, 224068, 0, -5932, -22792, -14508, -41200, -18700, -55224,
+              -18508, -64864, -13932, -70120, -4972, -70992, 8372, -67480, 26100, -49376, 48212, -47304, 74708, -30640,
+              105588, -9592, 140852, 15840, 180500, 45656, 224532, 79856, 272948, 118440, 254324, 0, -6620, -25480,
+              -16188, -46064, -20860, -61752, -20636, -72544, -15516, -78440, -5500, -79440, 9412, -75544, 29220,
+              -55264, 53924, -53064, 83524, -34480, 118020, -11000, 157412, 17376, 201700, 50648, 250884, 88816, 304964,
+              131880, 284580, 0, -7308, -28168, -17868, -50928, -23020, -68280, -22764, -80224, -17100, -86760, -6028,
+              -87888, 10452, -83608, 32340, -61152, 59636, -58824, 92340, -38320, 130452, -12408, 173972, 18912, 222900,
+              55640, 277236, 97776, 336980, 145320, 314836, 0, -7996, -30856, -19548, -55792, -25180, -74808, -24892,
+              -87904, -18684, -95080, -6556, -96336, 11492, -91672, 35460, -67040, 65348, -64584, 101156, -42160,
+              142884, -13816, 190532, 20448, 244100, 60632, 303588, 106736, 368996, 158760, 345092, 0, -8684, -33544,
+              -21228, -60656, -27340, -81336, -27020, -95584, -20268, -103400, -7084, -104784, 12532, -99736, 38580,
+              -72928, 71060, -70344, 109972, -46000, 155316, -15224, 207092, 21984, 265300, 65624, 329940, 115696,
+              401012, 172200, 375348, 0, -9372, -36232, -22908, -65520, -29500, -87864, -29148, -103264, -21852,
+              -111720, -7612, -113232, 13572, -107800, 41700, -78816, 76772, -76104, 118788, -49840, 167748, -16632,
+              223652, 23520, 286500, 70616, 356292, 124656, 433028, 185640, 405604, 0, -10060, -38920, -24588, -70384,
+              -31660, -94392, -31276, -110944, -23436, -120040, -8140, -121680, 14612, -115864, 44820, -84704, 82484,
+              -81864, 127604, -53680, 180180, -18040, 240212, 25056, 307700, 75608, 382644, 133616, 465044, 199080,
+              435860, 0, -10748, -41608, -26268, -75248, -33820, -100920, -33404, -118624, -25020, -128360, -8668,
+              -130128, 15652, -123928, 47940, -90592, 88196, -87624, 136420, -57520, 192612, -19448, 256772, 26592,
+              328900, 80600, 408996, 142576, 497060, 212520, 466116, 0, -11436, -44296, -27948, -80112, -35980, -107448,
+              -35532, -126304, -26604, -136680, -9196, -138576, 16692, -131992, 51060, -96480, 93908, -93384, 145236,
+              -61360, 205044, -20856, 273332, 28128, 350100, 85592, 435348, 151536, 529076, 225960, 496372, 0, -12124,
+              -46984, -29628, -84976, -38140, -113976, -37660, -133984, -28188, -145000, -9724, -147024, 17732, -140056,
+              54180, -102368, 99620, -99144, 154052, -65200, 217476, -22264, 289892, 29664, 371300, 90584, 461700,
+              160496, 561092, 239400, 526628, 0, -12812, -49672, -31308, -89840, -40300, -120504, -39788, -141664,
+              -29772, -153320, -10252, -155472, 18772, -148120, 57300, -108256, 105332, -104904, 162868, -69040, 229908,
+              -23672, 306452, 31200, 392500, 95576, 488052, 169456, 593108, 252840, 556884, 0, -13500, -52360, -32988,
+              -94704, -42460, -127032, -41916, -149344, -31356, -161640, -10780, -163920, 19812, -156184, 60420,
+              -114144, 111044, -110664, 171684, -72880, 242340, -25080, 323012, 32736, 413700, 100568, 514404, 178416,
+              625124, 266280, 587140, 0, -14188, -55048, -34668, -99568, -44620, -133560, -44044, -157024, -32940,
+              -169960, -11308, -172368, 20852, -164248, 63540, -120032, 116756, -116424, 180500, -76720, 254772, -26488,
+              339572, 34272, 434900, 105560, 540756, 187376, 657140, 279720, 617396, 0, -14876, -57736, -36348, -104432,
+              -46780, -140088, -46172, -164704, -34524, -178280, -11836, -180816, 21892, -172312, 66660, -125920,
+              122468, -122184, 189316, -80560, 267204, -27896, 356132, 35808, 456100, 110552, 567108, 196336, 689156,
+              293160, 647652, 0, -15564, -60424, -38028, -109296, -48940, -146616, -48300, -172384, -36108, -186600,
+              -12364, -189264, 22932, -180376, 69780, -131808, 128180, -127944, 198132, -84400, 279636, -29304, 372692,
+              37344, 477300, 115544, 593460, 205296, 721172, 306600, 677908, 0, -16252, -63112, -39708, -114160, -51100,
+              -153144, -50428, -180064, -37692, -194920, -12892, -197712, 23972, -188440, 72900, -137696, 133892,
+              -133704, 206948, -88240, 292068, -30712, 389252, 38880, 498500, 120536, 619812, 214256, 753188, 320040,
+              708164, 0, -16940, -65800, -41388, -119024, -53260, -159672, -52556, -187744, -39276, -203240, -13420,
+              -206160, 25012, -196504, 76020, -143584, 139604, -139464, 215764, -92080, 304500, -32120, 405812, 40416,
+              519700, 125528, 646164, 223216, 785204, 333480, 738420, 0, -17628, -68488, -43068, -123888, -55420,
+              -166200, -54684, -195424, -40860, -211560, -13948, -214608, 26052, -204568, 79140, -149472, 145316,
+              -145224, 224580, -95920, 316932, -33528, 422372, 41952, 540900, 130520, 672516, 232176, 817220, 346920,
+              768676, 0, -18316, -71176, -44748, -128752, -57580, -172728, -56812, -203104, -42444, -219880, -14476,
+              -223056, 27092, -212632, 82260, -155360, 151028, -150984, 233396, -99760, 329364, -34936, 438932, 43488,
+              562100, 135512, 698868, 241136, 849236, 360360, 798932, 0, -19004, -73864, -46428, -133616, -59740,
+              -179256, -58940, -210784, -44028, -228200, -15004, -231504, 28132, -220696, 85380, -161248, 156740,
+              -156744, 242212, -103600, 341796, -36344, 455492, 45024, 583300, 140504, 725220, 250096, 881252, 373800,
+              829188, 0, -19692, -76552, -48108, -138480, -61900, -185784, -61068, -218464, -45612, -236520, -15532,
+              -239952, 29172, -228760, 88500, -167136, 162452, -162504, 251028, -107440, 354228, -37752, 472052, 46560,
+              604500, 145496, 751572, 259056, 913268, 387240, 859444, 0, -20380, -79240, -49788, -143344, -64060,
+              -192312, -63196, -226144, -47196, -244840, -16060, -248400, 30212, -236824, 91620, -173024, 168164,
+              -168264, 259844, -111280, 366660, -39160, 488612, 48096, 625700, 150488, 777924, 268016, 945284, 400680,
+              889700, 0, -21068, -81928, -51468, -148208, -66220, -198840, -65324, -233824, -48780, -253160, -16588,
+              -256848, 31252, -244888, 94740, -178912, 173876, -174024, 268660, -115120, 379092, -40568, 505172, 49632,
+              646900, 155480, 804276, 276976, 977300, 414120, 919956, 0, -21756, -84616, -53148, -153072, -68380,
+              -205368, -67452, -241504, -50364, -261480, -17116, -265296, 32292, -252952, 97860, -184800, 179588,
+              -179784, 277476, -118960, 391524, -41976, 521732, 51168, 668100, 160472, 830628, 285936, 1009316, 427560,
+              950212
+            ]
+          }
+        ]
+      }
+    ]
+  },
+  {
+    "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4",
+    "operator": "MatMulNBits",
+    "opset": { "domain": "com.microsoft", "version": 1 },
+    "attributes": [
+      { "name": "K", "data": 16, "type": "int" },
+      { "name": "N", "data": 32, "type": "int" },
+      { "name": "block_size", "data": 16, "type": "int" },
+      { "name": "bits", "data": 4, "type": "int" }
+    ],
+    "cases": [
+      {
+        "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; asymmetric",
+        "inputs": [
+          {
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
+              275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295,
+              296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316,
+              317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337,
+              338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358,
+              359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379,
+              380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400,
+              401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421,
+              422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442,
+              443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463,
+              464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484,
+              485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505,
+              506, 507, 508, 509, 510, 511, 512
+            ],
+            "dims": [32, 16],
+            "type": "float32"
+          },
+          {
+            "dims": [32, 1, 8],
+            "type": "uint8",
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256
+            ]
+          },
+          {
+            "dims": [32],
+            "type": "float32",
+            "data": [
+              0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+              29, 30, 31
+            ]
+          },
+          {
+            "dims": [32],
+            "type": "uint8",
+            "data": [
+              128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
+              128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128
+            ]
+          }
+        ],
+        "outputs": [
+          {
+            "dims": [32, 32],
+            "type": "float32",
+            "data": [
+              0, 660, 888, 2196, 2064, 4020, 3528, 6132, 5280, 8532, 7320, 11220, 9648, 14196, 12264, 17460, 15136,
+              21012, 18360, 24852, 21840, 28980, 25608, 33396, 29664, 38100, 34008, 43092, 38640, 48372, 43560, 46004,
+              0, 2020, 2296, 6660, 5392, 12100, 9288, 18340, 13984, 25380, 19480, 33220, 25776, 41860, 32872, 51300,
+              42016, 61540, 49464, 72580, 58960, 84420, 69256, 97060, 80352, 110500, 92248, 124740, 104944, 139780,
+              118440, 139748, 0, 3380, 3704, 11124, 8720, 20180, 15048, 30548, 22688, 42228, 31640, 55220, 41904, 69524,
+              53480, 85140, 68896, 102068, 80568, 120308, 96080, 139860, 112904, 160724, 131040, 182900, 150488, 206388,
+              171248, 231188, 193320, 233492, 0, 4740, 5112, 15588, 12048, 28260, 20808, 42756, 31392, 59076, 43800,
+              77220, 58032, 97188, 74088, 118980, 95776, 142596, 111672, 168036, 133200, 195300, 156552, 224388, 181728,
+              255300, 208728, 288036, 237552, 322596, 268200, 327236, 0, 6100, 6520, 20052, 15376, 36340, 26568, 54964,
+              40096, 75924, 55960, 99220, 74160, 124852, 94696, 152820, 122656, 183124, 142776, 215764, 170320, 250740,
+              200200, 288052, 232416, 327700, 266968, 369684, 303856, 414004, 343080, 420980, 0, 7460, 7928, 24516,
+              18704, 44420, 32328, 67172, 48800, 92772, 68120, 121220, 90288, 152516, 115304, 186660, 149536, 223652,
+              173880, 263492, 207440, 306180, 243848, 351716, 283104, 400100, 325208, 451332, 370160, 505412, 417960,
+              514724, 0, 8820, 9336, 28980, 22032, 52500, 38088, 79380, 57504, 109620, 80280, 143220, 106416, 180180,
+              135912, 220500, 176416, 264180, 204984, 311220, 244560, 361620, 287496, 415380, 333792, 472500, 383448,
+              532980, 436464, 596820, 492840, 608468, 0, 10180, 10744, 33444, 25360, 60580, 43848, 91588, 66208, 126468,
+              92440, 165220, 122544, 207844, 156520, 254340, 203296, 304708, 236088, 358948, 281680, 417060, 331144,
+              479044, 384480, 544900, 441688, 614628, 502768, 688228, 567720, 702212, 0, 11540, 12152, 37908, 28688,
+              68660, 49608, 103796, 74912, 143316, 104600, 187220, 138672, 235508, 177128, 288180, 230176, 345236,
+              267192, 406676, 318800, 472500, 374792, 542708, 435168, 617300, 499928, 696276, 569072, 779636, 642600,
+              795956, 0, 12900, 13560, 42372, 32016, 76740, 55368, 116004, 83616, 160164, 116760, 209220, 154800,
+              263172, 197736, 322020, 257056, 385764, 298296, 454404, 355920, 527940, 418440, 606372, 485856, 689700,
+              558168, 777924, 635376, 871044, 717480, 889700, 0, 14260, 14968, 46836, 35344, 84820, 61128, 128212,
+              92320, 177012, 128920, 231220, 170928, 290836, 218344, 355860, 283936, 426292, 329400, 502132, 393040,
+              583380, 462088, 670036, 536544, 762100, 616408, 859572, 701680, 962452, 792360, 983444, 0, 15620, 16376,
+              51300, 38672, 92900, 66888, 140420, 101024, 193860, 141080, 253220, 187056, 318500, 238952, 389700,
+              310816, 466820, 360504, 549860, 430160, 638820, 505736, 733700, 587232, 834500, 674648, 941220, 767984,
+              1053860, 867240, 1077188, 0, 16980, 17784, 55764, 42000, 100980, 72648, 152628, 109728, 210708, 153240,
+              275220, 203184, 346164, 259560, 423540, 337696, 507348, 391608, 597588, 467280, 694260, 549384, 797364,
+              637920, 906900, 732888, 1022868, 834288, 1145268, 942120, 1170932, 0, 18340, 19192, 60228, 45328, 109060,
+              78408, 164836, 118432, 227556, 165400, 297220, 219312, 373828, 280168, 457380, 364576, 547876, 422712,
+              645316, 504400, 749700, 593032, 861028, 688608, 979300, 791128, 1104516, 900592, 1236676, 1017000,
+              1264676, 0, 19700, 20600, 64692, 48656, 117140, 84168, 177044, 127136, 244404, 177560, 319220, 235440,
+              401492, 300776, 491220, 391456, 588404, 453816, 693044, 541520, 805140, 636680, 924692, 739296, 1051700,
+              849368, 1186164, 966896, 1328084, 1091880, 1358420, 0, 21060, 22008, 69156, 51984, 125220, 89928, 189252,
+              135840, 261252, 189720, 341220, 251568, 429156, 321384, 525060, 418336, 628932, 484920, 740772, 578640,
+              860580, 680328, 988356, 789984, 1124100, 907608, 1267812, 1033200, 1419492, 1166760, 1452164, 0, 22420,
+              23416, 73620, 55312, 133300, 95688, 201460, 144544, 278100, 201880, 363220, 267696, 456820, 341992,
+              558900, 445216, 669460, 516024, 788500, 615760, 916020, 723976, 1052020, 840672, 1196500, 965848, 1349460,
+              1099504, 1510900, 1241640, 1545908, 0, 23780, 24824, 78084, 58640, 141380, 101448, 213668, 153248, 294948,
+              214040, 385220, 283824, 484484, 362600, 592740, 472096, 709988, 547128, 836228, 652880, 971460, 767624,
+              1115684, 891360, 1268900, 1024088, 1431108, 1165808, 1602308, 1316520, 1639652, 0, 25140, 26232, 82548,
+              61968, 149460, 107208, 225876, 161952, 311796, 226200, 407220, 299952, 512148, 383208, 626580, 498976,
+              750516, 578232, 883956, 690000, 1026900, 811272, 1179348, 942048, 1341300, 1082328, 1512756, 1232112,
+              1693716, 1391400, 1733396, 0, 26500, 27640, 87012, 65296, 157540, 112968, 238084, 170656, 328644, 238360,
+              429220, 316080, 539812, 403816, 660420, 525856, 791044, 609336, 931684, 727120, 1082340, 854920, 1243012,
+              992736, 1413700, 1140568, 1594404, 1298416, 1785124, 1466280, 1827140, 0, 27860, 29048, 91476, 68624,
+              165620, 118728, 250292, 179360, 345492, 250520, 451220, 332208, 567476, 424424, 694260, 552736, 831572,
+              640440, 979412, 764240, 1137780, 898568, 1306676, 1043424, 1486100, 1198808, 1676052, 1364720, 1876532,
+              1541160, 1920884, 0, 29220, 30456, 95940, 71952, 173700, 124488, 262500, 188064, 362340, 262680, 473220,
+              348336, 595140, 445032, 728100, 579616, 872100, 671544, 1027140, 801360, 1193220, 942216, 1370340,
+              1094112, 1558500, 1257048, 1757700, 1431024, 1967940, 1616040, 2014628, 0, 30580, 31864, 100404, 75280,
+              181780, 130248, 274708, 196768, 379188, 274840, 495220, 364464, 622804, 465640, 761940, 606496, 912628,
+              702648, 1074868, 838480, 1248660, 985864, 1434004, 1144800, 1630900, 1315288, 1839348, 1497328, 2059348,
+              1690920, 2108372, 0, 31940, 33272, 104868, 78608, 189860, 136008, 286916, 205472, 396036, 287000, 517220,
+              380592, 650468, 486248, 795780, 633376, 953156, 733752, 1122596, 875600, 1304100, 1029512, 1497668,
+              1195488, 1703300, 1373528, 1920996, 1563632, 2150756, 1765800, 2202116, 0, 33300, 34680, 109332, 81936,
+              197940, 141768, 299124, 214176, 412884, 299160, 539220, 396720, 678132, 506856, 829620, 660256, 993684,
+              764856, 1170324, 912720, 1359540, 1073160, 1561332, 1246176, 1775700, 1431768, 2002644, 1629936, 2242164,
+              1840680, 2295860, 0, 34660, 36088, 113796, 85264, 206020, 147528, 311332, 222880, 429732, 311320, 561220,
+              412848, 705796, 527464, 863460, 687136, 1034212, 795960, 1218052, 949840, 1414980, 1116808, 1624996,
+              1296864, 1848100, 1490008, 2084292, 1696240, 2333572, 1915560, 2389604, 0, 36020, 37496, 118260, 88592,
+              214100, 153288, 323540, 231584, 446580, 323480, 583220, 428976, 733460, 548072, 897300, 714016, 1074740,
+              827064, 1265780, 986960, 1470420, 1160456, 1688660, 1347552, 1920500, 1548248, 2165940, 1762544, 2424980,
+              1990440, 2483348, 0, 37380, 38904, 122724, 91920, 222180, 159048, 335748, 240288, 463428, 335640, 605220,
+              445104, 761124, 568680, 931140, 740896, 1115268, 858168, 1313508, 1024080, 1525860, 1204104, 1752324,
+              1398240, 1992900, 1606488, 2247588, 1828848, 2516388, 2065320, 2577092, 0, 38740, 40312, 127188, 95248,
+              230260, 164808, 347956, 248992, 480276, 347800, 627220, 461232, 788788, 589288, 964980, 767776, 1155796,
+              889272, 1361236, 1061200, 1581300, 1247752, 1815988, 1448928, 2065300, 1664728, 2329236, 1895152, 2607796,
+              2140200, 2670836, 0, 40100, 41720, 131652, 98576, 238340, 170568, 360164, 257696, 497124, 359960, 649220,
+              477360, 816452, 609896, 998820, 794656, 1196324, 920376, 1408964, 1098320, 1636740, 1291400, 1879652,
+              1499616, 2137700, 1722968, 2410884, 1961456, 2699204, 2215080, 2764580, 0, 41460, 43128, 136116, 101904,
+              246420, 176328, 372372, 266400, 513972, 372120, 671220, 493488, 844116, 630504, 1032660, 821536, 1236852,
+              951480, 1456692, 1135440, 1692180, 1335048, 1943316, 1550304, 2210100, 1781208, 2492532, 2027760, 2790612,
+              2289960, 2858324, 0, 42820, 44536, 140580, 105232, 254500, 182088, 384580, 275104, 530820, 384280, 693220,
+              509616, 871780, 651112, 1066500, 848416, 1277380, 982584, 1504420, 1172560, 1747620, 1378696, 2006980,
+              1600992, 2282500, 1839448, 2574180, 2094064, 2882020, 2364840, 2952068
+            ]
+          }
+        ]
+      }
+    ]
+  },
+  {
+    "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4",
+    "operator": "MatMulNBits",
+    "opset": { "domain": "com.microsoft", "version": 1 },
+    "attributes": [
+      { "name": "K", "data": 32, "type": "int" },
+      { "name": "N", "data": 16, "type": "int" },
+      { "name": "block_size", "data": 16, "type": "int" },
+      { "name": "bits", "data": 4, "type": "int" }
+    ],
+    "cases": [
+      {
+        "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; symmetric",
+        "inputs": [
+          {
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
+              275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295,
+              296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316,
+              317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337,
+              338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358,
+              359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379,
+              380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400,
+              401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421,
+              422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442,
+              443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463,
+              464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484,
+              485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505,
+              506, 507, 508, 509, 510, 511, 512
+            ],
+            "dims": [16, 32],
+            "type": "float32"
+          },
+          {
+            "dims": [16, 2, 8],
+            "type": "uint8",
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256
+            ]
+          },
+          {
+            "dims": [32],
+            "type": "float32",
+            "data": [
+              0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+              29, 30, 31
+            ]
+          }
+        ],
+        "outputs": [
+          {
+            "dims": [16, 16],
+            "type": "float32",
+            "data": [
+              -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012,
+              53452, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, 56908, 81124,
+              108476, 138964, 140844, -3868, -21508, -33964, -41236, -43324, -40228, -31948, -18484, 5252, 23996, 53012,
+              87212, 126596, 171164, 220916, 228236, -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372,
+              4900, 30108, 70196, 117516, 172068, 233852, 302868, 315628, -6620, -38980, -62060, -75860, -80380, -75620,
+              -61580, -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -7996, -47716, -76108, -93172,
+              -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, 466772, 490412, -9372,
+              -56452, -90156, -110484, -117436, -111012, -91212, -58036, 3844, 48444, 121748, 208428, 308484, 421916,
+              548724, 577804, -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932,
+              238732, 353956, 484604, 630676, 665196, -12124, -73924, -118252, -145108, -154492, -146404, -120844,
+              -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -13500, -82660, -132300, -162420,
+              -173020, -164100, -135660, -87700, 2788, 66780, 173300, 299340, 444900, 609980, 794580, 839980, -14876,
+              -91396, -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668,
+              876532, 927372, -16252, -100132, -160396, -197044, -210076, -199492, -165292, -107476, 2084, 79004,
+              207668, 359948, 535844, 735356, 958484, 1014764, -17628, -108868, -174444, -214356, -228604, -217188,
+              -180108, -117364, 1732, 85116, 224852, 390252, 581316, 798044, 1040436, 1102156, -19004, -117604, -188492,
+              -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, 860732, 1122388,
+              1189548, -20380, -126340, -202540, -248980, -265660, -252580, -209740, -137140, 1028, 97340, 259220,
+              450860, 672260, 923420, 1204340, 1276940, -21756, -135076, -216588, -266292, -284188, -270276, -224556,
+              -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332
+            ]
+          }
+        ]
+      }
+    ]
+  },
+  {
+    "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4",
+    "operator": "MatMulNBits",
+    "opset": { "domain": "com.microsoft", "version": 1 },
+    "attributes": [
+      { "name": "K", "data": 32, "type": "int" },
+      { "name": "N", "data": 16, "type": "int" },
+      { "name": "block_size", "data": 16, "type": "int" },
+      { "name": "bits", "data": 4, "type": "int" }
+    ],
+    "cases": [
+      {
+        "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; asymmetric",
+        "inputs": [
+          {
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
+              275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295,
+              296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316,
+              317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337,
+              338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358,
+              359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379,
+              380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400,
+              401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421,
+              422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442,
+              443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463,
+              464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484,
+              485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505,
+              506, 507, 508, 509, 510, 511, 512
+            ],
+            "dims": [16, 32],
+            "type": "float32"
+          },
+          {
+            "dims": [16, 2, 8],
+            "type": "uint8",
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256
+            ]
+          },
+          {
+            "dims": [32],
+            "type": "float32",
+            "data": [
+              0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+              29, 30, 31
+            ]
+          },
+          {
+            "dims": [16],
+            "type": "uint8",
+            "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128]
+          }
+        ],
+        "outputs": [
+          {
+            "dims": [16, 16],
+            "type": "float32",
+            "data": [
+              -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476,
+              86092, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, 170956, 205540, 243260,
+              284116, 296364, -3868, -2948, 3156, 14444, 30916, 52572, 79412, 111436, 153732, 191036, 238612, 291372,
+              349316, 412444, 480756, 506636, -5244, -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876,
+              337716, 411788, 493092, 581628, 677396, 716908, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348,
+              284100, 350716, 436820, 532204, 636868, 750812, 874036, 927180, -7996, -4580, 10164, 36236, 73636, 122364,
+              182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, -9372, -5124, 12500,
+              43500, 87876, 145628, 216756, 301260, 414468, 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724,
+              -10748, -5668, 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196,
+              1258364, 1463956, 1557996, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, 544836, 670076,
+              833236, 1013868, 1211972, 1427548, 1660596, 1768268, -13500, -6756, 19508, 65292, 130596, 215420, 319764,
+              443628, 610020, 749916, 932340, 1134284, 1355748, 1596732, 1857236, 1978540, -14876, -7300, 21844, 72556,
+              144836, 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812,
+              -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, 1130548, 1375116, 1643300,
+              1935100, 2250516, 2399084, -17628, -8388, 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436,
+              1229652, 1495532, 1787076, 2104284, 2447156, 2609356, -19004, -8932, 28852, 94348, 187556, 308476, 457108,
+              633452, 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -20380, -9476, 31188,
+              101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, 2442652, 2840436,
+              3029900, -21756, -10020, 33524, 108876, 216036, 355004, 525780, 728364, 1001124, 1228956, 1526964,
+              1856780, 2218404, 2611836, 3037076, 3240172
+            ]
+          }
+        ]
+      }
+    ]
+  },
+  {
+    "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4",
+    "operator": "MatMulNBits",
+    "opset": { "domain": "com.microsoft", "version": 1 },
+    "attributes": [
+      { "name": "K", "data": 32, "type": "int" },
+      { "name": "N", "data": 32, "type": "int" },
+      { "name": "block_size", "data": 16, "type": "int" },
+      { "name": "bits", "data": 4, "type": "int" }
+    ],
+    "cases": [
+      {
+        "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; symmetric",
+        "inputs": [
+          {
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
+              275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295,
+              296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316,
+              317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337,
+              338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358,
+              359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379,
+              380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400,
+              401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421,
+              422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442,
+              443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463,
+              464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484,
+              485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505,
+              506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526,
+              527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547,
+              548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568,
+              569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589,
+              590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610,
+              611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631,
+              632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652,
+              653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673,
+              674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694,
+              695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715,
+              716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736,
+              737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757,
+              758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778,
+              779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799,
+              800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820,
+              821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841,
+              842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862,
+              863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883,
+              884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904,
+              905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925,
+              926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946,
+              947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967,
+              968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988,
+              989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007,
+              1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024
+            ],
+            "dims": [32, 32],
+            "type": "float32"
+          },
+          {
+            "dims": [32, 2, 8],
+            "type": "uint8",
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
+              275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295,
+              296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316,
+              317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337,
+              338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358,
+              359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379,
+              380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400,
+              401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421,
+              422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442,
+              443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463,
+              464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484,
+              485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505,
+              506, 507, 508, 509, 510, 511, 512
+            ]
+          },
+          {
+            "dims": [64],
+            "type": "float32",
+            "data": [
+              0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+              29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
+              55, 56, 57, 58, 59, 60, 61, 62, 63
+            ]
+          }
+        ],
+        "outputs": [
+          {
+            "dims": [32, 32],
+            "type": "float32",
+            "data": [
+              -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012,
+              53452, -59740, -53956, -47084, -39124, -30076, -19940, -8716, 3596, 16996, 31484, 47060, 63724, 81476,
+              100316, 120244, 109004, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828,
+              56908, 81124, 108476, 138964, 140844, -199356, -184548, -166604, -145524, -121308, -93956, -63468, -29844,
+              6916, 46812, 89844, 136012, 185316, 237756, 293332, 287532, -3868, -21508, -33964, -41236, -43324, -40228,
+              -31948, -18484, 5252, 23996, 53012, 87212, 126596, 171164, 220916, 228236, -338972, -315140, -286124,
+              -251924, -212540, -167972, -118220, -63284, -3164, 62140, 132628, 208300, 289156, 375196, 466420, 466060,
+              -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, 4900, 30108, 70196, 117516, 172068, 233852,
+              302868, 315628, -478588, -445732, -405644, -358324, -303772, -241988, -172972, -96724, -13244, 77468,
+              175412, 280588, 392996, 512636, 639508, 644588, -6620, -38980, -62060, -75860, -80380, -75620, -61580,
+              -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -618204, -576324, -525164, -464724,
+              -395004, -316004, -227724, -130164, -23324, 92796, 218196, 352876, 496836, 650076, 812596, 823116, -7996,
+              -47716, -76108, -93172, -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228,
+              466772, 490412, -757820, -706916, -644684, -571124, -486236, -390020, -282476, -163604, -33404, 108124,
+              260980, 425164, 600676, 787516, 985684, 1001644, -9372, -56452, -90156, -110484, -117436, -111012, -91212,
+              -58036, 3844, 48444, 121748, 208428, 308484, 421916, 548724, 577804, -897436, -837508, -764204, -677524,
+              -577468, -464036, -337228, -197044, -43484, 123452, 303764, 497452, 704516, 924956, 1158772, 1180172,
+              -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, 238732, 353956,
+              484604, 630676, 665196, -1037052, -968100, -883724, -783924, -668700, -538052, -391980, -230484, -53564,
+              138780, 346548, 569740, 808356, 1062396, 1331860, 1358700, -12124, -73924, -118252, -145108, -154492,
+              -146404, -120844, -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -1176668, -1098692,
+              -1003244, -890324, -759932, -612068, -446732, -263924, -63644, 154108, 389332, 642028, 912196, 1199836,
+              1504948, 1537228, -13500, -82660, -132300, -162420, -173020, -164100, -135660, -87700, 2788, 66780,
+              173300, 299340, 444900, 609980, 794580, 839980, -1316284, -1229284, -1122764, -996724, -851164, -686084,
+              -501484, -297364, -73724, 169436, 432116, 714316, 1016036, 1337276, 1678036, 1715756, -14876, -91396,
+              -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, 876532,
+              927372, -1455900, -1359876, -1242284, -1103124, -942396, -760100, -556236, -330804, -83804, 184764,
+              474900, 786604, 1119876, 1474716, 1851124, 1894284, -16252, -100132, -160396, -197044, -210076, -199492,
+              -165292, -107476, 2084, 79004, 207668, 359948, 535844, 735356, 958484, 1014764, -1595516, -1490468,
+              -1361804, -1209524, -1033628, -834116, -610988, -364244, -93884, 200092, 517684, 858892, 1223716, 1612156,
+              2024212, 2072812, -17628, -108868, -174444, -214356, -228604, -217188, -180108, -117364, 1732, 85116,
+              224852, 390252, 581316, 798044, 1040436, 1102156, -1735132, -1621060, -1481324, -1315924, -1124860,
+              -908132, -665740, -397684, -103964, 215420, 560468, 931180, 1327556, 1749596, 2197300, 2251340, -19004,
+              -117604, -188492, -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788,
+              860732, 1122388, 1189548, -1874748, -1751652, -1600844, -1422324, -1216092, -982148, -720492, -431124,
+              -114044, 230748, 603252, 1003468, 1431396, 1887036, 2370388, 2429868, -20380, -126340, -202540, -248980,
+              -265660, -252580, -209740, -137140, 1028, 97340, 259220, 450860, 672260, 923420, 1204340, 1276940,
+              -2014364, -1882244, -1720364, -1528724, -1307324, -1056164, -775244, -464564, -124124, 246076, 646036,
+              1075756, 1535236, 2024476, 2543476, 2608396, -21756, -135076, -216588, -266292, -284188, -270276, -224556,
+              -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332, -2153980, -2012836, -1839884,
+              -1635124, -1398556, -1130180, -829996, -498004, -134204, 261404, 688820, 1148044, 1639076, 2161916,
+              2716564, 2786924, -23132, -143812, -230636, -283604, -302716, -287972, -239372, -156916, 324, 109564,
+              293588, 511468, 763204, 1048796, 1368244, 1451724, -2293596, -2143428, -1959404, -1741524, -1489788,
+              -1204196, -884748, -531444, -144284, 276732, 731604, 1220332, 1742916, 2299356, 2889652, 2965452, -24508,
+              -152548, -244684, -300916, -321244, -305668, -254188, -166804, -28, 115676, 310772, 541772, 808676,
+              1111484, 1450196, 1539116, -2433212, -2274020, -2078924, -1847924, -1581020, -1278212, -939500, -564884,
+              -154364, 292060, 774388, 1292620, 1846756, 2436796, 3062740, 3143980, -25884, -161284, -258732, -318228,
+              -339772, -323364, -269004, -176692, -380, 121788, 327956, 572076, 854148, 1174172, 1532148, 1626508,
+              -2572828, -2404612, -2198444, -1954324, -1672252, -1352228, -994252, -598324, -164444, 307388, 817172,
+              1364908, 1950596, 2574236, 3235828, 3322508, -27260, -170020, -272780, -335540, -358300, -341060, -283820,
+              -186580, -732, 127900, 345140, 602380, 899620, 1236860, 1614100, 1713900, -2712444, -2535204, -2317964,
+              -2060724, -1763484, -1426244, -1049004, -631764, -174524, 322716, 859956, 1437196, 2054436, 2711676,
+              3408916, 3501036, -28636, -178756, -286828, -352852, -376828, -358756, -298636, -196468, -1084, 134012,
+              362324, 632684, 945092, 1299548, 1696052, 1801292, -2852060, -2665796, -2437484, -2167124, -1854716,
+              -1500260, -1103756, -665204, -184604, 338044, 902740, 1509484, 2158276, 2849116, 3582004, 3679564, -30012,
+              -187492, -300876, -370164, -395356, -376452, -313452, -206356, -1436, 140124, 379508, 662988, 990564,
+              1362236, 1778004, 1888684, -2991676, -2796388, -2557004, -2273524, -1945948, -1574276, -1158508, -698644,
+              -194684, 353372, 945524, 1581772, 2262116, 2986556, 3755092, 3858092, -31388, -196228, -314924, -387476,
+              -413884, -394148, -328268, -216244, -1788, 146236, 396692, 693292, 1036036, 1424924, 1859956, 1976076,
+              -3131292, -2926980, -2676524, -2379924, -2037180, -1648292, -1213260, -732084, -204764, 368700, 988308,
+              1654060, 2365956, 3123996, 3928180, 4036620, -32764, -204964, -328972, -404788, -432412, -411844, -343084,
+              -226132, -2140, 152348, 413876, 723596, 1081508, 1487612, 1941908, 2063468, -3270908, -3057572, -2796044,
+              -2486324, -2128412, -1722308, -1268012, -765524, -214844, 384028, 1031092, 1726348, 2469796, 3261436,
+              4101268, 4215148, -34140, -213700, -343020, -422100, -450940, -429540, -357900, -236020, -2492, 158460,
+              431060, 753900, 1126980, 1550300, 2023860, 2150860, -3410524, -3188164, -2915564, -2592724, -2219644,
+              -1796324, -1322764, -798964, -224924, 399356, 1073876, 1798636, 2573636, 3398876, 4274356, 4393676,
+              -35516, -222436, -357068, -439412, -469468, -447236, -372716, -245908, -2844, 164572, 448244, 784204,
+              1172452, 1612988, 2105812, 2238252, -3550140, -3318756, -3035084, -2699124, -2310876, -1870340, -1377516,
+              -832404, -235004, 414684, 1116660, 1870924, 2677476, 3536316, 4447444, 4572204, -36892, -231172, -371116,
+              -456724, -487996, -464932, -387532, -255796, -3196, 170684, 465428, 814508, 1217924, 1675676, 2187764,
+              2325644, -3689756, -3449348, -3154604, -2805524, -2402108, -1944356, -1432268, -865844, -245084, 430012,
+              1159444, 1943212, 2781316, 3673756, 4620532, 4750732, -38268, -239908, -385164, -474036, -506524, -482628,
+              -402348, -265684, -3548, 176796, 482612, 844812, 1263396, 1738364, 2269716, 2413036, -3829372, -3579940,
+              -3274124, -2911924, -2493340, -2018372, -1487020, -899284, -255164, 445340, 1202228, 2015500, 2885156,
+              3811196, 4793620, 4929260, -39644, -248644, -399212, -491348, -525052, -500324, -417164, -275572, -3900,
+              182908, 499796, 875116, 1308868, 1801052, 2351668, 2500428, -3968988, -3710532, -3393644, -3018324,
+              -2584572, -2092388, -1541772, -932724, -265244, 460668, 1245012, 2087788, 2988996, 3948636, 4966708,
+              5107788, -41020, -257380, -413260, -508660, -543580, -518020, -431980, -285460, -4252, 189020, 516980,
+              905420, 1354340, 1863740, 2433620, 2587820, -4108604, -3841124, -3513164, -3124724, -2675804, -2166404,
+              -1596524, -966164, -275324, 475996, 1287796, 2160076, 3092836, 4086076, 5139796, 5286316, -42396, -266116,
+              -427308, -525972, -562108, -535716, -446796, -295348, -4604, 195132, 534164, 935724, 1399812, 1926428,
+              2515572, 2675212, -4248220, -3971716, -3632684, -3231124, -2767036, -2240420, -1651276, -999604, -285404,
+              491324, 1330580, 2232364, 3196676, 4223516, 5312884, 5464844, -43772, -274852, -441356, -543284, -580636,
+              -553412, -461612, -305236, -4956, 201244, 551348, 966028, 1445284, 1989116, 2597524, 2762604, -4387836,
+              -4102308, -3752204, -3337524, -2858268, -2314436, -1706028, -1033044, -295484, 506652, 1373364, 2304652,
+              3300516, 4360956, 5485972, 5643372
+            ]
+          }
+        ]
+      }
+    ]
+  },
+  {
+    "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4",
+    "operator": "MatMulNBits",
+    "opset": { "domain": "com.microsoft", "version": 1 },
+    "attributes": [
+      { "name": "K", "data": 32, "type": "int" },
+      { "name": "N", "data": 32, "type": "int" },
+      { "name": "block_size", "data": 16, "type": "int" },
+      { "name": "bits", "data": 4, "type": "int" }
+    ],
+    "cases": [
+      {
+        "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; asymmetric",
+        "inputs": [
+          {
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
+              275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295,
+              296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316,
+              317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337,
+              338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358,
+              359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379,
+              380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400,
+              401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421,
+              422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442,
+              443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463,
+              464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484,
+              485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505,
+              506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526,
+              527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547,
+              548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568,
+              569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589,
+              590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610,
+              611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631,
+              632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652,
+              653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673,
+              674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694,
+              695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715,
+              716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736,
+              737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757,
+              758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778,
+              779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799,
+              800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820,
+              821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841,
+              842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862,
+              863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883,
+              884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904,
+              905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925,
+              926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946,
+              947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967,
+              968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988,
+              989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007,
+              1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024
+            ],
+            "dims": [32, 32],
+            "type": "float32"
+          },
+          {
+            "dims": [32, 2, 8],
+            "type": "uint8",
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
+              275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295,
+              296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316,
+              317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337,
+              338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358,
+              359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379,
+              380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400,
+              401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421,
+              422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442,
+              443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463,
+              464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484,
+              485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505,
+              506, 507, 508, 509, 510, 511, 512
+            ]
+          },
+          {
+            "dims": [64],
+            "type": "float32",
+            "data": [
+              0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+              29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
+              55, 56, 57, 58, 59, 60, 61, 62, 63
+            ]
+          },
+          {
+            "dims": [32],
+            "type": "uint8",
+            "data": [
+              128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
+              128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128
+            ]
+          }
+        ],
+        "outputs": [
+          {
+            "dims": [32, 32],
+            "type": "float32",
+            "data": [
+              -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476,
+              86092, -24924, -16964, -7916, 2220, 13444, 25756, 39156, 53644, 69220, 85884, 103636, 122476, 142404,
+              163420, 185524, 176460, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508,
+              170956, 205540, 243260, 284116, 296364, -33468, -8292, 20020, 51468, 86052, 123772, 164628, 208620,
+              255748, 306012, 359412, 415948, 475620, 538428, 604372, 608940, -3868, -2948, 3156, 14444, 30916, 52572,
+              79412, 111436, 153732, 191036, 238612, 291372, 349316, 412444, 480756, 506636, -42012, 380, 47956, 100716,
+              158660, 221788, 290100, 363596, 442276, 526140, 615188, 709420, 808836, 913436, 1023220, 1041420, -5244,
+              -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, 337716, 411788, 493092, 581628, 677396,
+              716908, -50556, 9052, 75892, 149964, 231268, 319804, 415572, 518572, 628804, 746268, 870964, 1002892,
+              1142052, 1288444, 1442068, 1473900, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, 284100,
+              350716, 436820, 532204, 636868, 750812, 874036, 927180, -59100, 17724, 103828, 199212, 303876, 417820,
+              541044, 673548, 815332, 966396, 1126740, 1296364, 1475268, 1663452, 1860916, 1906380, -7996, -4580, 10164,
+              36236, 73636, 122364, 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452,
+              -67644, 26396, 131764, 248460, 376484, 515836, 666516, 828524, 1001860, 1186524, 1382516, 1589836,
+              1808484, 2038460, 2279764, 2338860, -9372, -5124, 12500, 43500, 87876, 145628, 216756, 301260, 414468,
+              510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, -76188, 35068, 159700, 297708, 449092, 613852,
+              791988, 983500, 1188388, 1406652, 1638292, 1883308, 2141700, 2413468, 2698612, 2771340, -10748, -5668,
+              14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, 1258364, 1463956,
+              1557996, -84732, 43740, 187636, 346956, 521700, 711868, 917460, 1138476, 1374916, 1626780, 1894068,
+              2176780, 2474916, 2788476, 3117460, 3203820, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172,
+              544836, 670076, 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -93276, 52412, 215572, 396204,
+              594308, 809884, 1042932, 1293452, 1561444, 1846908, 2149844, 2470252, 2808132, 3163484, 3536308, 3636300,
+              -13500, -6756, 19508, 65292, 130596, 215420, 319764, 443628, 610020, 749916, 932340, 1134284, 1355748,
+              1596732, 1857236, 1978540, -101820, 61084, 243508, 445452, 666916, 907900, 1168404, 1448428, 1747972,
+              2067036, 2405620, 2763724, 3141348, 3538492, 3955156, 4068780, -14876, -7300, 21844, 72556, 144836,
+              238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, -110364,
+              69756, 271444, 494700, 739524, 1005916, 1293876, 1603404, 1934500, 2287164, 2661396, 3057196, 3474564,
+              3913500, 4374004, 4501260, -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596,
+              1130548, 1375116, 1643300, 1935100, 2250516, 2399084, -118908, 78428, 299380, 543948, 812132, 1103932,
+              1419348, 1758380, 2121028, 2507292, 2917172, 3350668, 3807780, 4288508, 4792852, 4933740, -17628, -8388,
+              26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, 1229652, 1495532, 1787076, 2104284, 2447156,
+              2609356, -127452, 87100, 327316, 593196, 884740, 1201948, 1544820, 1913356, 2307556, 2727420, 3172948,
+              3644140, 4140996, 4663516, 5211700, 5366220, -19004, -8932, 28852, 94348, 187556, 308476, 457108, 633452,
+              870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -135996, 95772, 355252, 642444,
+              957348, 1299964, 1670292, 2068332, 2494084, 2947548, 3428724, 3937612, 4474212, 5038524, 5630548, 5798700,
+              -20380, -9476, 31188, 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628,
+              2442652, 2840436, 3029900, -144540, 104444, 383188, 691692, 1029956, 1397980, 1795764, 2223308, 2680612,
+              3167676, 3684500, 4231084, 4807428, 5413532, 6049396, 6231180, -21756, -10020, 33524, 108876, 216036,
+              355004, 525780, 728364, 1001124, 1228956, 1526964, 1856780, 2218404, 2611836, 3037076, 3240172, -153084,
+              113116, 411124, 740940, 1102564, 1495996, 1921236, 2378284, 2867140, 3387804, 3940276, 4524556, 5140644,
+              5788540, 6468244, 6663660, -23132, -10564, 35860, 116140, 230276, 378268, 560116, 775820, 1066308,
+              1308796, 1626068, 1977196, 2362180, 2781020, 3233716, 3450444, -161628, 121788, 439060, 790188, 1175172,
+              1594012, 2046708, 2533260, 3053668, 3607932, 4196052, 4818028, 5473860, 6163548, 6887092, 7096140, -24508,
+              -11108, 38196, 123404, 244516, 401532, 594452, 823276, 1131492, 1388636, 1725172, 2097612, 2505956,
+              2950204, 3430356, 3660716, -170172, 130460, 466996, 839436, 1247780, 1692028, 2172180, 2688236, 3240196,
+              3828060, 4451828, 5111500, 5807076, 6538556, 7305940, 7528620, -25884, -11652, 40532, 130668, 258756,
+              424796, 628788, 870732, 1196676, 1468476, 1824276, 2218028, 2649732, 3119388, 3626996, 3870988, -178716,
+              139132, 494932, 888684, 1320388, 1790044, 2297652, 2843212, 3426724, 4048188, 4707604, 5404972, 6140292,
+              6913564, 7724788, 7961100, -27260, -12196, 42868, 137932, 272996, 448060, 663124, 918188, 1261860,
+              1548316, 1923380, 2338444, 2793508, 3288572, 3823636, 4081260, -187260, 147804, 522868, 937932, 1392996,
+              1888060, 2423124, 2998188, 3613252, 4268316, 4963380, 5698444, 6473508, 7288572, 8143636, 8393580, -28636,
+              -12740, 45204, 145196, 287236, 471324, 697460, 965644, 1327044, 1628156, 2022484, 2458860, 2937284,
+              3457756, 4020276, 4291532, -195804, 156476, 550804, 987180, 1465604, 1986076, 2548596, 3153164, 3799780,
+              4488444, 5219156, 5991916, 6806724, 7663580, 8562484, 8826060, -30012, -13284, 47540, 152460, 301476,
+              494588, 731796, 1013100, 1392228, 1707996, 2121588, 2579276, 3081060, 3626940, 4216916, 4501804, -204348,
+              165148, 578740, 1036428, 1538212, 2084092, 2674068, 3308140, 3986308, 4708572, 5474932, 6285388, 7139940,
+              8038588, 8981332, 9258540, -31388, -13828, 49876, 159724, 315716, 517852, 766132, 1060556, 1457412,
+              1787836, 2220692, 2699692, 3224836, 3796124, 4413556, 4712076, -212892, 173820, 606676, 1085676, 1610820,
+              2182108, 2799540, 3463116, 4172836, 4928700, 5730708, 6578860, 7473156, 8413596, 9400180, 9691020, -32764,
+              -14372, 52212, 166988, 329956, 541116, 800468, 1108012, 1522596, 1867676, 2319796, 2820108, 3368612,
+              3965308, 4610196, 4922348, -221436, 182492, 634612, 1134924, 1683428, 2280124, 2925012, 3618092, 4359364,
+              5148828, 5986484, 6872332, 7806372, 8788604, 9819028, 10123500, -34140, -14916, 54548, 174252, 344196,
+              564380, 834804, 1155468, 1587780, 1947516, 2418900, 2940524, 3512388, 4134492, 4806836, 5132620, -229980,
+              191164, 662548, 1184172, 1756036, 2378140, 3050484, 3773068, 4545892, 5368956, 6242260, 7165804, 8139588,
+              9163612, 10237876, 10555980, -35516, -15460, 56884, 181516, 358436, 587644, 869140, 1202924, 1652964,
+              2027356, 2518004, 3060940, 3656164, 4303676, 5003476, 5342892, -238524, 199836, 690484, 1233420, 1828644,
+              2476156, 3175956, 3928044, 4732420, 5589084, 6498036, 7459276, 8472804, 9538620, 10656724, 10988460,
+              -36892, -16004, 59220, 188780, 372676, 610908, 903476, 1250380, 1718148, 2107196, 2617108, 3181356,
+              3799940, 4472860, 5200116, 5553164, -247068, 208508, 718420, 1282668, 1901252, 2574172, 3301428, 4083020,
+              4918948, 5809212, 6753812, 7752748, 8806020, 9913628, 11075572, 11420940, -38268, -16548, 61556, 196044,
+              386916, 634172, 937812, 1297836, 1783332, 2187036, 2716212, 3301772, 3943716, 4642044, 5396756, 5763436,
+              -255612, 217180, 746356, 1331916, 1973860, 2672188, 3426900, 4237996, 5105476, 6029340, 7009588, 8046220,
+              9139236, 10288636, 11494420, 11853420, -39644, -17092, 63892, 203308, 401156, 657436, 972148, 1345292,
+              1848516, 2266876, 2815316, 3422188, 4087492, 4811228, 5593396, 5973708, -264156, 225852, 774292, 1381164,
+              2046468, 2770204, 3552372, 4392972, 5292004, 6249468, 7265364, 8339692, 9472452, 10663644, 11913268,
+              12285900, -41020, -17636, 66228, 210572, 415396, 680700, 1006484, 1392748, 1913700, 2346716, 2914420,
+              3542604, 4231268, 4980412, 5790036, 6183980, -272700, 234524, 802228, 1430412, 2119076, 2868220, 3677844,
+              4547948, 5478532, 6469596, 7521140, 8633164, 9805668, 11038652, 12332116, 12718380, -42396, -18180, 68564,
+              217836, 429636, 703964, 1040820, 1440204, 1978884, 2426556, 3013524, 3663020, 4375044, 5149596, 5986676,
+              6394252, -281244, 243196, 830164, 1479660, 2191684, 2966236, 3803316, 4702924, 5665060, 6689724, 7776916,
+              8926636, 10138884, 11413660, 12750964, 13150860, -43772, -18724, 70900, 225100, 443876, 727228, 1075156,
+              1487660, 2044068, 2506396, 3112628, 3783436, 4518820, 5318780, 6183316, 6604524, -289788, 251868, 858100,
+              1528908, 2264292, 3064252, 3928788, 4857900, 5851588, 6909852, 8032692, 9220108, 10472100, 11788668,
+              13169812, 13583340
+            ]
+          }
+        ]
+      }
+    ]
+  },
+  {
+    "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4",
+    "operator": "MatMulNBits",
+    "opset": { "domain": "com.microsoft", "version": 1 },
+    "attributes": [
+      { "name": "K", "data": 32, "type": "int" },
+      { "name": "N", "data": 32, "type": "int" },
+      { "name": "block_size", "data": 32, "type": "int" },
+      { "name": "bits", "data": 4, "type": "int" }
+    ],
+    "cases": [
+      {
+        "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; symmetric",
+        "inputs": [
+          {
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
+              275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295,
+              296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316,
+              317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337,
+              338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358,
+              359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379,
+              380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400,
+              401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421,
+              422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442,
+              443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463,
+              464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484,
+              485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505,
+              506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526,
+              527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547,
+              548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568,
+              569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589,
+              590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610,
+              611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631,
+              632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652,
+              653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673,
+              674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694,
+              695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715,
+              716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736,
+              737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757,
+              758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778,
+              779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799,
+              800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820,
+              821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841,
+              842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862,
+              863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883,
+              884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904,
+              905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925,
+              926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946,
+              947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967,
+              968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988,
+              989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007,
+              1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024
+            ],
+            "dims": [32, 32],
+            "type": "float32"
+          },
+          {
+            "dims": [32, 1, 16],
+            "type": "uint8",
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
+              275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295,
+              296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316,
+              317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337,
+              338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358,
+              359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379,
+              380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400,
+              401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421,
+              422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442,
+              443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463,
+              464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484,
+              485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505,
+              506, 507, 508, 509, 510, 511, 512
+            ]
+          },
+          {
+            "dims": [32],
+            "type": "float32",
+            "data": [
+              0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+              29, 30, 31
+            ]
+          }
+        ],
+        "outputs": [
+          {
+            "dims": [32, 32],
+            "type": "float32",
+            "data": [
+              0, -1560, -2576, -3048, -2976, -2360, -1200, 504, 2736, 5544, 8880, 12760, 17184, 22152, 27664, 26040,
+              -29312, -26520, -23184, -19304, -14880, -9912, -4400, 1656, 8256, 15400, 23088, 31320, 40096, 49416,
+              59280, 53816, 0, -5368, -9168, -11400, -12064, -11160, -8688, -4648, 2224, 8136, 16880, 27192, 39072,
+              52520, 67536, 68760, -98432, -91256, -82512, -72200, -60320, -46872, -31856, -15272, 2880, 22600, 43888,
+              66744, 91168, 117160, 144720, 142104, 0, -9176, -15760, -19752, -21152, -19960, -16176, -9800, 1712,
+              10728, 24880, 41624, 60960, 82888, 107408, 111480, -167552, -155992, -141840, -125096, -105760, -83832,
+              -59312, -32200, -2496, 29800, 64688, 102168, 142240, 184904, 230160, 230392, 0, -12984, -22352, -28104,
+              -30240, -28760, -23664, -14952, 1200, 13320, 32880, 56056, 82848, 113256, 147280, 154200, -236672,
+              -220728, -201168, -177992, -151200, -120792, -86768, -49128, -7872, 37000, 85488, 137592, 193312, 252648,
+              315600, 318680, 0, -16792, -28944, -36456, -39328, -37560, -31152, -20104, 688, 15912, 40880, 70488,
+              104736, 143624, 187152, 196920, -305792, -285464, -260496, -230888, -196640, -157752, -114224, -66056,
+              -13248, 44200, 106288, 173016, 244384, 320392, 401040, 406968, 0, -20600, -35536, -44808, -48416, -46360,
+              -38640, -25256, 176, 18504, 48880, 84920, 126624, 173992, 227024, 239640, -374912, -350200, -319824,
+              -283784, -242080, -194712, -141680, -82984, -18624, 51400, 127088, 208440, 295456, 388136, 486480, 495256,
+              0, -24408, -42128, -53160, -57504, -55160, -46128, -30408, -336, 21096, 56880, 99352, 148512, 204360,
+              266896, 282360, -444032, -414936, -379152, -336680, -287520, -231672, -169136, -99912, -24000, 58600,
+              147888, 243864, 346528, 455880, 571920, 583544, 0, -28216, -48720, -61512, -66592, -63960, -53616, -35560,
+              -848, 23688, 64880, 113784, 170400, 234728, 306768, 325080, -513152, -479672, -438480, -389576, -332960,
+              -268632, -196592, -116840, -29376, 65800, 168688, 279288, 397600, 523624, 657360, 671832, 0, -32024,
+              -55312, -69864, -75680, -72760, -61104, -40712, -1360, 26280, 72880, 128216, 192288, 265096, 346640,
+              367800, -582272, -544408, -497808, -442472, -378400, -305592, -224048, -133768, -34752, 73000, 189488,
+              314712, 448672, 591368, 742800, 760120, 0, -35832, -61904, -78216, -84768, -81560, -68592, -45864, -1872,
+              28872, 80880, 142648, 214176, 295464, 386512, 410520, -651392, -609144, -557136, -495368, -423840,
+              -342552, -251504, -150696, -40128, 80200, 210288, 350136, 499744, 659112, 828240, 848408, 0, -39640,
+              -68496, -86568, -93856, -90360, -76080, -51016, -2384, 31464, 88880, 157080, 236064, 325832, 426384,
+              453240, -720512, -673880, -616464, -548264, -469280, -379512, -278960, -167624, -45504, 87400, 231088,
+              385560, 550816, 726856, 913680, 936696, 0, -43448, -75088, -94920, -102944, -99160, -83568, -56168, -2896,
+              34056, 96880, 171512, 257952, 356200, 466256, 495960, -789632, -738616, -675792, -601160, -514720,
+              -416472, -306416, -184552, -50880, 94600, 251888, 420984, 601888, 794600, 999120, 1024984, 0, -47256,
+              -81680, -103272, -112032, -107960, -91056, -61320, -3408, 36648, 104880, 185944, 279840, 386568, 506128,
+              538680, -858752, -803352, -735120, -654056, -560160, -453432, -333872, -201480, -56256, 101800, 272688,
+              456408, 652960, 862344, 1084560, 1113272, 0, -51064, -88272, -111624, -121120, -116760, -98544, -66472,
+              -3920, 39240, 112880, 200376, 301728, 416936, 546000, 581400, -927872, -868088, -794448, -706952, -605600,
+              -490392, -361328, -218408, -61632, 109000, 293488, 491832, 704032, 930088, 1170000, 1201560, 0, -54872,
+              -94864, -119976, -130208, -125560, -106032, -71624, -4432, 41832, 120880, 214808, 323616, 447304, 585872,
+              624120, -996992, -932824, -853776, -759848, -651040, -527352, -388784, -235336, -67008, 116200, 314288,
+              527256, 755104, 997832, 1255440, 1289848, 0, -58680, -101456, -128328, -139296, -134360, -113520, -76776,
+              -4944, 44424, 128880, 229240, 345504, 477672, 625744, 666840, -1066112, -997560, -913104, -812744,
+              -696480, -564312, -416240, -252264, -72384, 123400, 335088, 562680, 806176, 1065576, 1340880, 1378136, 0,
+              -62488, -108048, -136680, -148384, -143160, -121008, -81928, -5456, 47016, 136880, 243672, 367392, 508040,
+              665616, 709560, -1135232, -1062296, -972432, -865640, -741920, -601272, -443696, -269192, -77760, 130600,
+              355888, 598104, 857248, 1133320, 1426320, 1466424, 0, -66296, -114640, -145032, -157472, -151960, -128496,
+              -87080, -5968, 49608, 144880, 258104, 389280, 538408, 705488, 752280, -1204352, -1127032, -1031760,
+              -918536, -787360, -638232, -471152, -286120, -83136, 137800, 376688, 633528, 908320, 1201064, 1511760,
+              1554712, 0, -70104, -121232, -153384, -166560, -160760, -135984, -92232, -6480, 52200, 152880, 272536,
+              411168, 568776, 745360, 795000, -1273472, -1191768, -1091088, -971432, -832800, -675192, -498608, -303048,
+              -88512, 145000, 397488, 668952, 959392, 1268808, 1597200, 1643000, 0, -73912, -127824, -161736, -175648,
+              -169560, -143472, -97384, -6992, 54792, 160880, 286968, 433056, 599144, 785232, 837720, -1342592,
+              -1256504, -1150416, -1024328, -878240, -712152, -526064, -319976, -93888, 152200, 418288, 704376, 1010464,
+              1336552, 1682640, 1731288, 0, -77720, -134416, -170088, -184736, -178360, -150960, -102536, -7504, 57384,
+              168880, 301400, 454944, 629512, 825104, 880440, -1411712, -1321240, -1209744, -1077224, -923680, -749112,
+              -553520, -336904, -99264, 159400, 439088, 739800, 1061536, 1404296, 1768080, 1819576, 0, -81528, -141008,
+              -178440, -193824, -187160, -158448, -107688, -8016, 59976, 176880, 315832, 476832, 659880, 864976, 923160,
+              -1480832, -1385976, -1269072, -1130120, -969120, -786072, -580976, -353832, -104640, 166600, 459888,
+              775224, 1112608, 1472040, 1853520, 1907864, 0, -85336, -147600, -186792, -202912, -195960, -165936,
+              -112840, -8528, 62568, 184880, 330264, 498720, 690248, 904848, 965880, -1549952, -1450712, -1328400,
+              -1183016, -1014560, -823032, -608432, -370760, -110016, 173800, 480688, 810648, 1163680, 1539784, 1938960,
+              1996152, 0, -89144, -154192, -195144, -212000, -204760, -173424, -117992, -9040, 65160, 192880, 344696,
+              520608, 720616, 944720, 1008600, -1619072, -1515448, -1387728, -1235912, -1060000, -859992, -635888,
+              -387688, -115392, 181000, 501488, 846072, 1214752, 1607528, 2024400, 2084440, 0, -92952, -160784, -203496,
+              -221088, -213560, -180912, -123144, -9552, 67752, 200880, 359128, 542496, 750984, 984592, 1051320,
+              -1688192, -1580184, -1447056, -1288808, -1105440, -896952, -663344, -404616, -120768, 188200, 522288,
+              881496, 1265824, 1675272, 2109840, 2172728, 0, -96760, -167376, -211848, -230176, -222360, -188400,
+              -128296, -10064, 70344, 208880, 373560, 564384, 781352, 1024464, 1094040, -1757312, -1644920, -1506384,
+              -1341704, -1150880, -933912, -690800, -421544, -126144, 195400, 543088, 916920, 1316896, 1743016, 2195280,
+              2261016, 0, -100568, -173968, -220200, -239264, -231160, -195888, -133448, -10576, 72936, 216880, 387992,
+              586272, 811720, 1064336, 1136760, -1826432, -1709656, -1565712, -1394600, -1196320, -970872, -718256,
+              -438472, -131520, 202600, 563888, 952344, 1367968, 1810760, 2280720, 2349304, 0, -104376, -180560,
+              -228552, -248352, -239960, -203376, -138600, -11088, 75528, 224880, 402424, 608160, 842088, 1104208,
+              1179480, -1895552, -1774392, -1625040, -1447496, -1241760, -1007832, -745712, -455400, -136896, 209800,
+              584688, 987768, 1419040, 1878504, 2366160, 2437592, 0, -108184, -187152, -236904, -257440, -248760,
+              -210864, -143752, -11600, 78120, 232880, 416856, 630048, 872456, 1144080, 1222200, -1964672, -1839128,
+              -1684368, -1500392, -1287200, -1044792, -773168, -472328, -142272, 217000, 605488, 1023192, 1470112,
+              1946248, 2451600, 2525880, 0, -111992, -193744, -245256, -266528, -257560, -218352, -148904, -12112,
+              80712, 240880, 431288, 651936, 902824, 1183952, 1264920, -2033792, -1903864, -1743696, -1553288, -1332640,
+              -1081752, -800624, -489256, -147648, 224200, 626288, 1058616, 1521184, 2013992, 2537040, 2614168, 0,
+              -115800, -200336, -253608, -275616, -266360, -225840, -154056, -12624, 83304, 248880, 445720, 673824,
+              933192, 1223824, 1307640, -2102912, -1968600, -1803024, -1606184, -1378080, -1118712, -828080, -506184,
+              -153024, 231400, 647088, 1094040, 1572256, 2081736, 2622480, 2702456, 0, -119608, -206928, -261960,
+              -284704, -275160, -233328, -159208, -13136, 85896, 256880, 460152, 695712, 963560, 1263696, 1350360,
+              -2172032, -2033336, -1862352, -1659080, -1423520, -1155672, -855536, -523112, -158400, 238600, 667888,
+              1129464, 1623328, 2149480, 2707920, 2790744
+            ]
+          }
+        ]
+      }
+    ]
+  },
+  {
+    "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4",
+    "operator": "MatMulNBits",
+    "opset": { "domain": "com.microsoft", "version": 1 },
+    "attributes": [
+      { "name": "K", "data": 32, "type": "int" },
+      { "name": "N", "data": 32, "type": "int" },
+      { "name": "block_size", "data": 32, "type": "int" },
+      { "name": "bits", "data": 4, "type": "int" }
+    ],
+    "cases": [
+      {
+        "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; asymmetric",
+        "inputs": [
+          {
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
+              275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295,
+              296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316,
+              317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337,
+              338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358,
+              359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379,
+              380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400,
+              401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421,
+              422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442,
+              443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463,
+              464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484,
+              485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505,
+              506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526,
+              527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547,
+              548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568,
+              569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589,
+              590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610,
+              611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631,
+              632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652,
+              653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673,
+              674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694,
+              695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715,
+              716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736,
+              737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757,
+              758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778,
+              779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799,
+              800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820,
+              821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841,
+              842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862,
+              863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883,
+              884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904,
+              905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925,
+              926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946,
+              947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967,
+              968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988,
+              989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007,
+              1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024
+            ],
+            "dims": [32, 32],
+            "type": "float32"
+          },
+          {
+            "dims": [32, 1, 16],
+            "type": "uint8",
+            "data": [
+              1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
+              30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
+              56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+              82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
+              107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
+              128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
+              149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
+              170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
+              191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
+              212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
+              233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
+              254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
+              275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295,
+              296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316,
+              317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337,
+              338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358,
+              359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379,
+              380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400,
+              401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421,
+              422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442,
+              443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463,
+              464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484,
+              485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505,
+              506, 507, 508, 509, 510, 511, 512
+            ]
+          },
+          {
+            "dims": [32],
+            "type": "float32",
+            "data": [
+              0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
+              29, 30, 31
+            ]
+          },
+          {
+            "dims": [32],
+            "type": "uint8",
+            "data": [
+              128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
+              128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128
+            ]
+          }
+        ],
+        "outputs": [
+          {
+            "dims": [32, 32],
+            "type": "float32",
+            "data": [
+              0, 2664, 5872, 9624, 13920, 18760, 24144, 30072, 36528, 43560, 51120, 59224, 67872, 77064, 86800, 89400,
+              38272, 45288, 52848, 60952, 69600, 78792, 88528, 98808, 109632, 121000, 132912, 145368, 158368, 171912,
+              186000, 184760, 0, 7048, 15664, 25848, 37600, 50920, 65808, 82264, 101552, 119880, 141040, 163768, 188064,
+              213928, 241360, 255000, 100224, 119816, 140976, 163704, 188000, 213864, 241296, 270296, 300864, 333000,
+              366704, 401976, 438816, 477224, 517200, 527000, 0, 11432, 25456, 42072, 61280, 83080, 107472, 134456,
+              166576, 196200, 230960, 268312, 308256, 350792, 395920, 420600, 162176, 194344, 229104, 266456, 306400,
+              348936, 394064, 441784, 492096, 545000, 600496, 658584, 719264, 782536, 848400, 869240, 0, 15816, 35248,
+              58296, 84960, 115240, 149136, 186648, 231600, 272520, 320880, 372856, 428448, 487656, 550480, 586200,
+              224128, 268872, 317232, 369208, 424800, 484008, 546832, 613272, 683328, 757000, 834288, 915192, 999712,
+              1087848, 1179600, 1211480, 0, 20200, 45040, 74520, 108640, 147400, 190800, 238840, 296624, 348840, 410800,
+              477400, 548640, 624520, 705040, 751800, 286080, 343400, 405360, 471960, 543200, 619080, 699600, 784760,
+              874560, 969000, 1068080, 1171800, 1280160, 1393160, 1510800, 1553720, 0, 24584, 54832, 90744, 132320,
+              179560, 232464, 291032, 361648, 425160, 500720, 581944, 668832, 761384, 859600, 917400, 348032, 417928,
+              493488, 574712, 661600, 754152, 852368, 956248, 1065792, 1181000, 1301872, 1428408, 1560608, 1698472,
+              1842000, 1895960, 0, 28968, 64624, 106968, 156000, 211720, 274128, 343224, 426672, 501480, 590640, 686488,
+              789024, 898248, 1014160, 1083000, 409984, 492456, 581616, 677464, 780000, 889224, 1005136, 1127736,
+              1257024, 1393000, 1535664, 1685016, 1841056, 2003784, 2173200, 2238200, 0, 33352, 74416, 123192, 179680,
+              243880, 315792, 395416, 491696, 577800, 680560, 791032, 909216, 1035112, 1168720, 1248600, 471936, 566984,
+              669744, 780216, 898400, 1024296, 1157904, 1299224, 1448256, 1605000, 1769456, 1941624, 2121504, 2309096,
+              2504400, 2580440, 0, 37736, 84208, 139416, 203360, 276040, 357456, 447608, 556720, 654120, 770480, 895576,
+              1029408, 1171976, 1323280, 1414200, 533888, 641512, 757872, 882968, 1016800, 1159368, 1310672, 1470712,
+              1639488, 1817000, 2003248, 2198232, 2401952, 2614408, 2835600, 2922680, 0, 42120, 94000, 155640, 227040,
+              308200, 399120, 499800, 621744, 730440, 860400, 1000120, 1149600, 1308840, 1477840, 1579800, 595840,
+              716040, 846000, 985720, 1135200, 1294440, 1463440, 1642200, 1830720, 2029000, 2237040, 2454840, 2682400,
+              2919720, 3166800, 3264920, 0, 46504, 103792, 171864, 250720, 340360, 440784, 551992, 686768, 806760,
+              950320, 1104664, 1269792, 1445704, 1632400, 1745400, 657792, 790568, 934128, 1088472, 1253600, 1429512,
+              1616208, 1813688, 2021952, 2241000, 2470832, 2711448, 2962848, 3225032, 3498000, 3607160, 0, 50888,
+              113584, 188088, 274400, 372520, 482448, 604184, 751792, 883080, 1040240, 1209208, 1389984, 1582568,
+              1786960, 1911000, 719744, 865096, 1022256, 1191224, 1372000, 1564584, 1768976, 1985176, 2213184, 2453000,
+              2704624, 2968056, 3243296, 3530344, 3829200, 3949400, 0, 55272, 123376, 204312, 298080, 404680, 524112,
+              656376, 816816, 959400, 1130160, 1313752, 1510176, 1719432, 1941520, 2076600, 781696, 939624, 1110384,
+              1293976, 1490400, 1699656, 1921744, 2156664, 2404416, 2665000, 2938416, 3224664, 3523744, 3835656,
+              4160400, 4291640, 0, 59656, 133168, 220536, 321760, 436840, 565776, 708568, 881840, 1035720, 1220080,
+              1418296, 1630368, 1856296, 2096080, 2242200, 843648, 1014152, 1198512, 1396728, 1608800, 1834728, 2074512,
+              2328152, 2595648, 2877000, 3172208, 3481272, 3804192, 4140968, 4491600, 4633880, 0, 64040, 142960, 236760,
+              345440, 469000, 607440, 760760, 946864, 1112040, 1310000, 1522840, 1750560, 1993160, 2250640, 2407800,
+              905600, 1088680, 1286640, 1499480, 1727200, 1969800, 2227280, 2499640, 2786880, 3089000, 3406000, 3737880,
+              4084640, 4446280, 4822800, 4976120, 0, 68424, 152752, 252984, 369120, 501160, 649104, 812952, 1011888,
+              1188360, 1399920, 1627384, 1870752, 2130024, 2405200, 2573400, 967552, 1163208, 1374768, 1602232, 1845600,
+              2104872, 2380048, 2671128, 2978112, 3301000, 3639792, 3994488, 4365088, 4751592, 5154000, 5318360, 0,
+              72808, 162544, 269208, 392800, 533320, 690768, 865144, 1076912, 1264680, 1489840, 1731928, 1990944,
+              2266888, 2559760, 2739000, 1029504, 1237736, 1462896, 1704984, 1964000, 2239944, 2532816, 2842616,
+              3169344, 3513000, 3873584, 4251096, 4645536, 5056904, 5485200, 5660600, 0, 77192, 172336, 285432, 416480,
+              565480, 732432, 917336, 1141936, 1341000, 1579760, 1836472, 2111136, 2403752, 2714320, 2904600, 1091456,
+              1312264, 1551024, 1807736, 2082400, 2375016, 2685584, 3014104, 3360576, 3725000, 4107376, 4507704,
+              4925984, 5362216, 5816400, 6002840, 0, 81576, 182128, 301656, 440160, 597640, 774096, 969528, 1206960,
+              1417320, 1669680, 1941016, 2231328, 2540616, 2868880, 3070200, 1153408, 1386792, 1639152, 1910488,
+              2200800, 2510088, 2838352, 3185592, 3551808, 3937000, 4341168, 4764312, 5206432, 5667528, 6147600,
+              6345080, 0, 85960, 191920, 317880, 463840, 629800, 815760, 1021720, 1271984, 1493640, 1759600, 2045560,
+              2351520, 2677480, 3023440, 3235800, 1215360, 1461320, 1727280, 2013240, 2319200, 2645160, 2991120,
+              3357080, 3743040, 4149000, 4574960, 5020920, 5486880, 5972840, 6478800, 6687320, 0, 90344, 201712, 334104,
+              487520, 661960, 857424, 1073912, 1337008, 1569960, 1849520, 2150104, 2471712, 2814344, 3178000, 3401400,
+              1277312, 1535848, 1815408, 2115992, 2437600, 2780232, 3143888, 3528568, 3934272, 4361000, 4808752,
+              5277528, 5767328, 6278152, 6810000, 7029560, 0, 94728, 211504, 350328, 511200, 694120, 899088, 1126104,
+              1402032, 1646280, 1939440, 2254648, 2591904, 2951208, 3332560, 3567000, 1339264, 1610376, 1903536,
+              2218744, 2556000, 2915304, 3296656, 3700056, 4125504, 4573000, 5042544, 5534136, 6047776, 6583464,
+              7141200, 7371800, 0, 99112, 221296, 366552, 534880, 726280, 940752, 1178296, 1467056, 1722600, 2029360,
+              2359192, 2712096, 3088072, 3487120, 3732600, 1401216, 1684904, 1991664, 2321496, 2674400, 3050376,
+              3449424, 3871544, 4316736, 4785000, 5276336, 5790744, 6328224, 6888776, 7472400, 7714040, 0, 103496,
+              231088, 382776, 558560, 758440, 982416, 1230488, 1532080, 1798920, 2119280, 2463736, 2832288, 3224936,
+              3641680, 3898200, 1463168, 1759432, 2079792, 2424248, 2792800, 3185448, 3602192, 4043032, 4507968,
+              4997000, 5510128, 6047352, 6608672, 7194088, 7803600, 8056280, 0, 107880, 240880, 399000, 582240, 790600,
+              1024080, 1282680, 1597104, 1875240, 2209200, 2568280, 2952480, 3361800, 3796240, 4063800, 1525120,
+              1833960, 2167920, 2527000, 2911200, 3320520, 3754960, 4214520, 4699200, 5209000, 5743920, 6303960,
+              6889120, 7499400, 8134800, 8398520, 0, 112264, 250672, 415224, 605920, 822760, 1065744, 1334872, 1662128,
+              1951560, 2299120, 2672824, 3072672, 3498664, 3950800, 4229400, 1587072, 1908488, 2256048, 2629752,
+              3029600, 3455592, 3907728, 4386008, 4890432, 5421000, 5977712, 6560568, 7169568, 7804712, 8466000,
+              8740760, 0, 116648, 260464, 431448, 629600, 854920, 1107408, 1387064, 1727152, 2027880, 2389040, 2777368,
+              3192864, 3635528, 4105360, 4395000, 1649024, 1983016, 2344176, 2732504, 3148000, 3590664, 4060496,
+              4557496, 5081664, 5633000, 6211504, 6817176, 7450016, 8110024, 8797200, 9083000, 0, 121032, 270256,
+              447672, 653280, 887080, 1149072, 1439256, 1792176, 2104200, 2478960, 2881912, 3313056, 3772392, 4259920,
+              4560600, 1710976, 2057544, 2432304, 2835256, 3266400, 3725736, 4213264, 4728984, 5272896, 5845000,
+              6445296, 7073784, 7730464, 8415336, 9128400, 9425240, 0, 125416, 280048, 463896, 676960, 919240, 1190736,
+              1491448, 1857200, 2180520, 2568880, 2986456, 3433248, 3909256, 4414480, 4726200, 1772928, 2132072,
+              2520432, 2938008, 3384800, 3860808, 4366032, 4900472, 5464128, 6057000, 6679088, 7330392, 8010912,
+              8720648, 9459600, 9767480, 0, 129800, 289840, 480120, 700640, 951400, 1232400, 1543640, 1922224, 2256840,
+              2658800, 3091000, 3553440, 4046120, 4569040, 4891800, 1834880, 2206600, 2608560, 3040760, 3503200,
+              3995880, 4518800, 5071960, 5655360, 6269000, 6912880, 7587000, 8291360, 9025960, 9790800, 10109720, 0,
+              134184, 299632, 496344, 724320, 983560, 1274064, 1595832, 1987248, 2333160, 2748720, 3195544, 3673632,
+              4182984, 4723600, 5057400, 1896832, 2281128, 2696688, 3143512, 3621600, 4130952, 4671568, 5243448,
+              5846592, 6481000, 7146672, 7843608, 8571808, 9331272, 10122000, 10451960, 0, 138568, 309424, 512568,
+              748000, 1015720, 1315728, 1648024, 2052272, 2409480, 2838640, 3300088, 3793824, 4319848, 4878160, 5223000,
+              1958784, 2355656, 2784816, 3246264, 3740000, 4266024, 4824336, 5414936, 6037824, 6693000, 7380464,
+              8100216, 8852256, 9636584, 10453200, 10794200
+            ]
+          }
+        ]
+      }
+    ]
+  }
+]
diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc
index 047fd6fd7511b..990120dd3708e 100644
--- a/js/web/test/data/ops/where.jsonc
+++ b/js/web/test/data/ops/where.jsonc
@@ -168,5 +168,39 @@
         ]
       }
     ]
+  },
+  {
+    "name": "Where with no attributes",
+    "operator": "Where",
+    "attributes": [],
+    "cases": [
+      {
+        "name": "T[1 1 2 1] T[1 4] T[1 1 2 4] float32 broadcast 1",
+        "inputs": [
+          {
+            "data": [true, false],
+            "dims": [1, 1, 2, 1],
+            "type": "bool"
+          },
+          {
+            "data": [1, 2, 3, 4],
+            "dims": [1, 4],
+            "type": "float32"
+          },
+          {
+            "data": [5, 6, 7, 8, 9, 10, 11, 12],
+            "dims": [1, 1, 2, 4],
+            "type": "float32"
+          }
+        ],
+        "outputs": [
+          {
+            "data": [1, 2, 3, 4, 9, 10, 11, 12],
+            "dims": [1, 1, 2, 4],
+            "type": "float32"
+          }
+        ]
+      }
+    ]
   }
 ]
diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc
index 55b21283025c2..e96a0aa045bc8 100644
--- a/js/web/test/suite-test-list.jsonc
+++ b/js/web/test/suite-test-list.jsonc
@@ -1231,7 +1231,7 @@
       "test_split_variable_parts_1d",
       "test_split_variable_parts_2d",
       "test_split_variable_parts_default_axis",
-      // // "test_split_zero_size_splits",
+      "test_split_zero_size_splits",
       "test_sqrt_example",
       "test_sqrt",
       "test_squeeze_negative_axes",
@@ -1334,6 +1334,7 @@
       "acos.jsonc",
       "add.jsonc",
       "add_int32.jsonc",
+      "add_zero-sized.jsonc",
       //"and.jsonc",
       "asin.jsonc",
       "attention.jsonc",
@@ -1343,6 +1344,7 @@
       "ceil.jsonc",
       "concat.jsonc",
       "concat_int32.jsonc",
+      "concat_zero-sized.jsonc",
       "cast.jsonc",
       "conv.jsonc",
       "cos.jsonc",
@@ -1354,6 +1356,7 @@
       "expand.jsonc",
       "fast-gelu.jsonc",
       "floor.jsonc",
+      "fused-conv.jsonc",
       "gather-elements.jsonc",
       "gemm.jsonc",
       "global-average-pool.jsonc",
@@ -1362,6 +1365,7 @@
       "less.jsonc",
       "log.jsonc",
       "matmul.jsonc",
+      "matmulnbits.jsonc",
       "matmul-broadcast.jsonc",
       "mul.jsonc",
       "mul_int32.jsonc",
diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts
index 2d83ce1e095ce..96e374f87aed1 100644
--- a/js/web/test/test-main.ts
+++ b/js/web/test/test-main.ts
@@ -19,49 +19,7 @@ if (ORT_WEB_TEST_CONFIG.model.some(testGroup => testGroup.tests.some(test => tes
 }
 
 // set flags
-const options = ORT_WEB_TEST_CONFIG.options;
-if (options.debug !== undefined) {
-  ort.env.debug = options.debug;
-}
-if (options.globalEnvFlags) {
-  const flags = options.globalEnvFlags;
-  if (flags.logLevel !== undefined) {
-    ort.env.logLevel = flags.logLevel;
-  }
-  if (flags.webgl?.contextId !== undefined) {
-    ort.env.webgl.contextId = flags.webgl.contextId;
-  }
-  if (flags.webgl?.matmulMaxBatchSize !== undefined) {
-    ort.env.webgl.matmulMaxBatchSize = flags.webgl.matmulMaxBatchSize;
-  }
-  if (flags.webgl?.textureCacheMode !== undefined) {
-    ort.env.webgl.textureCacheMode = flags.webgl.textureCacheMode;
-  }
-  if (flags.webgl?.pack !== undefined) {
-    ort.env.webgl.pack = flags.webgl.pack;
-  }
-  if (flags.webgl?.async !== undefined) {
-    ort.env.webgl.async = flags.webgl.async;
-  }
-  if (flags.wasm?.numThreads !== undefined) {
-    ort.env.wasm.numThreads = flags.wasm.numThreads;
-  }
-  if (flags.wasm?.simd !== undefined) {
-    ort.env.wasm.simd = flags.wasm.simd;
-  }
-  if (flags.wasm?.proxy !== undefined) {
-    ort.env.wasm.proxy = flags.wasm.proxy;
-  }
-  if (flags.wasm?.initTimeout !== undefined) {
-    ort.env.wasm.initTimeout = flags.wasm.initTimeout;
-  }
-  if (flags.webgpu?.profilingMode !== undefined) {
-    ort.env.webgpu.profiling = {mode: flags.webgpu.profilingMode};
-  }
-  if (flags.webgpu?.validateInputContent !== undefined) {
-    ort.env.webgpu.validateInputContent = flags.webgpu.validateInputContent;
-  }
-}
+Object.assign(ort.env, ORT_WEB_TEST_CONFIG.options.globalEnvFlags);
 
 // Set logging configuration
 for (const logConfig of ORT_WEB_TEST_CONFIG.log) {
diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts
index b01d474788f25..7c03e5b915fd7 100644
--- a/js/web/test/test-runner.ts
+++ b/js/web/test/test-runner.ts
@@ -39,10 +39,6 @@ const ONNXRUNTIME_THRESHOLD_RELATIVE_ERROR = 1.00001;
  */
 const now = (typeof performance !== 'undefined' && performance.now) ? () => performance.now() : Date.now;
 
-function toInternalTensor(tensor: ort.Tensor): Tensor {
-  return new Tensor(
-      tensor.dims, tensor.type as Tensor.DataType, undefined, undefined, tensor.data as Tensor.NumberType);
-}
 function fromInternalTensor(tensor: Tensor): ort.Tensor {
   return new ort.Tensor(tensor.type, tensor.data as ort.Tensor.DataType, tensor.dims);
 }
@@ -330,6 +326,10 @@ export class TensorResultValidator {
   }
 
   checkTensorResult(actual: Tensor[], expected: Tensor[]): void {
+    this.checkApiTensorResult(actual.map(fromInternalTensor), expected.map(fromInternalTensor));
+  }
+
+  checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void {
     // check output size
     expect(actual.length, 'size of output tensors').to.equal(expected.length);
 
@@ -347,10 +347,6 @@ export class TensorResultValidator {
     }
   }
 
-  checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void {
-    this.checkTensorResult(actual.map(toInternalTensor), expected.map(toInternalTensor));
-  }
-
   checkNamedTensorResult(actual: Record<string, ort.Tensor>, expected: Test.NamedTensor[]): void {
     // check output size
     expect(Object.getOwnPropertyNames(actual).length, 'size of output tensors').to.equal(expected.length);
@@ -364,7 +360,7 @@ export class TensorResultValidator {
   }
 
   // This function check whether 2 tensors should be considered as 'match' or not
-  areEqual(actual: Tensor, expected: Tensor): boolean {
+  areEqual(actual: ort.Tensor, expected: ort.Tensor): boolean {
     if (!actual || !expected) {
       return false;
     }
@@ -392,13 +388,13 @@ export class TensorResultValidator {
 
     switch (actualType) {
       case 'string':
-        return this.strictEqual(actual.stringData, expected.stringData);
+        return this.strictEqual(actual.data, expected.data);
 
       case 'float32':
       case 'float64':
         return this.floatEqual(
-            actual.numberData as number[] | Float32Array | Float64Array,
-            expected.numberData as number[] | Float32Array | Float64Array);
+            actual.data as number[] | Float32Array | Float64Array,
+            expected.data as number[] | Float32Array | Float64Array);
 
       case 'uint8':
       case 'int8':
@@ -409,10 +405,8 @@ export class TensorResultValidator {
       case 'int64':
       case 'bool':
         return TensorResultValidator.integerEqual(
-            actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array |
-                Int32Array,
-            expected.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array |
-                Int32Array);
+            actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array,
+            expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array);
 
       default:
         throw new Error('type not implemented or not supported');
@@ -579,7 +573,9 @@ export async function sessionRun(options: {
       // replace the CPU tensors in feeds into GPU tensors
       for (const name in feeds) {
         if (Object.hasOwnProperty.call(feeds, name)) {
-          feeds[name] = createGpuTensorForInput(feeds[name]);
+          if (feeds[name].size > 0) {
+            feeds[name] = createGpuTensorForInput(feeds[name]);
+          }
         }
       }
     }
@@ -588,7 +584,11 @@ export async function sessionRun(options: {
       for (const name in options.outputsMetaInfo) {
         if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) {
           const {type, dims} = options.outputsMetaInfo[name];
-          fetches[name] = createGpuTensorForOutput(type, dims);
+          if (dims.some(d => d === 0)) {
+            fetches[name] = new ort.Tensor(type, [], dims);
+          } else {
+            fetches[name] = createGpuTensorForOutput(type, dims);
+          }
         }
       }
     }
@@ -633,8 +633,8 @@ export async function runModelTestSet(
   try {
     const feeds: Record<string, ort.Tensor> = {};
     const outputsMetaInfo: Record<string, ort.Tensor> = {};
-    testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor);
-    testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor);
+    testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor);
+    testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor);
     const [start, end, outputs] =
         await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding});
     if (context.perfData.count === 0) {
diff --git a/js/web/test/unittests/backends/webgl/test-conv-new.ts b/js/web/test/unittests/backends/webgl/test-conv-new.ts
index 8c186b9b36451..014fc57f21558 100644
--- a/js/web/test/unittests/backends/webgl/test-conv-new.ts
+++ b/js/web/test/unittests/backends/webgl/test-conv-new.ts
@@ -893,7 +893,9 @@ describe('New Conv tests', () => {
             const expected = cpuConv(
                 inputTensor, kernelTensor, biasTensor, testData.autoPad, testData.dilations, testData.pads,
                 testData.strides);
-            if (!validator.areEqual(actual, expected)) {
+            try {
+              validator.checkTensorResult([actual], [expected]);
+            } catch {
               console.log(actual.dims, `[${actual.numberData.slice(0, 20).join(',')},...]`);
               console.log(expected.dims, `[${expected.numberData.slice(0, 20).join(',')},...]`);
               throw new Error('Expected and Actual did not match');
diff --git a/objectivec/include/ort_coreml_execution_provider.h b/objectivec/include/ort_coreml_execution_provider.h
index a015b6fd60c8f..6ff18176ebeb2 100644
--- a/objectivec/include/ort_coreml_execution_provider.h
+++ b/objectivec/include/ort_coreml_execution_provider.h
@@ -41,6 +41,17 @@ NS_ASSUME_NONNULL_BEGIN
  */
 @property BOOL onlyEnableForDevicesWithANE;
 
+/**
+ * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also allow inputs with
+ * dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes.
+ */
+@property BOOL onlyAllowStaticInputShapes;
+
+/**
+ * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later.
+ */
+@property BOOL createMLProgram;
+
 @end
 
 @interface ORTSessionOptions (ORTSessionOptionsCoreMLEP)
diff --git a/objectivec/ort_coreml_execution_provider.mm b/objectivec/ort_coreml_execution_provider.mm
index 6340fdea1c3a7..58b47d68eea63 100644
--- a/objectivec/ort_coreml_execution_provider.mm
+++ b/objectivec/ort_coreml_execution_provider.mm
@@ -26,7 +26,10 @@ - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOpti
     const uint32_t flags =
         (options.useCPUOnly ? COREML_FLAG_USE_CPU_ONLY : 0) |
         (options.enableOnSubgraphs ? COREML_FLAG_ENABLE_ON_SUBGRAPH : 0) |
-        (options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0);
+        (options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0) |
+        (options.onlyAllowStaticInputShapes ? COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES : 0) |
+        (options.createMLProgram ? COREML_FLAG_CREATE_MLPROGRAM : 0);
+
     Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(
         [self CXXAPIOrtSessionOptions], flags));
     return YES;
diff --git a/objectivec/ort_value.mm b/objectivec/ort_value.mm
index b9dc1a9885c61..c61a7ea809237 100644
--- a/objectivec/ort_value.mm
+++ b/objectivec/ort_value.mm
@@ -148,6 +148,9 @@ - (nullable ORTValueTypeInfo*)typeInfoWithError:(NSError**)error {
 - (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError**)error {
   try {
     const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
+    if (!tensorTypeAndShapeInfo) {
+      ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
+    }
     return CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo);
   }
   ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
@@ -156,6 +159,9 @@ - (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError*
 - (nullable NSMutableData*)tensorDataWithError:(NSError**)error {
   try {
     const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
+    if (!tensorTypeAndShapeInfo) {
+      ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
+    }
     if (tensorTypeAndShapeInfo.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
       ORT_CXX_API_THROW(
           "This ORTValue holds string data. Please call tensorStringDataWithError: "
@@ -182,6 +188,9 @@ - (nullable NSMutableData*)tensorDataWithError:(NSError**)error {
 - (nullable NSArray<NSString*>*)tensorStringDataWithError:(NSError**)error {
   try {
     const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
+    if (!tensorTypeAndShapeInfo) {
+      ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
+    }
     const size_t elementCount = tensorTypeAndShapeInfo.GetElementCount();
     const size_t tensorStringDataLength = _value->GetStringTensorDataLength();
     std::vector<char> tensorStringData(tensorStringDataLength, '\0');
diff --git a/onnxruntime/contrib_ops/cpu/activations.cc b/onnxruntime/contrib_ops/cpu/activations.cc
index 556699192d2eb..3e0533dd8b9e5 100644
--- a/onnxruntime/contrib_ops/cpu/activations.cc
+++ b/onnxruntime/contrib_ops/cpu/activations.cc
@@ -2,7 +2,7 @@
 // Licensed under the MIT License.
 
 #include "core/providers/cpu/activation/activations.h"
-#include "activations.h"
+#include "contrib_ops/cpu/activations.h"
 
 namespace onnxruntime {
 namespace contrib {
@@ -26,14 +26,6 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
     KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
     ThresholdedRelu<float>);
 
-ONNX_OPERATOR_KERNEL_EX(
-    Gelu,
-    kMSDomain,
-    1,
-    kCpuExecutionProvider,
-    KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
-    Gelu<float>);
-
 ONNX_OPERATOR_KERNEL_EX(
     QuickGelu,
     kMSDomain,
diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h
index aed4c2229215d..7e64235d3fc3d 100644
--- a/onnxruntime/contrib_ops/cpu/activations.h
+++ b/onnxruntime/contrib_ops/cpu/activations.h
@@ -54,47 +54,6 @@ namespace contrib {
 DEFINE_ELE_KERNEL(ScaledTanh);
 DEFINE_ELE_KERNEL(ParametricSoftplus);
 
-template <typename T>
-class Gelu : public OpKernel {
- public:
-  Gelu(const OpKernelInfo& info) : OpKernel(info) {
-  }
-
-  Status Compute(OpKernelContext* context) const override {
-    const Tensor* input = context->Input<Tensor>(0);
-    const T* input_data = input->Data<T>();
-
-    Tensor* output = context->Output(0, input->Shape());
-    T* output_data = output->MutableData<T>();
-
-    concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
-    int64_t elem_count = input->Shape().Size();
-    constexpr int64_t length_per_task = 4096;  // this number comes from FastGelu.
-    int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
-    concurrency::ThreadPool::TryBatchParallelFor(
-        tp, static_cast<int32_t>(task_count),
-        [&](ptrdiff_t task_idx) {
-          const auto start = task_idx * length_per_task;
-          const T* p_input = input_data + start;
-          T* p_output = output_data + start;
-          int64_t count = std::min(length_per_task, elem_count - start);
-
-          for (int64_t i = 0; i < count; i++) {
-            T value = p_input[i];
-            p_output[i] = value * static_cast<T>(M_SQRT1_2);
-          }
-
-          MlasComputeErf(p_output, p_output, narrow<size_t>(count));
-
-          for (int64_t i = 0; i < count; i++) {
-            p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f);
-          }
-        },
-        0);
-    return Status::OK();
-  }
-};
-
 // Implement a new one instead of inheriting from ElementWiseRangedTransform so that we can call
 // MlasComputeLogistic instead of using Eigen for better perf.
 template <typename T>
diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h
index d72868cd8fa9f..56c8e2911e280 100644
--- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h
+++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h
@@ -10,7 +10,7 @@ namespace onnxruntime {
 namespace contrib {
 namespace aten_ops {
 
-typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input);
+typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input);
 typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size,
                                         DLManagedTensor** dlpack_inputs, size_t output_size,
                                         DLManagedTensor** dlpack_outputs);
@@ -22,17 +22,17 @@ class ATenOperatorExecutor {
     return instance;
   }
 
-  void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) {
-    ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw);
-    p_is_cpu_argument_func_ = reinterpret_cast<IsCpuArgumentFunc>(p_is_cpu_argument_func_raw);
+  void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) {
+    ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw);
+    p_is_tensor_argument_func_ = reinterpret_cast<IsTensorArgumentFunc>(p_is_tensor_argument_func_raw);
     p_execute_aten_op_func_ = reinterpret_cast<ExecuteATenOperatorFunc>(p_execute_aten_op_func_raw);
   }
 
   bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; }
 
-  bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) {
-    ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized.");
-    return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input);
+  bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) {
+    ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized.");
+    return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input);
   }
 
   void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size,
@@ -43,7 +43,7 @@ class ATenOperatorExecutor {
   }
 
  private:
-  IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr;
+  IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr;
   ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr;
 };
 
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index a34f41d2938c6..5a0c3af05c9da 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -98,6 +98,7 @@ struct GroupQueryAttentionParameters {
   int kv_hidden_size;
   int kv_num_heads;
   int num_splits;          // number of splits for splitkv
+  int rotary_dim;          // rotary embedding dimension
   bool is_unidirectional;  // causal
   int local_window_size;
   bool kv_share_buffer;
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
index b761b1afd8529..c617533319a18 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
@@ -140,17 +140,6 @@ class AttentionCPUBase : public AttentionBase {
       if (mask_data != nullptr) {
         PrepareMask(mask_index, mask_index_dims, mask_data,
                     causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_);
-      } else {  // no any mask
-        const int memset_loop_len = batch_size * num_heads_;
-        const double memset_cost = static_cast<double>(sequence_length) * total_sequence_length;
-
-        ThreadPool::TryParallelFor(tp, memset_loop_len, memset_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
-          for (std::ptrdiff_t i = begin; i != end; ++i) {
-            const int output_offset = static_cast<int>(i) * sequence_length * total_sequence_length;
-            T* output = attention_probs + output_offset;
-            memset(output, 0, static_cast<size_t>(sequence_length) * total_sequence_length * sizeof(T));
-          }
-        });
       }
 
       const int loop_len = batch_size * num_heads_;
@@ -188,7 +177,7 @@ class AttentionCPUBase : public AttentionBase {
           // B: K'               (B x N x) T x H          (B x N x) H x T        H x T
           // C: attention_probs  (B x N x) S x T          (B x N x) S x T        S x T
           math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha,
-                                    Q + q_input_chunk_length * i, k, 1.0,
+                                    Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f,
                                     output, nullptr);
 
           if (relative_position_bias_data != nullptr) {
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
index eb25d0fd7cc1e..c4e4b4ec707fb 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
@@ -58,11 +58,12 @@ Status Reshape_BSD_to_BSNH(Tensor* qkv,
 
 // Transpose Q/K/V from BxSxNxH to BxNxSxH
 Status Transpose_BSNH_to_BNSH(const Tensor* qkv,
-                              OrtValue& qkv_transposed) {
+                              OrtValue& qkv_transposed,
+                              concurrency::ThreadPool* tp = nullptr) {
   std::vector<size_t> permutations({0, 2, 1, 3});
   gsl::span<const size_t> permutations_span{permutations};
   size_t from = 2, to = 1;
-  SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable<Tensor>(), from, to);
+  SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable<Tensor>(), from, to, nullptr, tp);
   return Status::OK();
 }
 
@@ -143,7 +144,8 @@ Status AddBiasTranspose(const Tensor* qkv,                   // Input: Q/K/V dat
   ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable<Tensor>(), batch_size, sequence_length, num_heads, head_size));
 
   // Transpose Q from BxSxNxH to BxNxSxH
-  ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable<Tensor>(), qkv_with_bias_transposed));
+  auto tp = context->GetOperatorThreadPool();
+  ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable<Tensor>(), qkv_with_bias_transposed, tp));
 
   return Status::OK();
 }
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
index 166f5c8f52f54..602dd98d8c0d6 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
@@ -1,6 +1,12 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
+#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h"
+
+#include <cstdint>
+#include <type_traits>
+
+#include "core/common/common.h"
 #include "core/common/narrow.h"
 #include "core/common/safeint.h"
 #include "core/framework/op_kernel.h"
@@ -50,6 +56,17 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level
 }
 }  // namespace
 
+bool GetType(const NodeArg& node_arg, int32_t& type) {
+  type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
+  const auto* type_proto = node_arg.TypeAsProto();
+  if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) {
+    return false;
+  }
+
+  type = type_proto->tensor_type().elem_type();
+  return true;
+}
+
 class MatMulNBits final : public OpKernel {
  public:
   MatMulNBits(const OpKernelInfo& info)
@@ -59,6 +76,17 @@ class MatMulNBits final : public OpKernel {
         block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))},
         nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
         accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr<int64_t>("accuracy_level"))} {
+    const auto& node = info.node();
+    auto input_defs = node.InputDefs();
+    // g_idx
+    if (input_defs.size() > 4) {
+      act_order_ = true;
+    }
+    int32_t type;
+    if (input_defs.size() > 3 && GetType(*input_defs[3], type)) {
+      zero_point_is_not_quant_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8;
+    }
+
     ORT_ENFORCE(nbits_ == 4,
                 "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
 #ifdef ORT_NEURAL_SPEED
@@ -88,6 +116,8 @@ class MatMulNBits final : public OpKernel {
   const size_t N_;
   const size_t block_size_;
   const size_t nbits_;
+  bool act_order_{false};
+  bool zero_point_is_not_quant_{false};
   const int64_t accuracy_level_;
   const bool column_wise_quant_{true};
   IAllocatorUniquePtr<void> packed_b_;
@@ -105,7 +135,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
                             /*out*/ bool& is_packed,
                             /*out*/ PrePackedWeights* prepacked_weights) {
   is_packed = false;
-
+  if (act_order_ || zero_point_is_not_quant_) {
+    return Status::OK();
+  }
 #if defined(ORT_NEURAL_SPEED)
 
   if (!all_constant_) {
@@ -212,7 +244,6 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prep
 
 Status MatMulNBits::Compute(OpKernelContext* ctx) const {
   concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
-
   const Tensor* a = ctx->Input<Tensor>(0);
   const auto* a_data = a->Data<float>();
 
@@ -257,11 +288,14 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
 #endif  // defined(ORT_NEURAL_SPEED)
 
   const Tensor* scales = ctx->Input<Tensor>(2);
-  const Tensor* zero_points = ctx->Input<Tensor>(3);
+  const Tensor* zero_points = ctx->InputCount() > 3 ? ctx->Input<Tensor>(3) : nullptr;
+  const Tensor* reorder_idx = ctx->InputCount() > 4 ? ctx->Input<Tensor>(4) : nullptr;
+
   const auto* scales_data = scales->Data<float>();
-  const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data<uint8_t>();
+  const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
 
   TensorShape b_shape({static_cast<int64_t>(N_), static_cast<int64_t>(K_)});
+  const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data<int32_t>();
 
   MatMulComputeHelper helper;
   ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));
@@ -281,8 +315,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
   const size_t K = static_cast<size_t>(helper.K());
   const size_t lda = helper.Lda(false);
 
-  const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(),
-                                               [](size_t offset) { return offset == 0; });
+  const bool has_single_b_matrix =
+      (!act_order_) && (!zero_point_is_not_quant_) &&
+      std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; });
 
   if (has_single_b_matrix) {
     const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
@@ -328,22 +363,50 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
   const uint8_t* b_data = b->Data<uint8_t>();
 
   const size_t ldb = helper.Ldb(true);
-
   AllocatorPtr allocator;
   ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
   auto tmp_b_data_ptr = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_);
-  // dequantize b, only 4b quantization is supported for now
-  MlasDequantizeBlockwise<float, 4>(
-      tmp_b_data_ptr.get(),               // dequantized output
-      b_data,                             // quantized input
-      scales_data,                        // quantization scales
-      zero_points_data,                   // quantization zero points
-      static_cast<int32_t>(block_size_),  // quantization block size
-      column_wise_quant_,                 // columnwise quantization or row-wise
-      static_cast<int32_t>(K_),           // number of rows in quantized input
-      static_cast<int32_t>(N_),           // number of columns in quantized input
-      thread_pool);
-
+  if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType<float>())) {
+    // dequantize b, only 4b quantization is supported for now
+    MlasDequantizeBlockwise<float, 4>(
+        tmp_b_data_ptr.get(),                           // dequantized output
+        b_data,                                         // quantized input
+        scales_data,                                    // quantization scales
+        static_cast<const uint8_t*>(zero_points_data),  // quantization zero points
+        static_cast<int32_t>(block_size_),              // quantization block size
+        column_wise_quant_,                             // columnwise quantization or row-wise
+        static_cast<int32_t>(K_),                       // number of rows in quantized input
+        static_cast<int32_t>(N_),                       // number of columns in quantized input
+        thread_pool);
+  } else {
+    ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now");
+    // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!!
+    if ((zero_points && zero_points->IsDataType<float>())) {
+      DequantizeBlockwise<float, float>(
+          tmp_b_data_ptr.get(),                         // dequantized output
+          b_data,                                       // quantized input
+          scales_data,                                  // quantization scales
+          static_cast<const float*>(zero_points_data),  // quantization zero points
+          reorder_idx_data,
+          static_cast<int32_t>(block_size_),  // quantization block size
+          column_wise_quant_,                 // columnwise quantization or row-wise
+          static_cast<int32_t>(K_),           // number of rows in quantized input
+          static_cast<int32_t>(N_),           // number of columns in quantized input
+          thread_pool);
+    } else {
+      DequantizeBlockwise<float, uint8_t>(
+          tmp_b_data_ptr.get(),                           // dequantized output
+          b_data,                                         // quantized input
+          scales_data,                                    // quantization scales
+          static_cast<const uint8_t*>(zero_points_data),  // quantization zero points
+          reorder_idx_data,
+          static_cast<int32_t>(block_size_),  // quantization block size
+          column_wise_quant_,                 // columnwise quantization or row-wise
+          static_cast<int32_t>(K_),           // number of rows in quantized input
+          static_cast<int32_t>(N_),           // number of columns in quantized input
+          thread_pool);
+    }
+  }
 #if 0  // for debug
   auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_);
   MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_);
@@ -374,7 +437,9 @@ ONNX_OPERATOR_KERNEL_EX(
     kCpuExecutionProvider,
     KernelDefBuilder()
         .TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
-        .TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
+        .TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>())
+        .TypeConstraint("T3", {DataTypeImpl::GetTensorType<uint8_t>(), DataTypeImpl::GetTensorType<float>()})
+        .TypeConstraint("T4", DataTypeImpl::GetTensorType<int32_t>()),
     MatMulNBits);
 
 }  // namespace contrib
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc
new file mode 100644
index 0000000000000..7e343d85f4048
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc
@@ -0,0 +1,109 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h"
+
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <cstdint>
+#include <type_traits>
+
+#include "core/common/common.h"
+#include "core/framework/float16.h"
+#include "core/providers/common.h"
+#include "core/platform/threadpool.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+template <class T, class zeroT>
+void Dequantize4BitsKernelReOrder(
+    T* output, const uint8_t* quant_data, const T* scale_data,
+    const zeroT* zero_points, const int32_t* reorder_idx, int block_size,
+    int groups_per_threadblock, int total_groups, int out_rows, int out_cols,
+    int blockIdx_x, int threadIdx_x) {
+  const int group_id = blockIdx_x * groups_per_threadblock + ((threadIdx_x * 8) / block_size);
+  if (group_id >= total_groups) {
+    return;
+  }
+  const int scales_shape_x = (out_cols + block_size - 1) / block_size;
+  const int zero_point_shape_x = (scales_shape_x + 1) / 2;
+
+  int n_idx = group_id / scales_shape_x;
+  int kb_idx = group_id % scales_shape_x;
+  int element_offset = group_id * block_size + ((threadIdx_x * 8) & (block_size - 1));
+
+  const int out_x = element_offset % (scales_shape_x * block_size);
+  const int out_y = element_offset / (scales_shape_x * block_size);
+  if (out_y >= out_rows || out_x >= out_cols) {
+    return;
+  }
+  T* output_i = output + out_y * out_cols + out_x;
+  uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
+  const int remain_x = std::min(8, out_cols - out_x);
+  const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx_x * 8) & (block_size - 1));
+  for (int i = 0; i < remain_x; i++) {
+    int32_t rid = reorder_idx ? reorder_idx_with_off[i] : kb_idx;
+    T scale = *(scale_data + n_idx * scales_shape_x + rid);
+    float zp_f = 8;
+    if (zero_points) {
+      if constexpr (std::is_same_v<zeroT, T>) {
+        zp_f = *(zero_points + n_idx * scales_shape_x + rid);
+      } else {
+        uint8_t zp = 8;
+        zp = zero_points[n_idx * zero_point_shape_x + rid / 2];
+        zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f);
+      }
+    }
+
+    if constexpr (std::is_same_v<T, MLFloat16>) {
+      T zp_adjust = -scale * MLFloat16(zp_f);
+      output_i[i] = static_cast<float>((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust;
+    } else {
+      T zp_adjust = -scale * zp_f;
+      output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust;
+    }
+  }
+}
+
+template <typename inputT, typename zeroT>
+void DequantizeBlockwise(
+    inputT* output,              // dequantized output
+    const uint8_t* quant_data,   // quantized input
+    const inputT* scales_data,   // quantization scales
+    const zeroT* zero_points,    // quantization zero points
+    const int32_t* reorder_idx,  // reorder_idx for groupwise quantization
+    int32_t block_size,          // quantization block size
+    bool,                        // columnwise quantization or row-wise
+    int32_t K,                   // number of rows in quantized input
+    int32_t N,                   // number of columns in quantized input
+    onnxruntime::concurrency::ThreadPool* pool) {
+  auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
+  constexpr int element_per_thread = 8;
+  int groups_per_threadblock = 256 * element_per_thread / block_size;
+  int groups_per_K = ceildiv(K, block_size);
+  int total_groups = N * groups_per_K;  // total elemenets in quant_data
+  int blocks_per_grid = static_cast<int>(ceildiv(total_groups, groups_per_threadblock));
+  concurrency::ThreadPool::TrySimpleParallelFor(
+      pool, static_cast<std::ptrdiff_t>(blocks_per_grid),
+      [&](std::ptrdiff_t block_id) {
+        for (int j = 0; j < 256; j++) {
+          Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points,
+                                       reorder_idx, block_size, groups_per_threadblock,
+                                       total_groups, N, K, static_cast<int>(block_id), j);
+        }
+      });
+}
+
+template void DequantizeBlockwise<float, uint8_t>(
+    float* output, const uint8_t* quant_data, const float* scales_data,
+    const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size,
+    bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);
+
+template void DequantizeBlockwise<float, float>(
+    float* output, const uint8_t* quant_data, const float* scales_data,
+    const float* zero_points, const int32_t* reorder_idx, int32_t block_size,
+    bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);
+
+}  // namespace contrib
+}  // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h
new file mode 100644
index 0000000000000..5061ac5c800a6
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h
@@ -0,0 +1,23 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#include "core/providers/common.h"
+#include "core/platform/threadpool.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+template <typename inputT, typename zeroT>
+void DequantizeBlockwise(
+    inputT* output,              // dequantized output
+    const uint8_t* quant_data,   // quantized input
+    const inputT* scales_data,   // quantization scales
+    const zeroT* zero_points,    // quantization zero points
+    const int32_t* reorder_idx,  // quantization zero points
+    int32_t block_size,          // quantization block size
+    bool,                        // columnwise quantization or row-wise
+    int32_t K,                   // number of rows in quantized input
+    int32_t N,                   // number of columns in quantized input
+    onnxruntime::concurrency::ThreadPool* thread_pool);
+
+}  // namespace contrib
+}  // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h
index d3902f9bd68c7..e7df50408ef09 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h
+++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h
@@ -27,6 +27,7 @@
 #pragma warning(disable : 4244)
 #pragma warning(disable : 4267)
 #pragma warning(disable : 4702)
+#pragma warning(disable : 4127)
 #endif
 
 #include "bestla/bestla_prologue_a.h"
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
index 72e6d3930a548..af0904b7d6e4b 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
@@ -134,8 +134,8 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
     TensorShape no_speech_probs_shape{parameters->batch_size};
     Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape);
     if (no_speech_probs && no_speech_probs->MutableData<T>()) {
-      ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size,
-                  "no_speech_token id out of range, it is ", parameters->no_speech_token,
+      ORT_ENFORCE(parameters->no_speech_token_id >= 0 && parameters->no_speech_token_id < parameters->vocab_size,
+                  "no_speech_token_id is out of range, it is ", parameters->no_speech_token_id,
                   ", vocab_size is ", parameters->vocab_size);
       this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData<T>();
     }
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
index bb6885c3216bc..93837e785b4a4 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
@@ -153,7 +153,13 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info)
   model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", IGenerationParameters::kModelTypeWhisper));
   ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper);
 
-  no_speech_token = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_speech_token", -1LL));
+  // Token ids are defined below in the order that they appear in the tokenizer
+  translate_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("translate_token_id", -1LL));
+  transcribe_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("transcribe_token_id", -1LL));
+  start_of_lm_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("start_of_lm_token_id", -1LL));
+  no_speech_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_speech_token_id", -1LL));
+  no_timestamps_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_timestamps_token_id", -1LL));
+  beginning_timestamp_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("beginning_timestamp_token_id", -1LL));
   cross_qk_layer_head_input_id = 12;
   extra_decoding_ids_input_id = 13;
   cross_qk_output_id = 3;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
index cb62e2f7bf4da..b1dd55eb20f34 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
@@ -183,7 +183,14 @@ struct IGenerationParameters {
   // Parameters for whisper model
   bool decoder_output_cross_qk = false;
   gsl::span<const int32_t> extra_decoding_ids;
-  int32_t no_speech_token = -1;
+
+  // Token ids are defined below in the order that they appear in the tokenizer
+  int32_t translate_token_id = -1;
+  int32_t transcribe_token_id = -1;
+  int32_t start_of_lm_token_id = -1;
+  int32_t no_speech_token_id = -1;
+  int32_t no_timestamps_token_id = -1;
+  int32_t beginning_timestamp_token_id = -1;
   void* no_speech_probs = nullptr;
 
   int cross_qk_layer_head_input_id = -1;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
index 03d4e89ac20fe..231eb17d1a947 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
@@ -10,6 +10,7 @@
 #include "contrib_ops/cpu/transformers/greedy_search_parameters.h"
 #include "contrib_ops/cpu/transformers/sampling_parameters.h"
 #include "contrib_ops/cpu/transformers/generation_shared.h"
+#include <iostream>
 
 namespace onnxruntime {
 namespace contrib {
@@ -34,6 +35,14 @@ struct NextTokenScores {
   }
 };
 
+#ifdef DEBUG_GENERATION
+template <typename T>
+void DumpScores(const char* name, const NextTokenScores<T>& next_token_scores) {
+  std::cout << name << std::endl;
+  ORT_UNUSED_PARAMETER(next_token_scores);
+}
+#endif
+
 // Interface for all scorers for beam search or beam sample.
 template <typename T>
 class ILogitsProcessor {
@@ -150,19 +159,25 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor<T> {
 template <typename T>
 class TimestampLogitsProcessor : public ILogitsProcessor<T> {
  public:
-  TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index)
-      : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {}
+  TimestampLogitsProcessor(int end_of_text_token_id,          // <|endoftext|>
+                           int start_of_transcript_token_id,  // <|startoftranscript|>
+                           int translate_token_id,            // <|translate|>
+                           int transcribe_token_id,           // <|transcribe|>
+                           int start_of_lm_token_id,          // <|startoflm|>
+                           int no_timestamps_token_id,        // <|notimestamps|>
+                           int beginning_timestamp_token_id,  // <|0.00|>
+                           int max_initial_timestamp_index)
+      : end_of_text_token_id_(end_of_text_token_id),
+        start_of_transcript_token_id_(start_of_transcript_token_id),
+        translate_token_id_(translate_token_id),
+        transcribe_token_id_(transcribe_token_id),
+        start_of_lm_token_id_(start_of_lm_token_id),
+        no_timestamps_token_id_(no_timestamps_token_id),
+        beginning_timestamp_token_id_(beginning_timestamp_token_id),
+        max_initial_timestamp_index_(max_initial_timestamp_index) {}
 
   void Process(const ISequences* sequences,
                NextTokenScores<T>& next_token_scores) override {
-    // TODO: translate_token_id_ and transcribe_token_id_ need to support both multilingual and English-only models.
-    const int beg_token_id_ = eos_token_id_ + 107;
-    const int not_token_id_ = eos_token_id_ + 106;
-    const int solm_token_id_ = eos_token_id_ + 105;
-    const int sot_token_id_ = eos_token_id_ + 1;
-    constexpr int translate_token_id_ = 50358;
-    constexpr int transcribe_token_id_ = 50359;
-
     const int batch_beam_size = next_token_scores.batch_beam_size;
     const int vocab_size = next_token_scores.vocab_size;
     for (int i = 0; i < batch_beam_size; i++) {
@@ -174,7 +189,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor<T> {
       size_t sample_begin = 0;
       for (size_t j = 0; j < seq_length; j++) {
         sample_begin++;
-        if (sequence[j] >= beg_token_id_) {
+        if (sequence[j] >= beginning_timestamp_token_id_) {
           break;
         }
       }
@@ -182,30 +197,30 @@ class TimestampLogitsProcessor : public ILogitsProcessor<T> {
       // Suppress tokens
       for (int j = 0; j < vocab_size; j++) {
         // Suppress notimestamps and solm tokens
-        if (j == not_token_id_ || j == solm_token_id_) {
+        if (j == no_timestamps_token_id_ || j == start_of_lm_token_id_) {
           beam_token_scores[j] = std::numeric_limits<T>::lowest();
         }
 
         // Suppress sot, translate and transcribe tokens
         if (seq_length > sample_begin) {
-          if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) {
+          if (j == start_of_transcript_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) {
             beam_token_scores[j] = std::numeric_limits<T>::lowest();
           }
         }
       }
 
       // Timestamps should be in pair except the first one
-      const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_;
-      const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_;
+      const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beginning_timestamp_token_id_;
+      const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beginning_timestamp_token_id_;
       if (last_was_timestamp) {
         if (penultimate_was_timestamp) {
           // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated
-          for (int j = beg_token_id_; j < vocab_size; j++) {
+          for (int j = beginning_timestamp_token_id_; j < vocab_size; j++) {
             beam_token_scores[j] = std::numeric_limits<T>::lowest();
           }
         } else {
           // If timestamp doesn't show up in pair, generate timestamp
-          for (int j = 0; j < eos_token_id_; j++) {
+          for (int j = 0; j < end_of_text_token_id_; j++) {
             beam_token_scores[j] = std::numeric_limits<T>::lowest();
           }
         }
@@ -214,7 +229,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor<T> {
       // Find timestamp tokens
       std::vector<int32_t> timestamps;
       for (const auto& word_id : sequence) {
-        if (word_id >= beg_token_id_) {
+        if (word_id >= beginning_timestamp_token_id_) {
           timestamps.push_back(word_id);
         }
       }
@@ -231,13 +246,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor<T> {
           timestamp_last = timestamps.back() + 1;
         }
 
-        for (int j = beg_token_id_; j < timestamp_last; j++) {
+        for (int j = beginning_timestamp_token_id_; j < timestamp_last; j++) {
           beam_token_scores[j] = std::numeric_limits<T>::lowest();
         }
       }
 
       if (seq_length == sample_begin) {
-        const int last_allowed = beg_token_id_ + max_initial_timestamp_index_;
+        const int last_allowed = beginning_timestamp_token_id_ + max_initial_timestamp_index_;
         for (int j = last_allowed + 1; j < vocab_size; j++) {
           beam_token_scores[j] = std::numeric_limits<T>::lowest();
         }
@@ -247,8 +262,8 @@ class TimestampLogitsProcessor : public ILogitsProcessor<T> {
       float timestamp_logprob = std::numeric_limits<T>::lowest();
       {
         float logsumexp = 0.0f;
-        const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end());
-        for (int j = beg_token_id_; j < vocab_size; ++j) {
+        const float logprob_max = *std::max_element(beam_token_scores.begin() + beginning_timestamp_token_id_, beam_token_scores.end());
+        for (int j = beginning_timestamp_token_id_; j < vocab_size; ++j) {
           if (beam_token_scores[j] > std::numeric_limits<T>::lowest()) {
             logsumexp += expf(beam_token_scores[j] - logprob_max);
           }
@@ -258,9 +273,9 @@ class TimestampLogitsProcessor : public ILogitsProcessor<T> {
         }
       }
 
-      const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_);
+      const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beginning_timestamp_token_id_);
       if (timestamp_logprob > max_text_token_logprob) {
-        for (int j = 0; j < beg_token_id_; ++j) {
+        for (int j = 0; j < beginning_timestamp_token_id_; ++j) {
           beam_token_scores[j] = std::numeric_limits<T>::lowest();
         }
       }
@@ -268,7 +283,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor<T> {
   }
 
  private:
-  int eos_token_id_;
+  int end_of_text_token_id_;
+  int start_of_transcript_token_id_;
+  int translate_token_id_;
+  int transcribe_token_id_;
+  int start_of_lm_token_id_;
+  int no_timestamps_token_id_;
+  int beginning_timestamp_token_id_;
   int max_initial_timestamp_index_;
 };
 
@@ -330,7 +351,15 @@ class LogitsProcessorList : public ILogitsProcessorList {
     // Add timestamp processor for whisper model
     if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) {
       constexpr int max_initial_timestamp_index = 50;
-      timestamp_processor_ = std::make_unique<TimestampLogitsProcessor<float>>(parameters.eos_token_id, max_initial_timestamp_index);
+      // Token ids are passed below in the order that they appear in the tokenizer
+      timestamp_processor_ = std::make_unique<TimestampLogitsProcessor<float>>(parameters.eos_token_id,
+                                                                               parameters.decoder_start_token_id,
+                                                                               parameters.translate_token_id,
+                                                                               parameters.transcribe_token_id,
+                                                                               parameters.start_of_lm_token_id,
+                                                                               parameters.no_timestamps_token_id,
+                                                                               parameters.beginning_timestamp_token_id,
+                                                                               max_initial_timestamp_index);
       processor_list_.push_back(timestamp_processor_.get());
     }
 
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.cc b/onnxruntime/contrib_ops/cuda/activation/activations.cc
index 1a86c5dbece5a..6303858b9bd48 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations.cc
+++ b/onnxruntime/contrib_ops/cuda/activation/activations.cc
@@ -49,7 +49,6 @@ namespace cuda {
 UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain);
 UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain);
 UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain);
-UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain);
 UNARY_ACTIVATION_OP_HFD(QuickGelu, 1, kMSDomain);
 
 REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16)
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.h b/onnxruntime/contrib_ops/cuda/activation/activations.h
index ab339f276c2bd..fc9a71b0b7fa1 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations.h
+++ b/onnxruntime/contrib_ops/cuda/activation/activations.h
@@ -66,17 +66,6 @@ class ScaledTanh final : public UnaryElementwise {
   float beta_;
 };
 
-template <typename T>
-class Gelu final : public UnaryElementwise {
- public:
-  Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {}
-
-  Status ComputeInternal(OpKernelContext* context) const override;
-
- private:
-  MAKE_FUNC_CTX_NULL()
-};
-
 template <typename T>
 class QuickGelu final : public UnaryElementwise {
  public:
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
index 0c856815fd437..36f33fbb24c18 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
@@ -36,20 +36,6 @@ struct OP_ScaledTanh : public CtxScaledTanh {
   }
 };
 
-template <typename T>
-struct OP_Gelu : public CtxGelu {
-  __device__ __inline__ T operator()(const T& a) const {
-    return _Gelu(a);
-  }
-};
-
-template <>
-struct OP_Gelu<half> : public CtxGelu {
-  __device__ __inline__ half operator()(const half& a) const {
-    return static_cast<half>(_Gelu(static_cast<float>(a)));
-  }
-};
-
 template <typename T>
 struct OP_QuickGelu : public CtxQuickGelu {
   __device__ __inline__ T operator()(const T& a) const {
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
index 5d18283a395e3..782d4bf59a5ad 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
+++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
@@ -11,14 +11,12 @@ namespace cuda {
 typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine;
 typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus;
 typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh;
-typedef onnxruntime::cuda::CtxNull CtxGelu;
 typedef onnxruntime::cuda::CtxAlpha CtxQuickGelu;
 
 #define UNARY_CONTRIB_ACTIVATION_OPS()         \
   UNARY_ACTIVATION_OP_NAME(ScaledTanh)         \
   UNARY_ACTIVATION_OP_NAME(Affine)             \
   UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \
-  UNARY_ACTIVATION_OP_NAME(Gelu)               \
   UNARY_ACTIVATION_OP_NAME(QuickGelu)
 
 #define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name);
diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
index 1ea2540db486f..9e6752b451868 100644
--- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
@@ -843,11 +843,11 @@ void InvokeAddBiasTransposeTrt(
 
 template <>
 void LaunchAddBiasTransposeTrt(
-    cudaStream_t stream, const int max_threads_per_block,
-    const int batch_size, const int sequence_length,
-    const int num_heads, const int head_size,
-    const float* biases, const float* query, const float* key, const float* value, float* output,
-    bool is_cross_attention, int kv_sequence_length) {
+    cudaStream_t /*stream*/, const int /*max_threads_per_block*/,
+    const int /*batch_size*/, const int /*sequence_length*/,
+    const int /*num_heads*/, const int /*head_size*/,
+    const float* /*biases*/, const float* /*query*/, const float* /*key*/, const float* /*value*/, float* /*output*/,
+    bool /*is_cross_attention*/, int /*kv_sequence_length*/) {
   ORT_ENFORCE(false, "Shall not call this since fused kernel does not support float input.");
 }
 
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
index c20f42c4d06bc..a93fdf74dc28c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@@ -58,12 +58,12 @@ size_t AlignSize(size_t bytes) {
   return bytesAligned;
 }
 
-void CumulatedSequenceLengthCache::Initialize(int32_t sequence_length, cudaStream_t stream) {
-  if (this->sequence_length != sequence_length) {
+void CumulatedSequenceLengthCache::Initialize(int32_t seq_length, cudaStream_t stream) {
+  if (this->sequence_length != seq_length) {
     ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0);
     LaunchTrtSequenceOffset(reinterpret_cast<int32_t*>(buffer.get()), nullptr,
-                            this->max_batch_size, sequence_length, stream);
-    this->sequence_length = sequence_length;
+                            this->max_batch_size, seq_length, stream);
+    this->sequence_length = seq_length;
   }
 }
 
@@ -213,9 +213,9 @@ Status FusedTrtCrossAttention(
 
 template <>
 Status FusedTrtCrossAttention<float>(
-    cudaStream_t stream,
-    contrib::AttentionParameters& parameters,
-    AttentionData<float>& data) {
+    cudaStream_t /*stream*/,
+    contrib::AttentionParameters& /*parameters*/,
+    AttentionData<float>& /*data*/) {
   return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
                          "Trt fused cross attention does not support float tensor");
 }
@@ -276,9 +276,9 @@ Status FusedTrtSelfAttention(
 // Template Specialization for float type
 template <>
 Status FusedTrtSelfAttention<float>(
-    cudaStream_t stream,
-    contrib::AttentionParameters& parameters,
-    AttentionData<float>& data) {
+    cudaStream_t /*stream*/,
+    contrib::AttentionParameters& /*parameters*/,
+    AttentionData<float>& /*data*/) {
   return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
                          "Trt fused attention does not support float tensor");
 }
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
index a513d9e8d2211..b843966d88e85 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
@@ -231,7 +231,7 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
                                 AttentionData<T>& data,
                                 cudaStream_t stream,
                                 int max_threads_per_block,
-                                T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
+                                T* /*q*/, T* /*k*/, T* /*v*/, AttentionQkvFormat& qkv_format) {
   const int batch_size = parameters.batch_size;
   const int sequence_length = parameters.sequence_length;
   const int num_heads = parameters.num_heads;
@@ -279,7 +279,7 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
                                AttentionData<T>& data,
                                cudaStream_t stream,
                                int max_threads_per_block,
-                               T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
+                               T* /*q*/, T* k, T* /*v*/, AttentionQkvFormat& qkv_format) {
   const int batch_size = parameters.batch_size;
   const int kv_sequence_length = parameters.kv_sequence_length;
   const int num_heads = parameters.num_heads;
diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h
index db78722cc0e4c..c12cb374d9adf 100644
--- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h
+++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h
@@ -242,18 +242,18 @@ void DispatchIsAligned(const MemoryEfficientAttentionParams& params) {
   using AlignedAK = AttentionKernel<T, ArchTag, true, queries_per_block, keys_per_block, single_value_iteration>;
 #if defined(_MSC_VER) && !defined(__clang__)
 #pragma warning(push)
-#pragma warning(disable : 6287)
+#pragma warning(disable : 6287 4189)  // kAligned is used via capture so 4189 warning seems incorrect
 #endif
   // Run a more efficient kernel with `isAligned=True` when memory is correctly aligned.
   bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 &&
                     params.qk_head_size % AlignedAK::kAlignmentK == 0 &&
                     params.v_head_size % AlignedAK::kAlignmentV == 0;
-#if defined(_MSC_VER) && !defined(__clang__)
-#pragma warning(pop)
-#endif
   DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() {
                   LaunchCutlassFmha<T, ArchTag, kIsAligned, queries_per_block, keys_per_block, single_value_iteration>(params);
                 }));
+#if defined(_MSC_VER) && !defined(__clang__)
+#pragma warning(pop)
+#endif
 }
 
 template <typename T, typename ArchTag>
diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu
index e24d9da94c964..c0b1996789183 100644
--- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu
@@ -17,7 +17,7 @@ Status DecoderQkvToContext(
     const cudaDeviceProp& device_prop,
     Stream* ort_stream,
     cublasHandle_t& cublas,
-    const size_t element_size,
+    const size_t /*element_size*/,
     const int batch_size,
     const int sequence_length,
     const int kv_sequence_length,
diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
index 892f5c181a607..8b8e4e267f895 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
@@ -4,9 +4,13 @@
 #include "core/providers/cuda/cuda_common.h"
 #include "core/providers/cuda/cudnn_common.h"
 #include "fast_gelu.h"
-#include "fast_gelu_impl.h"
+#include "core/providers/cuda/tensor/gelu_impl.h"
 #include "contrib_ops/cpu/bert/bias_gelu_helper.h"
-#include "transformer_common.h"
+#ifdef USE_ROCM
+#include "contrib_ops/rocm/bert/elementwise.h"
+#else
+#include "contrib_ops/cuda/bert/transformer_common.h"
+#endif
 
 namespace onnxruntime {
 namespace contrib {
@@ -31,8 +35,10 @@ using namespace ONNX_NAMESPACE;
 
 template <typename T>
 FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
+#ifndef USE_ROCM
   const TransformerOptions* options = TransformerOptions::GetInstance();
   use_half2_ = !options->DisableHalf2();
+#endif
 }
 
 template <typename T>
@@ -50,6 +56,13 @@ Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
   int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
   typedef typename ToCudaType<T>::MappedType CudaT;
 
+#ifdef USE_ROCM
+  return LaunchElementwiseKernel<functor::FastGeLU, CudaT>(
+      GetTuningContext(), context->GetComputeStream(),
+      reinterpret_cast<const CudaT*>(input->Data<T>()), static_cast<int>(input_length),
+      (nullptr != bias) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, static_cast<int>(bias_length),
+      reinterpret_cast<CudaT*>(output->MutableData<T>()));
+#else
   return LaunchFastGeluKernel<CudaT>(GetDeviceProp(),
                                      Stream(context),
                                      static_cast<int>(input_length),
@@ -58,6 +71,7 @@ Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
                                      (nullptr != bias) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
                                      reinterpret_cast<CudaT*>(output->MutableData<T>()),
                                      use_half2_);
+#endif
 }
 
 }  // namespace cuda
diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h
index 3e642a70afef5..26f3bd5a03928 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h
+++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h
@@ -18,7 +18,9 @@ class FastGelu final : public CudaKernel {
   Status ComputeInternal(OpKernelContext* ctx) const override;
 
  private:
+#ifndef USE_ROCM
   bool use_half2_;
+#endif
 };
 
 }  // namespace cuda
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
index 2c296bf4f8483..0f58a74c4d2fd 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
@@ -371,6 +371,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
                        int seqlen_q,
                        int seqlen_k,
                        int seqlen_k_new,
+                       int rotary_dim,
                        const float softmax_scale,
                        bool is_causal,
                        bool is_bf16,
@@ -448,7 +449,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
     params.rotary_cos_ptr = rotary_cos;
     params.rotary_sin_ptr = rotary_sin;
     params.is_rotary_interleaved = is_rotary_interleaved;
-    params.rotary_dim = (head_size / 16) * 16;
+    params.rotary_dim = rotary_dim;
   }
 
   params.num_splits = num_splits;
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
index 387d1cf9d84fe..24891bcc4d499 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
@@ -96,6 +96,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
                        int seqlen_q,
                        int seqlen_k,
                        int seqlen_k_new,
+                       int rotary_dim,
                        const float softmax_scale,
                        bool is_causal,
                        bool is_bf16,
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
index fe56f84f0a886..814aa1fb3c8f0 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
@@ -110,6 +110,11 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
   parameters.do_rotary = do_rotary_;
   parameters.rotary_interleaved = rotary_interleaved_;
 
+  if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) {
+    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+                           "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1");
+  }
+
   TensorShapeVector output_shape(3);
   output_shape[0] = static_cast<int64_t>(parameters.batch_size);
   output_shape[1] = static_cast<int64_t>(sequence_length);
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
index 853e1a710cb24..1a7c3fcea3fa3 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
@@ -205,6 +205,7 @@ Status CheckInputs(const Tensor* query,
   int total_sequence_length = *((*total_seqlen).template Data<int32_t>());
   int present_sequence_length = std::max(total_sequence_length, past_sequence_length);
 
+  int rotary_dim = 0;
   if (cos_cache != nullptr && sin_cache != nullptr) {
     const auto& cos_dims = cos_cache->Shape().GetDims();
     const auto& sin_dims = sin_cache->Shape().GetDims();
@@ -214,22 +215,27 @@ Status CheckInputs(const Tensor* query,
                              "head_size shall be a multiple of 16. Got head_size % 16 == ",
                              head_size % 16);
     }
-    if (cos_dims[0] != present_sequence_length) {
+    if (cos_dims[0] < present_sequence_length) {
       return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
-                             "cos_cache dimension 0 must be of present_sequence_length.");
+                             "cos_cache dimension 0 should be of max_sequence_length.");
     }
-    if (sin_dims[0] != present_sequence_length) {
+    if (sin_dims[0] < present_sequence_length) {
       return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
-                             "sin_cache dimension 0 must be of present_sequence_length.");
+                             "sin_cache dimension 0 should be of max_sequence_length.");
     }
-    if (cos_dims[1] != (head_size / 16) * 8) {
+    if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
       return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
                              "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
     }
-    if (sin_dims[1] != (head_size / 16) * 8) {
+    if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) {
       return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
                              "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
     }
+    if (cos_dims[1] != sin_dims[1]) {
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+                             "cos_cache and sin_cache dimension 1 must be the same.");
+    }
+    rotary_dim = static_cast<int>(cos_dims[1] * 2);
   } else if (cos_cache != nullptr || sin_cache != nullptr) {
     return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
                            "Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
@@ -248,6 +254,7 @@ Status CheckInputs(const Tensor* query,
     output_parameters->head_size = head_size;
     output_parameters->kv_hidden_size = kv_hidden_size;
     output_parameters->kv_num_heads = kv_num_heads;
+    output_parameters->rotary_dim = rotary_dim;
     output_parameters->is_packed_qkv = is_packed_qkv;
     output_parameters->is_unidirectional = true;
     output_parameters->is_prompt = is_prompt;
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
index d88e9a49fb5ee..afba83be34e2d 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
@@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k,
 // Convert Past to Total sequence length tensor
 Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k,
                            int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream,
-                           const int threads_per_block) {
+                           const int /*threads_per_block*/) {
   if (parameters.is_prompt) {
     return Status::OK();
   }
@@ -530,7 +530,7 @@ Status FlashAttention(
       device_prop, stream, query, present_key, present_value, key, value, data.output,
       reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache,
       batch_size, num_heads, kv_num_heads, head_size, sequence_length,
-      parameters.seqlen_present_kv_cache, kv_sequence_length,
+      parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim,
       scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
       reinterpret_cast<void*>(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved,
       parameters.is_packed_qkv));
@@ -655,7 +655,7 @@ Status EfficientAttention(
 template <typename T>
 Status QkvToContext(
     const cudaDeviceProp& device_prop,
-    cublasHandle_t& cublas,
+    cublasHandle_t& /*cublas*/,
     Stream* ort_stream,
     contrib::GroupQueryAttentionParameters& parameters,
     GroupQueryAttentionData<T>& data) {
diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
index ce7ac3796dbe1..a84a310b46ca0 100644
--- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
@@ -440,7 +440,7 @@ Status LaunchTransposeRemovePadding(
 
 template <typename T>
 Status FusedScaledDotProductAttention(
-    const cudaDeviceProp& device_prop,
+    const cudaDeviceProp& /*device_prop*/,
     cudaStream_t stream,
     PackedAttentionParameters& parameters,
     PackedAttentionData<T>& data) {
diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu
index 49029da12a308..982c7eaa2cb2c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu
@@ -381,7 +381,7 @@ void InvokeTranspose(
     const T* query, const T* key, const T* value, const T* bias, T* output,
     const int batch_size, const int sequence_length,
     const int num_heads, const int qk_head_size, const int v_head_size,
-    AttentionQkvFormat source_format, AttentionQkvFormat target_format,
+    [[maybe_unused]] AttentionQkvFormat source_format, AttentionQkvFormat target_format,
     const int32_t* token_offset, int32_t token_count,
     cudaStream_t stream) {
   if (key != nullptr && value != nullptr) {
@@ -551,7 +551,7 @@ void LaunchTranspose(
 
 template <typename T>
 Status FusedAttentionTrt(
-    const cudaDeviceProp& device_prop,
+    const cudaDeviceProp& /*device_prop*/,
     cudaStream_t stream,
     PackedAttentionParameters& parameters,
     PackedMultiHeadAttentionData<T>& data) {
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
index 9de7ba3885c3c..ab7479f2938fe 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
@@ -82,8 +82,6 @@ Status RotaryEmbedding<T>::ComputeInternal(OpKernelContext* context) const {
       interleaved,
       device_prop.maxThreadsPerBlock,
       parameters.transposed);
-
-  return Status::OK();
 }
 
 }  // namespace cuda
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
index c6637041f05bd..3a14161f29e9f 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
@@ -93,7 +93,7 @@ Status LaunchRotaryEmbeddingKernel(
     const int num_heads,
     const int head_size,
     const int rotary_embedding_dim,
-    const int max_sequence_length,
+    const int /*max_sequence_length*/,
     const int position_ids_format,
     const bool interleaved,
     const int max_threads_per_block,
diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu
index 8fb6575d27cc0..4a4e3eeecf642 100644
--- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu
@@ -53,9 +53,9 @@ class FusedMHARunnerFP16v2::mhaImpl {
 
   ~mhaImpl() {}
 
-  void setup(const int S, const int B) {
+  void setup(const int seq_len, const int B) {
     // For bert and vit, use flash attention when sequence length is larger than the threshold.
-    use_flash_attention = is_flash_attention(S);
+    use_flash_attention = is_flash_attention(seq_len);
 
     params.force_unroll = use_flash_attention;
 
@@ -68,26 +68,26 @@ class FusedMHARunnerFP16v2::mhaImpl {
       warps_n = 1;
     } else {
       if (sm == 70) {
-        if (S == 64 || S == 96) {
+        if (seq_len == 64 || seq_len == 96) {
           warps_m = 2;
           warps_n = 2;
-        } else if (S == 128) {
+        } else if (seq_len == 128) {
           warps_m = 1;
           warps_n = 4;
-        } else if (S == 256 || S == 384) {
+        } else if (seq_len == 256 || seq_len == 384) {
           warps_m = 1;
           warps_n = 8;
         } else {
           ORT_ENFORCE(false, "Unsupported sequence length");
         }
       } else {
-        if (S == 32 || S == 64 || S == 96 || S == 128) {
+        if (seq_len == 32 || seq_len == 64 || seq_len == 96 || seq_len == 128) {
           warps_m = 2;
           warps_n = 2;
-        } else if (S == 192 || S == 256) {
+        } else if (seq_len == 192 || seq_len == 256) {
           warps_m = 1;
           warps_n = 4;
-        } else if (S == 384) {
+        } else if (seq_len == 384) {
           warps_m = 1;
           warps_n = 8;
         } else {
@@ -99,7 +99,7 @@ class FusedMHARunnerFP16v2::mhaImpl {
     // The number of threads per CTA.
     threads_per_cta = warps_m * warps_n * warps_k * 32;
     // The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension.
-    xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m);
+    xmmas_m = (seq_len + 16 * warps_m - 1) / (16 * warps_m);
 
     const float scale_bmm1 = interface->mScale;
     const float scale_softmax = 1.f;  // Seems to be only required for int8
@@ -111,7 +111,7 @@ class FusedMHARunnerFP16v2::mhaImpl {
 
     params.b = B;
     params.h = interface->mNumHeads;
-    params.s = S;
+    params.s = seq_len;
     params.d = interface->mHeadSize;
 
     params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half);
@@ -121,7 +121,7 @@ class FusedMHARunnerFP16v2::mhaImpl {
     has_causal_mask = false;
   }
 
-  void setup_causal_masked_fmha(const int S, const int B) {
+  void setup_causal_masked_fmha(const int seq_len, const int B) {
     const float scale_bmm1 = interface->mScale;
     const float scale_softmax = 1.f;  // Seems to be only required for int8
     const float scale_bmm2 = 1.f;
@@ -132,7 +132,7 @@ class FusedMHARunnerFP16v2::mhaImpl {
 
     params.b = B;
     params.h = interface->mNumHeads;
-    params.s = S;
+    params.s = seq_len;
     params.d = interface->mHeadSize;
 
     params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half);
@@ -182,30 +182,30 @@ class FusedMHARunnerFP16v2::mhaImpl {
       return max_seq_len;
     }
 
-    int S = max_seq_len;
+    int seq_len = max_seq_len;
     if (max_seq_len <= 32) {
-      S = (sm == 70) ? 64 : 32;
+      seq_len = (sm == 70) ? 64 : 32;
     } else if (max_seq_len <= 64) {
-      S = 64;
+      seq_len = 64;
     } else if (max_seq_len <= 96) {
-      S = 96;
+      seq_len = 96;
     } else if (max_seq_len <= 128) {
-      S = 128;
+      seq_len = 128;
     } else if (max_seq_len <= 192) {
-      S = (sm == 70) ? 256 : 192;
+      seq_len = (sm == 70) ? 256 : 192;
     } else if (max_seq_len <= 256) {
-      S = 256;
+      seq_len = 256;
     } else if (max_seq_len <= 384) {
-      S = 384;
+      seq_len = 384;
     }
 
-    return S;
+    return seq_len;
   }
 
  protected:
-  bool is_flash_attention(const int S) const {
+  bool is_flash_attention(const int seq_len) const {
     ORT_ENFORCE(interface->mHasCausalMask == false);
-    return interface->mEnableFlashAttention && S >= kMinSequenceLengthFlashAttention;
+    return interface->mEnableFlashAttention && seq_len >= kMinSequenceLengthFlashAttention;
   }
 
  private:
@@ -232,12 +232,12 @@ FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads,
       pimpl(new mhaImpl(this)) {
 }
 
-void FusedMHARunnerFP16v2::setup(const int S, const int B) {
-  MHARunner::setup(S, B);
+void FusedMHARunnerFP16v2::setup(const int seq_len, const int B) {
+  MHARunner::setup(seq_len, B);
   if (mHasCausalMask) {
-    pimpl->setup_causal_masked_fmha(S, B);
+    pimpl->setup_causal_masked_fmha(seq_len, B);
   } else {
-    pimpl->setup(S, B);
+    pimpl->setup(seq_len, B);
   }
 }
 
diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
index 40a667ffd5d83..2efc37cf98010 100644
--- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
+++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
@@ -1,6 +1,8 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
+#include <utility>
+
 #include "core/common/safeint.h"
 #include "core/providers/cuda/cuda_common.h"
 #include "contrib_ops/cuda/bert/transformer_cuda_common.h"
@@ -35,6 +37,7 @@ using namespace ONNX_NAMESPACE;
 
 template <typename T>
 ShardedMoE<T>::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) {
+  ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("tensor_shards", &tensor_shards_).IsOK());
   ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("local_experts_start_index", &local_experts_start_index_).IsOK());
   rank_to_experts_start_index_.resize(nccl_->Size());
   // Initialize rank_to_experts_start_index_[0] to a value to convey that it is not initialized.
@@ -55,27 +58,36 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
   // Create a {Rank, ExpertsStartIndex} map on Host.
   AutoDestoryCudaEvent cuda_event;
   cudaEvent_t& copy_event = cuda_event.Get();
-  ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event));
 
   const Tensor* input = context->Input<Tensor>(0);
   const Tensor* router_probs = context->Input<Tensor>(1);
   const Tensor* fc1_experts_weights = context->Input<Tensor>(2);
-  const Tensor* fc2_experts_weights = context->Input<Tensor>(3);
-  const Tensor* fc1_experts_bias_optional = context->Input<Tensor>(4);
+  const Tensor* fc1_experts_bias_optional = context->Input<Tensor>(3);
+  const Tensor* fc2_experts_weights = context->Input<Tensor>(4);
   const Tensor* fc2_experts_bias_optional = context->Input<Tensor>(5);
+  const Tensor* fc3_experts_weights_optional = context->Input<Tensor>(6);
+  const Tensor* fc3_experts_bias_optional = context->Input<Tensor>(7);
+
+  MoEParameters moe_params(tensor_shards_);
+  ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional,
+                                  fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional,
+                                  fc3_experts_bias_optional));
 
-  MoEParameters moe_params;
-  ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights,
-                                  fc1_experts_bias_optional, fc2_experts_bias_optional));
   ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0,
                     "num_experts should be divisible by world_size");
 
-  ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm);
+  if (moe_params.parallel_type == MoEParallelType::EP || moe_params.parallel_type == MoEParallelType::EPAndTP) {
+    ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event));
+  }
+
+  ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm,
+                                                                     fc3_experts_weights_optional != nullptr,
+                                                                     normalize_routing_weights_);
 
   size_t ws_size =
-      moe_runner.getWorkspaceSize(static_cast<int>(moe_params.num_rows), static_cast<int>(moe_params.hidden_size),
-                                  static_cast<int>(moe_params.inter_size), static_cast<int>(moe_params.num_experts),
-                                  static_cast<int>(k_));
+      moe_runner.getWorkspaceSize(static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
+                                  static_cast<size_t>(moe_params.inter_size),
+                                  static_cast<size_t>(moe_params.num_experts), static_cast<size_t>(k_));
 
   size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT);
   size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT);
@@ -93,19 +105,25 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
   IAllocatorUniquePtr<void> expert_for_source_row =
       IAllocator::MakeUniquePtr<void>(allocator, expert_for_source_row_size, false, stream);
 
-  // fc1_scales and fc2_scales are used in quantized MoE
-  const CudaT* fc1_scales_ptr = nullptr;
-  const CudaT* fc2_scales_ptr = nullptr;
+  const CudaT* fc_scales_ptr = nullptr;
 
   moe_runner.run_moe_fc(reinterpret_cast<const CudaT*>(input->template Data<T>()),
                         reinterpret_cast<const CudaT*>(router_probs->template Data<T>()),
                         reinterpret_cast<const CudaT*>(fc1_experts_weights->template Data<T>()),
-                        std::move(fc1_scales_ptr),
+                        std::move(fc_scales_ptr),
                         fc1_experts_bias_optional == nullptr
                             ? nullptr
                             : reinterpret_cast<const CudaT*>(fc1_experts_bias_optional->template Data<T>()),
-                        activation_type_, reinterpret_cast<const CudaT*>(fc2_experts_weights->template Data<T>()),
-                        std::move(fc2_scales_ptr), static_cast<int>(moe_params.num_rows),
+                        activation_type_,
+                        fc3_experts_weights_optional == nullptr
+                            ? nullptr
+                            : reinterpret_cast<const CudaT*>(fc3_experts_weights_optional->template Data<T>()),
+                        std::move(fc_scales_ptr),
+                        fc3_experts_bias_optional == nullptr
+                            ? nullptr
+                            : reinterpret_cast<const CudaT*>(fc3_experts_bias_optional->template Data<T>()),
+                        reinterpret_cast<const CudaT*>(fc2_experts_weights->template Data<T>()),
+                        std::move(fc_scales_ptr), static_cast<int>(moe_params.num_rows),
                         static_cast<int>(moe_params.hidden_size),
                         static_cast<int>(moe_params.inter_size), static_cast<int>(moe_params.num_experts),
                         static_cast<int>(moe_params.local_num_experts), static_cast<int>(local_experts_start_index_),
@@ -116,31 +134,54 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
 
   Tensor* output = context->Output(0, input->Shape());
 
-  size_t stride_count = moe_params.hidden_size;
-  size_t stride_bytes = stride_count * sizeof(CudaT);
-  int64_t total_past_rows = 0;
-  int64_t total_covered_rows = 0;
-  if (copy_event != nullptr) {
-    CUDA_RETURN_IF_ERROR(cudaEventSynchronize(copy_event));
+  if (moe_params.parallel_type == MoEParallelType::None) {
+    fc2_output_bc = std::move(fc2_output);
   }
-  NCCL_RETURN_IF_ERROR(ncclGroupStart());
-  for (int rank = 0; rank < nccl_->Size(); ++rank) {
-    int64_t experts_start_index = rank_to_experts_start_index_[rank];
-    moe_runner.get_total_rows_info(experts_start_index,
-                                   moe_params.local_num_experts,
-                                   total_past_rows,
-                                   total_covered_rows);
-    const char* src = reinterpret_cast<const char*>(fc2_output.get()) + total_past_rows * stride_bytes;
-    char* dst = reinterpret_cast<char*>(fc2_output_bc.get()) + total_past_rows * stride_bytes;
-    NCCL_RETURN_IF_ERROR(ncclBroadcast(src,
-                                       dst,
-                                       total_covered_rows * stride_count,
+
+  if (moe_params.parallel_type == MoEParallelType::EPAndTP) {
+    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expert and Tensor Parallelism is not supported yet");
+  }
+
+  if (moe_params.parallel_type == MoEParallelType::TP) {
+    ORT_ENFORCE(moe_params.tensor_shards == nccl_->Size());
+    NCCL_RETURN_IF_ERROR(ncclGroupStart());
+    NCCL_RETURN_IF_ERROR(ncclAllReduce(reinterpret_cast<const char*>(fc2_output.get()),
+                                       reinterpret_cast<char*>(fc2_output_bc.get()),
+                                       fc2_output_size / sizeof(CudaT),
                                        GetNcclDataType(input->DataType()),
-                                       rank,
+                                       ncclSum,
                                        nccl_->Comm(),
                                        Stream(context)));
+    NCCL_RETURN_IF_ERROR(ncclGroupEnd());
+  }
+
+  if (moe_params.parallel_type == MoEParallelType::EP) {
+    size_t stride_count = moe_params.hidden_size;
+    size_t stride_bytes = stride_count * sizeof(CudaT);
+    int64_t total_past_rows = 0;
+    int64_t total_covered_rows = 0;
+    if (copy_event != nullptr) {
+      CUDA_RETURN_IF_ERROR(cudaEventSynchronize(copy_event));
+    }
+    NCCL_RETURN_IF_ERROR(ncclGroupStart());
+    for (int rank = 0; rank < nccl_->Size(); ++rank) {
+      int64_t experts_start_index = rank_to_experts_start_index_[rank];
+      moe_runner.get_total_rows_info(experts_start_index,
+                                     moe_params.local_num_experts,
+                                     total_past_rows,
+                                     total_covered_rows);
+      const char* src = reinterpret_cast<const char*>(fc2_output.get()) + total_past_rows * stride_bytes;
+      char* dst = reinterpret_cast<char*>(fc2_output_bc.get()) + total_past_rows * stride_bytes;
+      NCCL_RETURN_IF_ERROR(ncclBroadcast(src,
+                                         dst,
+                                         total_covered_rows * stride_count,
+                                         GetNcclDataType(input->DataType()),
+                                         rank,
+                                         nccl_->Comm(),
+                                         Stream(context)));
+    }
+    NCCL_RETURN_IF_ERROR(ncclGroupEnd());
   }
-  NCCL_RETURN_IF_ERROR(ncclGroupEnd());
 
   ort_fastertransformer::finalize_moe_routing_kernelLauncher(
       reinterpret_cast<CudaT*>(fc2_output_bc.get()), reinterpret_cast<CudaT*>(output->template MutableData<T>()),
diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h
index 5ea4ae59c4020..827283a794dd6 100644
--- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h
+++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h
@@ -26,6 +26,7 @@ class ShardedMoE final : public NcclKernel, public MoEBase {
   Status SynchronizeExpertsStartIndex(AllocatorPtr& alloc, OpKernelContext* ctx, cudaEvent_t& cuda_event) const;
 
   int64_t local_experts_start_index_;
+  int64_t tensor_shards_;
   std::vector<int64_t> rank_to_experts_start_index_;
 };
 
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 8f368251f12c7..57e951d3a68ff 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -120,6 +120,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits);
@@ -202,6 +203,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze);
 #endif
 
+#ifdef ENABLE_CUDA_NHWC_OPS
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample);
+#endif
+
 template <>
 KernelCreateInfo BuildKernelCreateInfo<void>() {
   KernelCreateInfo info;
@@ -318,6 +323,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits)>,
@@ -406,6 +412,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze)>,
 #endif
 
+#ifdef ENABLE_CUDA_NHWC_OPS
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample)>,
+#endif
   };
 
   for (auto& function_table_entry : function_table) {
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h
index ea87d0c29111e..a80584d3293a0 100644
--- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h
+++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h
@@ -136,10 +136,10 @@ struct GroupNormNHWCParams {
                       bool use_silu,
                       bool broadcast_skip,
                       int channels_per_block) {
-    int32_t channels_per_group = num_channels / num_groups;
+    int32_t channels_per_group_in = num_channels / num_groups;
     // channels_per_block is computed in PrePack.
     // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here.
-    if (channels_per_block < channels_per_group) {
+    if (channels_per_block < channels_per_group_in) {
       channels_per_block = GetChannelsPerBlock(num_channels, num_groups);
     }
 
@@ -167,7 +167,7 @@ struct GroupNormNHWCParams {
     this->hw_per_block = DivUp(this->hw, blocks_per_hw);
 
     this->channels_per_block = channels_per_block;
-    this->channels_per_group = channels_per_group;
+    this->channels_per_group = channels_per_group_in;
     this->hwc = this->hw * this->c;
     this->inv_hw_channels_per_group = 1.F / (float)(this->hw * this->channels_per_group);
     this->groups_per_block = channels_per_block / this->channels_per_group;
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.cc b/onnxruntime/contrib_ops/cuda/grid_sample.cc
index 4c2999c279e0a..2500de39d3536 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample.cc
+++ b/onnxruntime/contrib_ops/cuda/grid_sample.cc
@@ -9,22 +9,23 @@ namespace onnxruntime {
 namespace contrib {
 namespace cuda {
 
-#define REGISTER_KERNEL_TYPED(T)                                   \
+#define REGISTER_KERNEL_TYPED(T, VERSION, LAYOUT, DOMAIN)          \
   ONNX_OPERATOR_TYPED_KERNEL_EX(                                   \
       GridSample,                                                  \
-      kMSDomain,                                                   \
-      1,                                                           \
+      DOMAIN,                                                      \
+      VERSION,                                                     \
       T,                                                           \
       kCudaExecutionProvider,                                      \
       (*KernelDefBuilder::Create())                                \
           .TypeConstraint("T1", DataTypeImpl::GetTensorType<T>())  \
           .TypeConstraint("T2", DataTypeImpl::GetTensorType<T>()), \
-      GridSample<T>);
+      onnxruntime::contrib::cuda::GridSample<T, LAYOUT>);
 
-REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain)
+REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain)
 
-template <typename T>
-GridSample<T>::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
+template <typename T, bool IsNHWC>
+GridSample<T, IsNHWC>::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
   std::string mode_str = info.GetAttrOrDefault<std::string>("mode", "bilinear");
   std::string padding_mode_str = info.GetAttrOrDefault<std::string>("padding_mode", "zeros");
   align_corners_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("align_corners", 0));
@@ -48,8 +49,8 @@ GridSample<T>::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
   }
 }
 
-template <typename T>
-Status GridSample<T>::ComputeInternal(OpKernelContext* context) const {
+template <typename T, bool IsNHWC>
+Status GridSample<T, IsNHWC>::ComputeInternal(OpKernelContext* context) const {
   const Tensor* X = context->Input<Tensor>(0);
   const auto& dims_input = X->Shape().GetDims();
   const Tensor* Grid = context->Input<Tensor>(1);
@@ -61,11 +62,13 @@ Status GridSample<T>::ComputeInternal(OpKernelContext* context) const {
   ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]);
   ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2");
 
+  using Ch = Channels<IsNHWC>;
+
   TensorShapeVector dims_output(4);
-  dims_output[0] = dims_input[0];
-  dims_output[1] = dims_input[1];
-  dims_output[2] = dims_grid[1];
-  dims_output[3] = dims_grid[2];
+  dims_output[Ch::N] = dims_input[Ch::N];
+  dims_output[Ch::C] = dims_input[Ch::C];
+  dims_output[Ch::H] = dims_grid[1 /* Grid::H */];
+  dims_output[Ch::W] = dims_grid[2 /* Grid::W */];
   Tensor* Y = context->Output(0, dims_output);
   // Return early if the output tensor is going to be of size 0
   if (Y->Shape().Size() == 0) {
@@ -74,7 +77,7 @@ Status GridSample<T>::ComputeInternal(OpKernelContext* context) const {
 
   typedef typename ToCudaType<T>::MappedType CudaT;
   CudaT* Y_data = reinterpret_cast<CudaT*>(Y->MutableData<T>());
-  GridSampleImpl<CudaT>(
+  GridSampleImpl<CudaT, IsNHWC>(
       Stream(context),
       reinterpret_cast<const CudaT*>(X->Data<T>()),
       reinterpret_cast<const CudaT*>(Grid->Data<T>()),
@@ -89,4 +92,8 @@ Status GridSample<T>::ComputeInternal(OpKernelContext* context) const {
 }
 }  // namespace cuda
 }  // namespace contrib
+
+namespace cuda {
+REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain)
+}  // namespace cuda
 }  // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.h b/onnxruntime/contrib_ops/cuda/grid_sample.h
index 08ca58c7cc458..16581bfe77482 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample.h
+++ b/onnxruntime/contrib_ops/cuda/grid_sample.h
@@ -12,7 +12,7 @@ namespace cuda {
 
 using namespace onnxruntime::cuda;
 
-template <typename T>
+template <typename T, bool IsNHWC>
 class GridSample final : public CudaKernel {
  public:
   explicit GridSample(const OpKernelInfo& info);
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu
index 8a391eca7e86a..b23da635bc83d 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu
@@ -50,28 +50,34 @@ __device__ T GsReflect(T x, float x_min, float x_max) {
   return static_cast<T>(fx);
 }
 
-template <typename T>
+template <typename T, bool Layout>
 __device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t y, int64_t x,
-    int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) {
+                         int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) {
   T pixel = 0.0f;
+
+  auto PixelOffset = [bIdx, cIdx, C, H, W](int64_t x, int64_t y) -> int64_t {
+    return Layout == LAYOUT_NCHW
+       ? (bIdx * C * H * W + cIdx * H * W + y * W + x)
+       : (bIdx * H * W * C + y * W * C + x * C + cIdx);
+  };
+
   if (padding_mode == 0) {  // zeros
     if (x >= 0 && x < W && y >= 0 && y < H) {
-      pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
+      pixel = input_data[PixelOffset(x, y)];
     }
-  } else if (padding_mode == 1) {  //border
+  } else if (padding_mode == 1) {  // border
     x = max((int64_t)0, min((int64_t)W - 1, (int64_t)x));
     y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y));
-    pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
+    pixel = input_data[PixelOffset(x, y)];
   } else {  // Reflection
-    x = (int64_t) GsReflect<T>(x, border[0], border[2]);
-    y = (int64_t) GsReflect<T>(y, border[1], border[3]);
-    pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
+    x = (int64_t)GsReflect<T>(x, border[0], border[2]);
+    y = (int64_t)GsReflect<T>(y, border[1], border[3]);
+    pixel = input_data[PixelOffset(x, y)];
   }
   return pixel;
 }
 
-__device__ void GsGetCubicCoeffs(float x, float coeffs[4])
-{
+__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) {
   float cubic_alpha = -0.75f;
   x = abs(x);
   coeffs[0] = (((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha);
@@ -93,7 +99,7 @@ __device__ T GsBicubicInterpolate(T p[4][4], float x, float y) {
   return pixel;
 }
 
-template <typename T>
+template <typename T, bool Layout>
 __global__ void _GridSampleKernel(
     const T* input_data,
     const T* grid_data,
@@ -110,16 +116,32 @@ __global__ void _GridSampleKernel(
 {
     CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out);
     // extract batch index, channel index, y index, x index for current thread
-    int BIdx = idx / (C * H_out * W_out );
-    int tmpBCnt = BIdx * (C * H_out * W_out);
+    int BIdx, yIdx, xIdx, cIdx;
+    if constexpr (Layout == LAYOUT_NCHW) {
+      BIdx = idx / (C * H_out * W_out);
+      int tmpBCnt = BIdx * (C * H_out * W_out);
+
+      cIdx = (idx - tmpBCnt) / (H_out * W_out);
+      int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out);
 
-    int cIdx = (idx - tmpBCnt) / (H_out * W_out);
-    int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out);
+      yIdx = (idx - tmpCCnt) / W_out;
+      int tmpHCnt = tmpCCnt + yIdx * W_out;
 
-    int yIdx = (idx - tmpCCnt) / W_out;
-    int tmpHCnt = tmpCCnt + yIdx * W_out;
+      xIdx = (idx - tmpHCnt);
+    } else {
+      static_assert(Layout == LAYOUT_NHWC, "Unsupported layout");
 
-    int xIdx = (idx - tmpHCnt);
+      BIdx = idx / (H_out * W_out * C);
+      int tmpBCnt = BIdx * (H_out * W_out * C);
+
+      yIdx = (idx - tmpBCnt) / (W_out * C);
+      int tmpHCnt = tmpBCnt + yIdx * (W_out * C);
+
+      xIdx = (idx - tmpHCnt) / C;
+      int tmpWCnt = tmpHCnt + xIdx * C;
+
+      cIdx = (idx - tmpWCnt);
+    }
 
     int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx;
     T grid_X = grid_data[grid_idx * 2 + 0];
@@ -147,8 +169,9 @@ __global__ void _GridSampleKernel(
     if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max ||
         grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound
       if (padding_mode == 1) {  // border
-        grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f));
-        grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f));
+        // Clamping must not be done here, see #10607
+        // grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f));
+        // grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f));
       } else if (padding_mode == 2) {  // reflection
         grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max);
         grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max);
@@ -175,10 +198,10 @@ __global__ void _GridSampleKernel(
       w_lb = w_b * w_l;
       w_rb = w_b * w_r;
 
-      T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border);
-      T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border);
-      T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border);
-      T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border);
+      T lt_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border);
+      T rt_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border);
+      T lb_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border);
+      T rb_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border);
       T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v;
       output_data[outIdx] = interpoV;
       return;
@@ -186,7 +209,8 @@ __global__ void _GridSampleKernel(
     if (mode == 1) {  // nearest
       int x_n = grid_x_imgSpace;
       int y_n = grid_y_imgSpace;
-      output_data[outIdx] = PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border);
+      output_data[outIdx] =
+        PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border);
       return;
     }
     if (mode == 2) {  // bicubic
@@ -195,7 +219,8 @@ __global__ void _GridSampleKernel(
       T p[4][4] = {};  // [H][W]
       for (int64_t h = 0; h < 4; h++) {
         for (int64_t w = 0; w < 4; w++) {
-          p[h][w] = PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border);
+          p[h][w] = 
+            PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border);
         }
       }
       T dx = grid_x_imgSpace - x0 - 1;
@@ -204,7 +229,7 @@ __global__ void _GridSampleKernel(
     }
 }
 
-template <typename T>
+template <typename T, bool IsNHWC>
 void GridSampleImpl(
     cudaStream_t stream,
     const T* input_data,
@@ -216,17 +241,23 @@ void GridSampleImpl(
     const int64_t H_out,
     const int64_t W_out,
     T* output_data) {
-  int blocksPerGrid = (int)(ceil(static_cast<T>(dims[0] * dims[1] * H_out * W_out) / GridDim::maxThreadsPerBlock));
-  _GridSampleKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
-      input_data, grid_data, mode, padding_mode, align_corners, dims[0], dims[1], dims[2], dims[3], H_out, W_out, output_data);
+  using Ch = Channels<IsNHWC>;
+
+  int blocksPerGrid = static_cast<int>(
+    ceil(static_cast<T>(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock));
+  _GridSampleKernel<T, IsNHWC><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
+      input_data, grid_data, mode, padding_mode, align_corners, 
+      dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W],
+      H_out, W_out, output_data);
 }
 
-#define SPECIALIZED_IMPL(T) \
-  template void GridSampleImpl<T>(cudaStream_t stream, const T* input_data, const T* grid_data, \
-                                  const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \
-                                  const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data);
+#define SPECIALIZED_IMPL(T, IsNHWC)                                                                                    \
+  template void GridSampleImpl<T, IsNHWC>(cudaStream_t stream, const T* input_data, const T* grid_data,                \
+                                          const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \
+                                          const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data);
 
-SPECIALIZED_IMPL(float)
+SPECIALIZED_IMPL(float, false)  // NCHW
+SPECIALIZED_IMPL(float, true)   // NHWC
 
 }  // namespace cuda
 }  // namespace contrib
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h
index 6df86ce161908..62cd66a48fa84 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h
+++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h
@@ -8,7 +8,7 @@ namespace onnxruntime {
 namespace contrib {
 namespace cuda {
 
-template <typename T>
+template <typename T, bool IsNHWC>
 void GridSampleImpl(
     cudaStream_t stream,
     const T* input_data,
diff --git a/onnxruntime/contrib_ops/cuda/inverse.cc b/onnxruntime/contrib_ops/cuda/inverse.cc
index 81e161e60642c..9075dda26f86b 100644
--- a/onnxruntime/contrib_ops/cuda/inverse.cc
+++ b/onnxruntime/contrib_ops/cuda/inverse.cc
@@ -78,9 +78,9 @@ struct Inverse::ComputeImpl {
     cudaStream_t stream = ort_stream ? static_cast<cudaStream_t>(ort_stream->GetHandle()) : nullptr;
 
     // Make a copy of the input which will serve as a workspace as well.
-    if (std::is_same<T, float>::value || std::is_same<T, MLFloat16>::value) {
+    if constexpr (std::is_same<T, float>::value || std::is_same<T, MLFloat16>::value) {
       IAllocatorUniquePtr<float> input_workspace = inst->GetScratchBuffer<float>(input_count, ort_stream);
-      if (std::is_same<T, MLFloat16>::value) {
+      if constexpr (std::is_same<T, MLFloat16>::value) {
         // Convert from MLFloat16(half) to float
         Impl_Cast<CudaT, float>(stream, reinterpret_cast<const CudaT*>(input.Data<MLFloat16>()), input_workspace.get(), input_count);
       } else {
@@ -96,7 +96,7 @@ struct Inverse::ComputeImpl {
       // Need to compute ptrs for output buffers
       // Output for MLFloat
       IAllocatorUniquePtr<float*> output_ptrs = inst->GetScratchBuffer<float*>(n_batches, ort_stream);
-      if (std::is_same<T, MLFloat16>::value) {
+      if constexpr (std::is_same<T, MLFloat16>::value) {
         IAllocatorUniquePtr<float> ml_float_output = inst->GetScratchBuffer<float>(input_count, ort_stream);
         ORT_RETURN_IF_ERROR(ComputeMatrixOffsets<float>(stream, ml_float_output.get(), num_batches, rows, output_ptrs));
         // Do the inverse
@@ -112,7 +112,7 @@ struct Inverse::ComputeImpl {
         ORT_RETURN_IF_ERROR(CheckForSingularity(stream, info, info_cpu, num_batches));
         // We are done here
       }
-    } else if (std::is_same<T, double>::value) {
+    } else if constexpr (std::is_same<T, double>::value) {
       IAllocatorUniquePtr<double> input_workspace = inst->GetScratchBuffer<double>(static_cast<int>(input_count), ort_stream);
       CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_workspace.get(), input.Data<double>(), sizeof(double) * input_count,
                                            cudaMemcpyDeviceToDevice, stream));
diff --git a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu
index ca94477114ee2..47a64502b3480 100644
--- a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu
@@ -97,8 +97,8 @@ void ComplexMul_Impl(
     const TArray<int64_t>* rhs_padded_strides,
     const T* rhs_data,
     const TArray<onnxruntime::cuda::fast_divmod>* fdm_output_strides,
-    const onnxruntime::cuda::fast_divmod& fdm_H,
-    const onnxruntime::cuda::fast_divmod& fdm_C,
+    const onnxruntime::cuda::fast_divmod& /*fdm_H*/,
+    const onnxruntime::cuda::fast_divmod& /*fdm_C*/,
     T* output_data,
     int64_t count,
     int64_t lhs_size,
diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu
index 064b6dd392437..28ab27ee33d10 100644
--- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu
+++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu
@@ -174,7 +174,7 @@ Status GemmFloat8::ComputeGemm(
     int32_t dtype_A, int32_t dtype_B,
     int32_t dtype_C, int32_t dtype_Y,
     const TensorShape& shape_A, const TensorShape& shape_B,
-    const TensorShape& shape_C, const TensorShape& shape_Y,
+    const TensorShape& shape_C, const TensorShape& /*shape_Y*/,
     bool trans_A, bool trans_B, const void* p_input_a, const void* p_input_b,
     const void* p_input_c, const void* p_scale_a, const void* p_scale_b,
     const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda,
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h
index 78d206bf1d9bc..b18a70e899d1c 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h
@@ -83,10 +83,16 @@ namespace ort_fastertransformer {
 
 struct EpilogueOpBiasSilu {};
 
+struct EpilogueOpNoBiasSilu {};
+
 struct EpilogueOpBiasReLU {};
 
+struct EpilogueOpNoBiasReLU {};
+
 struct EpilogueOpBiasFtGelu {};
 
+struct EpilogueOpNoBiasFtGelu {};
+
 struct EpilogueOpBias {};
 
 struct EpilogueOpNoBias {};
@@ -101,6 +107,13 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
                                                               cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
 };
 
+template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
+struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpNoBiasSilu> {
+  using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess, ElementAccumulator,
+                                                              ElementAccumulator,
+                                                              cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
+};
+
 template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
 struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU> {
   using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess, ElementAccumulator,
@@ -108,6 +121,13 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
                                                               cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
 };
 
+template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
+struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpNoBiasReLU> {
+  using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess, ElementAccumulator,
+                                                              ElementAccumulator,
+                                                              cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
+};
+
 template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
 struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasFtGelu> {
   using Op = cutlass::epilogue::thread::LinearCombinationGeneric<
@@ -116,6 +136,14 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
       cutlass::FloatRoundStyle::round_to_nearest, true>;
 };
 
+template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
+struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpNoBiasFtGelu> {
+  using Op = cutlass::epilogue::thread::LinearCombinationGeneric<
+      cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator,
+      ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling,
+      cutlass::FloatRoundStyle::round_to_nearest, true>;
+};
+
 template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
 struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBias> {
   using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
@@ -126,8 +154,9 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
 template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
 struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpNoBias> {
   using Op =
-      cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
-                                                   ElementAccumulator, cutlass::epilogue::thread::ScaleType::Default>;
+      cutlass::epilogue::thread::LinearCombination<
+          ElementType, ElementsPerVectorAccess, ElementAccumulator,
+          ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
 };
 
 }  // namespace ort_fastertransformer
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h
index bfe30b71170d8..cfe306c2482a5 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h
@@ -202,7 +202,7 @@ struct MoeFCGemm {
           total_rows_before_expert(total_rows_before_expert),
           gemm_n(gemm_n),
           gemm_k(gemm_k),
-          host_problem_sizes(nullptr) {
+          host_problem_sizes(host_problem_sizes) {
       if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value) {
         assert(weight_scales);
       }
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h
index 60608f462fde5..e0f91ab806c85 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h
@@ -42,8 +42,13 @@ class MoeGemmRunner {
                          int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
                          int num_experts, ActivationType activation_type, cudaStream_t stream);
 
-  void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert,
-                int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream);
+  void moe_gemm_act(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert,
+                    int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
+                    ActivationType activation_type, cudaStream_t stream);
+
+  void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
+                int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
+                int num_experts, cudaStream_t stream);
 
  private:
   template <typename EpilogueTag>
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h
index 66950c9b65970..2a15fdfd1cc1a 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h
@@ -20,6 +20,12 @@
 #pragma GCC diagnostic ignored "-Wstrict-aliasing"
 #endif
 
+// Ignore CUTLASS warning C4100: unreferenced formal parameter
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4100)
+#endif
+
 #include "cutlass/array.h"
 #include "cutlass/numeric_conversion.h"
 #include "cutlass/layout/matrix.h"
@@ -36,6 +42,10 @@
 #include "layout_traits_helper.h"
 #include "moe_cutlass_kernel.h"
 
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
 #ifdef __GNUC__
 #pragma GCC diagnostic pop
 #endif
@@ -149,10 +159,10 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w
 template <typename T, typename WeightType, typename arch, typename EpilogueTag, typename ThreadblockShape,
           typename WarpShape, int Stages, typename Enable = void>
 struct dispatch_stages {
-  static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
-                       int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts,
-                       CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream,
-                       int* occupancy = nullptr) {
+  static void dispatch(const T* /*A*/, const WeightType* /*B*/, const T* /*weight_scales*/, const T* /*biases*/,
+                       T* /*C*/, int64_t* /*total_rows_before_expert*/, int64_t /*gemm_n*/, int64_t /*gemm_k*/,
+                       int /*num_experts*/, CutlassGemmConfig /*gemm_config*/, int /*multi_processor_count*/,
+                       cudaStream_t /*stream*/, [[maybe_unused]] int* occupancy = nullptr) {
     std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " +
                           std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages);
     ORT_THROW("[FT Error][dispatch_stages::dispatch] " + err_msg);
@@ -221,9 +231,10 @@ template <
     typename T, typename WeightType, typename arch, typename EpilogueTag,
     typename std::enable_if<!std::is_same<T, float>::value && std::is_same<T, WeightType>::value>::type* = nullptr>
 void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
-                                  int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
-                                  int num_experts, CutlassGemmConfig gemm_config, int sm_version,
-                                  int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) {
+                                  int64_t* total_rows_before_expert, int64_t /*total_rows*/,
+                                  int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config,
+                                  int /*sm_version*/, int multi_processor_count, cudaStream_t stream,
+                                  int* occupancy = nullptr) {
   switch (gemm_config.tile_config) {
     case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
       dispatch_gemm_config<T, WeightType, arch, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
@@ -300,8 +311,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weig
 template <typename T, typename WeightType, typename arch, typename EpilogueTag,
           typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
 void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
-                                  int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
-                                  int num_experts, CutlassGemmConfig gemm_config, int sm_version,
+                                  int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n,
+                                  int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, int /*sm_version*/,
                                   int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) {
   switch (gemm_config.tile_config) {
     case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
@@ -418,11 +429,47 @@ void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(const T* A, const WeightTyp
 }
 
 template <typename T, typename WeightType>
-void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C,
-                                            int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n,
-                                            int64_t gemm_k, int num_experts, cudaStream_t stream) {
-  run_gemm<EpilogueOpNoBias>(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k,
+void MoeGemmRunner<T, WeightType>::moe_gemm_act(const T* A, const WeightType* B, const T* weight_scales,
+                                                T* C, int64_t* total_rows_before_expert, int64_t total_rows,
+                                                int64_t gemm_n, int64_t gemm_k, int num_experts,
+                                                ActivationType activation_type, cudaStream_t stream) {
+  switch (activation_type) {
+    case ActivationType::Relu:
+      run_gemm<EpilogueOpNoBiasReLU>(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n,
+                                     gemm_k, num_experts, stream);
+      break;
+    case ActivationType::Gelu:
+      run_gemm<EpilogueOpNoBiasFtGelu>(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n,
+                                       gemm_k, num_experts, stream);
+      break;
+    case ActivationType::Silu:
+      run_gemm<EpilogueOpNoBiasSilu>(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n,
+                                     gemm_k, num_experts, stream);
+      break;
+    case ActivationType::Identity:
+      run_gemm<EpilogueOpNoBias>(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k,
+                                 num_experts, stream);
+      break;
+    case ActivationType::InvalidType:
+      ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM");
+      break;
+    default: {
+      ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM");
+    }
+  }
+}
+
+template <typename T, typename WeightType>
+void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases,
+                                            T* C, int64_t* total_rows_before_expert, int64_t total_rows,
+                                            int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream) {
+  if (biases != nullptr) {
+    run_gemm<EpilogueOpBias>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k,
                              num_experts, stream);
+  } else {
+    run_gemm<EpilogueOpNoBias>(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k,
+                               num_experts, stream);
+  }
 }
 
 }  // namespace ort_fastertransformer
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
index f4f2b49032d23..5e6e484567988 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
@@ -30,7 +30,6 @@
 
 #include "cutlass/array.h"
 #include "cutlass/numeric_conversion.h"
-#include "cutlass/numeric_types.h"
 
 #ifdef __GNUC__
 #pragma GCC diagnostic pop
@@ -49,15 +48,14 @@
 #endif
 
 namespace ort_fastertransformer {
-
 static constexpr int WARP_SIZE = 32;
 
 // ====================== Softmax things ===============================
 // We have our own implementation of softmax here so we can support transposing the output
 // in the softmax kernel when we extend this module to support expert-choice routing.
 template <typename T, int TPB>
-__launch_bounds__(TPB) __global__
-    void moe_softmax(const T* input, const bool* finished, T* output, const int num_cols) {
+__launch_bounds__(TPB) __global__ void moe_softmax(const T* input, const bool* finished, T* output,
+                                                   const int num_cols) {
   using BlockReduce = cub::BlockReduce<float, TPB>;
   __shared__ typename BlockReduce::TempStorage tmpStorage;
 
@@ -108,14 +106,15 @@ __launch_bounds__(TPB) __global__
 
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
 template <typename T, int TPB>
-__launch_bounds__(TPB) __global__ void moe_top_k(const T*, const bool*, T*, int*, int*, int, const int) {
+__launch_bounds__(TPB) __global__ void moe_top_k(const T*, const bool*, T*, int*, int*, int, int, bool) {
   // Does not support pre-Kepler architectures
   ;
 }
 #else
 template <typename T, int TPB>
 __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, const bool* finished, T* output,
-                                                 int* indices, int* source_rows, int num_experts, int k) {
+                                                 int* indices, int* source_rows, int num_experts, int k,
+                                                 bool normalize_routing_weights) {
   using cub_kvp = cub::KeyValuePair<int, T>;
   using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
   __shared__ typename BlockReduce::TempStorage tmpStorage;
@@ -128,6 +127,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
 
   const bool should_process_row = finished ? !finished[block_row] : true;
   const int thread_read_offset = blockIdx.x * num_experts;
+  float output_row_sum = 0.f;
   for (int k_idx = 0; k_idx < k; ++k_idx) {
     thread_kvp.key = 0;
     thread_kvp.value = T(-1.f);  // This is OK because inputs are probabilities
@@ -155,6 +155,13 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
       output[idx] = result_kvp.value;
       indices[idx] = should_process_row ? result_kvp.key : num_experts;
       source_rows[idx] = k_idx * num_rows + block_row;
+
+      if (normalize_routing_weights && k_idx == k - 1) {
+#pragma unroll
+        for (int ki = 0; ki < k; ++ki) {
+          output[idx - ki] = T(static_cast<float>(output[idx - ki]) / output_row_sum);
+        }
+      }
     }
     __syncthreads();
   }
@@ -178,7 +185,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
 template <typename T, int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
 __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
     void topk_gating_softmax(const T* input, const bool* finished, T* output, int num_rows, int* indices,
-                             int* source_rows, int k) {
+                             int* source_rows, int k, bool normalize_routing_weights) {
   // We begin by enforcing compile time assertions and setting up compile time constants.
   static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
   static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
@@ -296,6 +303,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
   int start_col = first_elt_read_by_thread;
   static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
 
+  float output_row_sum = 0.f;
   for (int k_idx = 0; k_idx < k; ++k_idx) {
     // First, each thread does the local argmax
     float max_val = row_chunk[0];
@@ -336,8 +344,16 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
       // single) thread per row of the input/output matrices.
       const int idx = k * thread_row + k_idx;
       output[idx] = T(max_val);
+      output_row_sum = output_row_sum + static_cast<float>(max_val);
       indices[idx] = should_process_row ? expert : NUM_EXPERTS;
       source_rows[idx] = k_idx * num_rows + thread_row;
+
+      if (normalize_routing_weights && k_idx == k - 1) {
+#pragma unroll
+        for (int ki = 0; ki < k; ++ki) {
+          output[idx - ki] = T(static_cast<float>(output[idx - ki]) / output_row_sum);
+        }
+      }
     }
 
     // Finally, we clear the value in the thread with the current max if there is another iteration to run.
@@ -370,7 +386,8 @@ struct TopkConstants {
 
 template <typename T, int EXPERTS, int WARPS_PER_TB>
 void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T* output, int* indices, int* source_row,
-                                         int num_rows, int num_experts, int k, cudaStream_t stream) {
+                                         int num_rows, int /*num_experts*/, int k, bool normalize_routing_weights,
+                                         cudaStream_t stream) {
   static constexpr unsigned long MAX_BYTES_PER_LDG = 16;
 
   static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS);
@@ -382,61 +399,63 @@ void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T
 
   dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
   topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG>
-      <<<num_blocks, block_dim, 0, stream>>>(input, finished, output, num_rows, indices, source_row, k);
+      <<<num_blocks, block_dim, 0, stream>>>(input, finished, output, num_rows, indices, source_row, k,
+                                             normalize_routing_weights);
 }
 
 template <typename T>
 void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_output,
                                         int* indices, int* source_row, int num_rows, int num_experts,
-                                        int k, cudaStream_t stream) {
+                                        int k, bool normalize_routing_weights, cudaStream_t stream) {
   static constexpr int WARPS_PER_TB = 4;
 
   switch (num_experts) {
     case 2: {
       topk_gating_softmax_launcher_helper<T, 2, WARPS_PER_TB>(input, finished, output, indices, source_row, num_rows,
-                                                              num_experts, k, stream);
+                                                              num_experts, k, normalize_routing_weights, stream);
       break;
     }
     case 4: {
       topk_gating_softmax_launcher_helper<T, 4, WARPS_PER_TB>(input, finished, output, indices, source_row, num_rows,
-                                                              num_experts, k, stream);
+                                                              num_experts, k, normalize_routing_weights, stream);
       break;
     }
     case 8: {
       topk_gating_softmax_launcher_helper<T, 8, WARPS_PER_TB>(input, finished, output, indices, source_row, num_rows,
-                                                              num_experts, k, stream);
+                                                              num_experts, k, normalize_routing_weights, stream);
       break;
     }
     case 16: {
       topk_gating_softmax_launcher_helper<T, 16, WARPS_PER_TB>(input, finished, output, indices, source_row, num_rows,
-                                                               num_experts, k, stream);
+                                                               num_experts, k, normalize_routing_weights, stream);
       break;
     }
     case 32: {
       topk_gating_softmax_launcher_helper<T, 32, WARPS_PER_TB>(input, finished, output, indices, source_row, num_rows,
-                                                               num_experts, k, stream);
+                                                               num_experts, k, normalize_routing_weights, stream);
       break;
     }
     case 64: {
       topk_gating_softmax_launcher_helper<T, 64, WARPS_PER_TB>(input, finished, output, indices, source_row, num_rows,
-                                                               num_experts, k, stream);
+                                                               num_experts, k, normalize_routing_weights, stream);
       break;
     }
     case 128: {
       topk_gating_softmax_launcher_helper<T, 128, WARPS_PER_TB>(input, finished, output, indices, source_row, num_rows,
-                                                                num_experts, k, stream);
+                                                                num_experts, k, normalize_routing_weights, stream);
       break;
     }
     case 256: {
       topk_gating_softmax_launcher_helper<T, 256, WARPS_PER_TB>(input, finished, output, indices, source_row, num_rows,
-                                                                num_experts, k, stream);
+                                                                num_experts, k, normalize_routing_weights, stream);
       break;
     }
     default: {
       static constexpr int TPB = 256;
       moe_softmax<T, TPB><<<num_rows, TPB, 0, stream>>>(input, finished, softmax_temp_output, num_experts);
       moe_top_k<T, TPB>
-          <<<num_rows, TPB, 0, stream>>>(softmax_temp_output, finished, output, indices, source_row, num_experts, k);
+          <<<num_rows, TPB, 0, stream>>>(softmax_temp_output, finished, output, indices, source_row, num_experts, k,
+                                         normalize_routing_weights);
     }
   }
 }
@@ -521,25 +540,31 @@ __global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, i
 }
 
 template <typename T, typename WeightType, typename Enable>
-CutlassMoeFCRunner<T, WeightType, Enable>::CutlassMoeFCRunner(int sm_version) {
-  total_past_rows_ = 0;
-  total_covered_rows_ = 0;
+CutlassMoeFCRunner<T, WeightType, Enable>::CutlassMoeFCRunner(int sm_version,
+                                                              bool has_fc3,
+                                                              bool normalize_routing_weights)
+    : has_fc3_(has_fc3),
+      total_past_rows_(0),
+      total_covered_rows_(0),
+      normalize_routing_weights_(normalize_routing_weights) {
   moe_gemm_runner_.initialize(sm_version);
 }
 
 template <typename T, typename WeightType, typename Enable>
-size_t CutlassMoeFCRunner<T, WeightType, Enable>::getWorkspaceSize(int num_rows, const int hidden_size,
-                                                                   const int inter_size, int num_experts,
-                                                                   int k) {
-  const int buf_size = static_cast<int>(pad_to_multiple_of_16(k * num_rows * hidden_size));
-  const int interbuf_size = static_cast<int>(pad_to_multiple_of_16(k * num_rows * inter_size));
-  const int padded_experts = static_cast<int>(pad_to_multiple_of_16(num_experts));
-  const int num_moe_inputs = static_cast<int>(pad_to_multiple_of_16(k * num_rows));
-  int num_softmax_outs = 0;
+size_t CutlassMoeFCRunner<T, WeightType, Enable>::getWorkspaceSize(size_t num_rows, const size_t hidden_size,
+                                                                   const size_t inter_size, size_t num_experts,
+                                                                   size_t k) {
+  total_covered_rows_ = k * num_rows;
+
+  const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size);
+  const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size);
+  const size_t padded_experts = pad_to_multiple_of_16(num_experts);
+  const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows);
+  size_t num_softmax_outs = 0;
 
   const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
   if (!is_pow_2 || num_experts > 256) {
-    num_softmax_outs = static_cast<int>(pad_to_multiple_of_16(num_rows * num_experts));
+    num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts);
   }
 
   // softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them
@@ -548,13 +573,13 @@ size_t CutlassMoeFCRunner<T, WeightType, Enable>::getWorkspaceSize(int num_rows,
   total_ws_bytes += buf_size * sizeof(T);                    // permuted_data
   total_ws_bytes += padded_experts * sizeof(int64_t);        // Hold total_rows_before_expert_
   total_ws_bytes += num_softmax_outs * sizeof(T);
-  const int bytes_for_fc1_result = interbuf_size * sizeof(T);
-  const int sorter_ws_size_bytes = static_cast<int>(pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows)));
-  sorter_.update_num_experts(num_experts);
+  const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T);
+  const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows));
+  sorter_.update_num_experts(static_cast<int>(num_experts));
 
-  int bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
+  size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
   if (sorter_ws_size_bytes > bytes_for_fc1_result) {
-    int remaining_bytes = static_cast<int>(pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result));
+    size_t remaining_bytes = pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result);
     bytes_for_intermediate_and_sorting += remaining_bytes;
   }
 
@@ -563,13 +588,13 @@ size_t CutlassMoeFCRunner<T, WeightType, Enable>::getWorkspaceSize(int num_rows,
 }
 
 template <typename T, typename WeightType, typename Enable>
-void CutlassMoeFCRunner<T, WeightType, Enable>::configure_ws_ptrs(char* ws_ptr, int num_rows,
-                                                                  const int hidden_size, const int inter_size,
-                                                                  int num_experts, int k) {
-  const int buf_size = static_cast<int>(pad_to_multiple_of_16(k * num_rows * hidden_size));
-  const int interbuf_size = static_cast<int>(pad_to_multiple_of_16(k * num_rows * inter_size));
-  const int padded_experts = static_cast<int>(pad_to_multiple_of_16(num_experts));
-  const int num_moe_inputs = static_cast<int>(pad_to_multiple_of_16(k * num_rows));
+void CutlassMoeFCRunner<T, WeightType, Enable>::configure_ws_ptrs(char* ws_ptr, size_t num_rows,
+                                                                  const size_t hidden_size, const size_t inter_size,
+                                                                  size_t num_experts, size_t k) {
+  const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size);
+  const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size);
+  const size_t padded_experts = pad_to_multiple_of_16(num_experts);
+  const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows);
 
   source_rows_ = (int*)ws_ptr;
   permuted_rows_ = source_rows_ + num_moe_inputs;
@@ -578,24 +603,126 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::configure_ws_ptrs(char* ws_ptr,
 
   total_rows_before_expert_ = (int64_t*)(permuted_data_ + buf_size);
 
-  fc1_result_ = (T*)(total_rows_before_expert_ + padded_experts);
+  if (has_fc3_) {
+    fc3_result_ = reinterpret_cast<T*>(total_rows_before_expert_ + padded_experts);
+    fc1_result_ = reinterpret_cast<T*>(fc3_result_ + interbuf_size);
+  } else {
+    fc1_result_ = reinterpret_cast<T*>(total_rows_before_expert_ + padded_experts);
+  }
 
   const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
   if (!is_pow_2 || num_experts > 256) {
-    softmax_out_ = (T*)(fc1_result_ + interbuf_size);
+    softmax_out_ = reinterpret_cast<T*>(fc1_result_ + interbuf_size);
   } else {
     softmax_out_ = nullptr;
   }
 }
 
+namespace {
+
+struct __align__(8) Half4 {
+  half2 x;
+  half2 y;
+};
+
+// TODO(wy): move to common header
+template <typename T>
+struct T4;
+template <>
+struct T4<float> {
+  using Type = float4;
+};
+template <>
+struct T4<half> {
+  using Type = Half4;
+};
+
+template <typename T>
+struct T2;
+template <>
+struct T2<float> {
+  using Type = float2;
+};
+template <>
+struct T2<half> {
+  using Type = half2;
+};
+
+inline __device__ float2 operator*(const float2 a, const float2 b) {
+  return make_float2(a.x * b.x, a.y * b.y);
+}
+
+inline __device__ float4 operator*(const float4 a, const float4 b) {
+  return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
+}
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
+inline __device__ half operator*(const half a, const half b) {
+  return __float2half(__half2float(a) * __half2float(b));
+}
+
+inline __device__ half2 operator*(const half2 a, const half2 b) {
+  return make_half2(a.x * b.x, a.y * b.y);
+}
+#endif
+
+inline __device__ Half4 operator*(const Half4 a, const Half4 b) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
+  Half4 result;
+  result.x = a.x * b.x;
+  result.y = a.y * b.y;
+  return result;
+#else
+  return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)};
+#endif
+}
+
+}  // anonymous namespace
+
+template <typename T>
+__global__ void elementWiseMulKernel(T* output, T const* input, size_t inter_size) {
+  int const tid = threadIdx.x;
+  int const token = blockIdx.x;
+
+  output = output + token * inter_size;
+  input = input + token * inter_size;
+  for (int i = tid; i < inter_size; i += blockDim.x) {
+    T fc1_value = input[i];
+    output[i] = fc1_value * output[i];
+  }
+}
+
+template <typename T>
+void elementWiseMul(T* output, T const* input, int inter_size, int num_tokens, cudaStream_t stream) {
+  int const blocks = num_tokens;
+
+  if (inter_size & 3 == 0) {
+    using vec_type = typename T4<T>::Type;
+    int const threads = std::min(inter_size / 4, 1024);
+    elementWiseMulKernel<vec_type><<<blocks, threads, 0, stream>>>(reinterpret_cast<vec_type*>(output),
+                                                                   reinterpret_cast<vec_type const*>(input),
+                                                                   inter_size / 4);
+  } else if (inter_size & 1 == 0) {
+    using vec_type = typename T2<T>::Type;
+    int const threads = std::min(inter_size / 2, 1024);
+    elementWiseMulKernel<vec_type><<<blocks, threads, 0, stream>>>(reinterpret_cast<vec_type*>(output),
+                                                                   reinterpret_cast<vec_type const*>(input),
+                                                                   inter_size / 2);
+  } else {
+    int const threads = std::min(inter_size, 1024);
+    elementWiseMulKernel<T><<<blocks, threads, 0, stream>>>(output, input, inter_size);
+  }
+}
+
 template <typename T, typename WeightType, typename Enable>
 void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
     const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales,
-    const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights,
-    const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts,
-    int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result,
-    const bool* finished, int active_rows, T* expert_scales, int* expanded_source_row_to_expanded_dest_row,
-    int* expert_for_source_row, cudaStream_t stream) {
+    const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc3_expert_weights,
+    const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales,
+    int num_rows, const int hidden_size, const int inter_size, int num_experts, int local_num_experts,
+    int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows,
+    T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row,
+    cudaStream_t stream) {
   static constexpr bool scales_required =
       std::is_same<WeightType, uint8_t>::value || std::is_same<WeightType, cutlass::uint4b_t>::value;
 
@@ -613,9 +740,10 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
     }
   }
 
-  configure_ws_ptrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts, k);
+  configure_ws_ptrs(workspace_ptr, static_cast<size_t>(num_rows), static_cast<size_t>(hidden_size),
+                    static_cast<size_t>(inter_size), static_cast<size_t>(num_experts), static_cast<size_t>(k));
   topk_gating_softmax_kernelLauncher<T>(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row,
-                                        source_rows_, num_rows, num_experts, k, stream);
+                                        source_rows_, num_rows, num_experts, k, normalize_routing_weights_, stream);
 
   const int sorter_ws_size_bytes = static_cast<int>(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows)));
   sorter_.run((void*)fc1_result_, sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, source_rows_,
@@ -634,15 +762,48 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
   }
 
   // expanded_active_expert_rows is not used
-  moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows_ * hidden_size,
-                                     fc1_expert_weights, fc1_scales, fc1_expert_biases,
-                                     fc1_result_ + total_past_rows_ * inter_size,
-                                     total_rows_before_expert_ + local_experts_start_index,
-                                     expanded_active_expert_rows, inter_size, hidden_size,
-                                     local_num_experts, fc1_activation_type, stream);
+  if (fc1_expert_biases != nullptr) {
+    moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows_ * hidden_size,
+                                       fc1_expert_weights, fc1_scales, fc1_expert_biases,
+                                       fc1_result_ + total_past_rows_ * inter_size,
+                                       total_rows_before_expert_ + local_experts_start_index,
+                                       expanded_active_expert_rows, inter_size, hidden_size,
+                                       local_num_experts, fc1_activation_type, stream);
+  } else {
+    moe_gemm_runner_.moe_gemm_act(permuted_data_ + total_past_rows_ * hidden_size,
+                                  fc1_expert_weights, fc1_scales,
+                                  fc1_result_ + total_past_rows_ * inter_size,
+                                  total_rows_before_expert_ + local_experts_start_index,
+                                  expanded_active_expert_rows, inter_size, hidden_size,
+                                  local_num_experts, fc1_activation_type, stream);
+  }
+
+  if (has_fc3_) {
+    if (scales_required) {
+      if (fc3_scales == nullptr) {
+        ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for third matmul is a null pointer");
+      }
+    } else {
+      if (fc3_scales != nullptr) {
+        ORT_THROW("[FT Error][Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC3");
+      }
+    }
+    if (fc3_expert_weights == nullptr) {
+      ORT_THROW("[FT Error][Run MoE FC] FC3 weights are null");
+    }
+    moe_gemm_runner_.moe_gemm(permuted_data_ + total_past_rows_ * hidden_size,
+                              fc3_expert_weights, fc3_scales, fc3_expert_biases,
+                              fc3_result_ + total_past_rows_ * inter_size,
+                              total_rows_before_expert_ + local_experts_start_index,
+                              expanded_active_expert_rows, inter_size, hidden_size,
+                              local_num_experts, stream);
+
+    elementWiseMul(fc1_result_ + total_past_rows_ * inter_size, fc3_result_ + total_past_rows_ * inter_size,
+                   static_cast<int>(inter_size), static_cast<int>(total_covered_rows_), stream);
+  }
 
   moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows_ * inter_size,
-                            fc2_expert_weights, fc2_scales,
+                            fc2_expert_weights, fc2_scales, nullptr,
                             fc2_result + total_past_rows_ * hidden_size,
                             total_rows_before_expert_ + local_experts_start_index,
                             expanded_active_expert_rows, hidden_size, inter_size, local_num_experts, stream);
@@ -651,14 +812,16 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
 template <typename T, typename WeightType, typename Enable>
 void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
     const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales,
-    const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights,
-    const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts,
-    int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, T* expert_scales,
+    const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc3_expert_weights,
+    const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales,
+    int num_rows, const int hidden_size, const int inter_size, int num_experts, int local_num_experts,
+    int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, T* expert_scales,
     int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) {
   run_moe_fc(input_activations, gating_output, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_activation_type,
-             fc2_expert_weights, fc2_scales, num_rows, hidden_size, inter_size, num_experts, local_num_experts,
-             local_experts_start_index, k, workspace_ptr, fc2_result, nullptr, num_rows, expert_scales,
-             expanded_source_row_to_expanded_dest_row, expert_for_source_row, stream);
+             fc3_expert_weights, fc3_scales, fc3_expert_biases, fc2_expert_weights, fc2_scales, num_rows, hidden_size,
+             inter_size, num_experts, local_num_experts, local_experts_start_index, k, workspace_ptr, fc2_result,
+             nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row,
+             stream);
 }
 
 template <typename T, typename WeightType, typename Enable>
@@ -811,9 +974,10 @@ __global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, T*
       const T* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols;
 
       const int expert_idx = expert_for_source_row[k_offset];
-      const T* bias_ptr = bias + expert_idx * cols;
+      const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr;
 
-      thread_output = thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_ptr[tid]);
+      thread_output = thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] +
+                                                   (bias_ptr ? bias_ptr[tid] : T(0)));
     }
     reduced_row_ptr[tid] = thread_output;
   }
@@ -866,9 +1030,9 @@ void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* red
 
 // ========================= TopK Softmax specializations ===========================
 template void topk_gating_softmax_kernelLauncher(const float*, const bool*, float*, float*, int*, int*, int,
-                                                 int, int, cudaStream_t);
+                                                 int, int, bool, cudaStream_t);
 template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int,
-                                                 int, int, cudaStream_t);
+                                                 int, int, bool, cudaStream_t);
 
 // ==================== Variable batched GEMM specializations ==================================
 template class CutlassMoeFCRunner<float, float>;
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
index 5cc2a3f79f003..5eef6f95f4820 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
@@ -24,6 +24,8 @@
 #include "core/common/common.h"
 #include "contrib_ops/cuda/bert/transformer_cuda_common.h"
 
+#include "cutlass/numeric_types.h"
+
 using namespace onnxruntime;
 
 namespace ort_fastertransformer {
@@ -107,12 +109,13 @@ template <typename T,          /*The type used for activations/scales/compute*/
           typename Enable = void>
 class CutlassMoeFCRunner {
  public:
-  CutlassMoeFCRunner(int sm_version);
+  CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights);
 
-  size_t getWorkspaceSize(int num_rows, int hidden_size, int inter_size, int num_experts, int k);
+  size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k);
 
   void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights,
                   const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type,
+                  const WeightType* fc3_expert_weights, const T* fc3_scales, const T* fc3_expert_biases,
                   const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size,
                   int inter_size, int num_experts, int local_num_experts, int local_experts_start_index, int k,
                   char* workspace_ptr, T* fc2_result, T* expert_scales, int* expanded_source_row_to_expanded_dest_row,
@@ -120,6 +123,7 @@ class CutlassMoeFCRunner {
 
   void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights,
                   const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type,
+                  const WeightType* fc3_expert_weights, const T* fc3_scales, const T* fc3_expert_biases,
                   const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size,
                   int inter_size, int num_experts, int local_num_experts, int local_experts_start_index, int k,
                   char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, T* expert_scales,
@@ -135,7 +139,8 @@ class CutlassMoeFCRunner {
                            int64_t& total_covered_rows);
 
  private:
-  void configure_ws_ptrs(char* ws_ptr, int num_rows, int hidden_size, int inter_size, int num_experts, int k);
+  void configure_ws_ptrs(char* ws_ptr, size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts,
+                         size_t k);
 
  private:
   CubKeyValueSorter sorter_;
@@ -152,12 +157,17 @@ class CutlassMoeFCRunner {
   int64_t* total_rows_before_expert_;
 
   T* fc1_result_;
+  T* fc3_result_;
+
+  bool has_fc3_;
+  bool normalize_routing_weights_;
 
   // Cuda events
   contrib::cuda::AutoDestoryCudaEvent cuda_event_;
 
   int64_t total_past_rows_;
   int64_t total_covered_rows_;
+
   // TODO: use pinned memory
   std::vector<int64_t> total_rows_before_expert_host_;
 };
@@ -165,11 +175,11 @@ class CutlassMoeFCRunner {
 template <typename WeightType>
 class CutlassMoeFCRunner<float, WeightType, typename std::enable_if_t<!std::is_same<float, WeightType>::value>> {
  public:
-  CutlassMoeFCRunner(int sm_version);
+  CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights);
 
-  size_t getWorkspaceSize(int num_rows, int hidden_size, int inter_size, int num_experts, int k) {
+  size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) {
     return 0;
   }
 };
 
-}  // namespace ort_fastertransformer
\ No newline at end of file
+}  // namespace ort_fastertransformer
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h
index 00f977c615df6..1de8f6b69642c 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h
@@ -276,13 +276,13 @@ struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode:
     return true;
   }
 
-  static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count,
-                                   int32_t block_count) {
+  static size_t get_workspace_size(const cutlass::gemm::GemmCoord* /*host_problem_sizes_ptr*/,
+                                   int32_t /*problem_count*/, int32_t /*block_count*/) {
     return 0;
   }
 
-  static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count,
-                              int32_t block_count, void* host_workspace_ptr) {}
+  static void host_precompute(const cutlass::gemm::GemmCoord* /*host_problem_sizes_ptr*/, int32_t /*problem_count*/,
+                              int32_t /*block_count*/, void* /*host_workspace_ptr*/) {}
 };
 
 }  // namespace kernel
diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc
index 3f26a274109ad..b13aab959fc48 100644
--- a/onnxruntime/contrib_ops/cuda/moe/moe.cc
+++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc
@@ -39,13 +39,16 @@ Status MoE<T>::ComputeInternal(OpKernelContext* context) const {
   const Tensor* input = context->Input<Tensor>(0);
   const Tensor* router_probs = context->Input<Tensor>(1);
   const Tensor* fc1_experts_weights = context->Input<Tensor>(2);
-  const Tensor* fc2_experts_weights = context->Input<Tensor>(3);
-  const Tensor* fc1_experts_bias_optional = context->Input<Tensor>(4);
+  const Tensor* fc1_experts_bias_optional = context->Input<Tensor>(3);
+  const Tensor* fc2_experts_weights = context->Input<Tensor>(4);
   const Tensor* fc2_experts_bias_optional = context->Input<Tensor>(5);
+  const Tensor* fc3_experts_weights_optional = context->Input<Tensor>(6);
+  const Tensor* fc3_experts_bias_optional = context->Input<Tensor>(7);
 
   MoEParameters moe_params;
-  ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights,
-                                  fc1_experts_bias_optional, fc2_experts_bias_optional));
+  ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional,
+                                  fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional,
+                                  fc3_experts_bias_optional));
 
   typedef typename ToCudaType<T>::MappedType CudaT;
   auto stream = context->GetComputeStream();
@@ -53,12 +56,14 @@ Status MoE<T>::ComputeInternal(OpKernelContext* context) const {
   auto& device_prop = GetDeviceProp();
   const int sm = device_prop.major * 10 + device_prop.minor;
 
-  ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm);
+  ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm,
+                                                                     fc3_experts_weights_optional != nullptr,
+                                                                     normalize_routing_weights_);
 
   size_t ws_size =
-      moe_runner.getWorkspaceSize(static_cast<int>(moe_params.num_rows), static_cast<int>(moe_params.hidden_size),
-                                  static_cast<int>(moe_params.inter_size), static_cast<int>(moe_params.num_experts),
-                                  static_cast<int>(k_));
+      moe_runner.getWorkspaceSize(static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
+                                  static_cast<size_t>(moe_params.inter_size),
+                                  static_cast<size_t>(moe_params.num_experts), static_cast<size_t>(k_));
   size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT);
   size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT);
   size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int);
@@ -77,26 +82,37 @@ Status MoE<T>::ComputeInternal(OpKernelContext* context) const {
   IAllocatorUniquePtr<void> expert_for_source_row =
       IAllocator::MakeUniquePtr<void>(allocator, expert_for_source_row_size, false, stream);
 
-  // fc1_scales and fc2_scales are used in quantized MoE
-  const CudaT* fc1_scales_ptr = nullptr;
-  const CudaT* fc2_scales_ptr = nullptr;
-
+  const CudaT* fc_scales_ptr = nullptr;
   moe_runner.run_moe_fc(reinterpret_cast<const CudaT*>(input->template Data<T>()),
                         reinterpret_cast<const CudaT*>(router_probs->template Data<T>()),
-                        reinterpret_cast<const CudaT*>(fc1_experts_weights->template Data<T>()),
-                        std::move(fc1_scales_ptr),
+                        reinterpret_cast<const CudaT*>(fc1_experts_weights->DataRaw()),
+                        fc_scales_ptr,
                         fc1_experts_bias_optional == nullptr
                             ? nullptr
                             : reinterpret_cast<const CudaT*>(fc1_experts_bias_optional->template Data<T>()),
-                        activation_type_, reinterpret_cast<const CudaT*>(fc2_experts_weights->template Data<T>()),
-                        std::move(fc2_scales_ptr), static_cast<int>(moe_params.num_rows),
-                        static_cast<int>(moe_params.hidden_size), static_cast<int>(moe_params.inter_size),
-                        static_cast<int>(moe_params.num_experts), static_cast<int>(moe_params.local_num_experts),
-                        0 /*local_experts_start_index_ used in sharded MoE*/, static_cast<int>(k_),
-                        reinterpret_cast<char*>(work_space.get()), reinterpret_cast<CudaT*>(fc2_output.get()),
+                        activation_type_,
+                        fc3_experts_weights_optional == nullptr
+                            ? nullptr
+                            : reinterpret_cast<const CudaT*>(fc3_experts_weights_optional->DataRaw()),
+                        fc_scales_ptr,
+                        fc3_experts_bias_optional == nullptr
+                            ? nullptr
+                            : reinterpret_cast<const CudaT*>(fc3_experts_bias_optional->template Data<T>()),
+                        reinterpret_cast<const CudaT*>(fc2_experts_weights->DataRaw()),
+                        fc_scales_ptr,
+                        static_cast<int>(moe_params.num_rows),
+                        static_cast<int>(moe_params.hidden_size),
+                        static_cast<int>(moe_params.inter_size),
+                        static_cast<int>(moe_params.num_experts),
+                        static_cast<int>(moe_params.local_num_experts),
+                        0 /*local_experts_start_index_ used in sharded MoE*/,
+                        static_cast<int>(k_),
+                        reinterpret_cast<char*>(work_space.get()),
+                        reinterpret_cast<CudaT*>(fc2_output.get()),
                         reinterpret_cast<CudaT*>(expert_scales.get()),
                         reinterpret_cast<int*>(expanded_source_row_to_expanded_dest_row.get()),
-                        reinterpret_cast<int*>(expert_for_source_row.get()), Stream(context));
+                        reinterpret_cast<int*>(expert_for_source_row.get()),
+                        Stream(context));
 
   Tensor* output = context->Output(0, input->Shape());
 
diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h
index f55a7cde2e208..84a5e8c7c120d 100644
--- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h
+++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h
@@ -13,16 +13,22 @@ namespace cuda {
 
 enum class MoEParallelType {
   None = 0,
-  ExpertSlicing = 1,
+  EP = 1,
+  TP = 2,
+  EPAndTP = 3,
 };
 
 struct MoEParameters {
+  MoEParameters() {}
+  explicit MoEParameters(int64_t tensor_shards) : tensor_shards(tensor_shards) {}
   int64_t num_rows;
   int64_t num_experts;
   int64_t local_num_experts;
   int64_t hidden_size;
   int64_t inter_size;
+
   MoEParallelType parallel_type;
+  int64_t tensor_shards{1};
 };
 
 class MoEBase {
@@ -31,9 +37,11 @@ class MoEBase {
                      const Tensor* input,
                      const Tensor* router_probs,
                      const Tensor* fc1_experts_weights,
-                     const Tensor* fc2_experts_weights,
                      const Tensor* fc1_experts_bias_optional,
-                     const Tensor* fc2_experts_bias_optional) const {
+                     const Tensor* fc2_experts_weights,
+                     const Tensor* fc2_experts_bias_optional,
+                     const Tensor* fc3_experts_weights_optional,
+                     const Tensor* fc3_experts_bias_optional) const {
     const auto& input_dims = input->Shape().GetDims();
     const auto& router_probs_dims = router_probs->Shape().GetDims();
     const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims();
@@ -83,12 +91,6 @@ class MoEBase {
       return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ",
                              router_probs_dims[0], " and ", num_rows);
     }
-    if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional == nullptr) {
-      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is set but fc2_experts_bias is not set");
-    }
-    if (fc1_experts_bias_optional == nullptr && fc2_experts_bias_optional != nullptr) {
-      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is not set but fc2_experts_bias is set");
-    }
     if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional != nullptr) {
       const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims();
       const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims();
@@ -126,15 +128,38 @@ class MoEBase {
       }
     }
 
+    if (fc3_experts_weights_optional != nullptr &&
+        fc3_experts_weights_optional->Shape().GetDims() != fc1_experts_weights_dims) {
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+                             "fc3_experts_weights_dims must be equal to fc1_experts_weights_dims, got ",
+                             fc3_experts_weights_optional->Shape().GetDims(), " and ", fc1_experts_weights_dims);
+    }
+
+    if (fc3_experts_bias_optional != nullptr && fc1_experts_bias_optional != nullptr &&
+        fc3_experts_bias_optional->Shape().GetDims() != fc1_experts_bias_optional->Shape().GetDims()) {
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+                             "fc3_experts_bias_dims must be equal to fc1_experts_bias_dims, got ",
+                             fc3_experts_bias_optional->Shape().GetDims(), " and ",
+                             fc1_experts_bias_optional->Shape().GetDims());
+    }
+
     parameters.num_rows = num_rows;
     parameters.num_experts = num_experts;
     parameters.local_num_experts = local_num_experts;
     parameters.hidden_size = hidden_size;
     parameters.inter_size = inter_size;
     if (num_experts == local_num_experts) {
-      parameters.parallel_type = MoEParallelType::None;
+      if (parameters.tensor_shards == 1) {
+        parameters.parallel_type = MoEParallelType::None;
+      } else {
+        parameters.parallel_type = MoEParallelType::TP;
+      }
     } else if (num_experts > local_num_experts) {
-      parameters.parallel_type = MoEParallelType::ExpertSlicing;
+      if (parameters.tensor_shards == 1) {
+        parameters.parallel_type = MoEParallelType::EP;
+      } else {
+        parameters.parallel_type = MoEParallelType::EPAndTP;
+      }
     } else {
       return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
                              "num_experts must be greater than or equal to local_num_experts, got ",
@@ -161,8 +186,11 @@ class MoEBase {
     } else {
       ORT_THROW("Unsupported MoE activation type: ", activation_type_str);
     }
+
+    normalize_routing_weights_ = op_kernel_info.GetAttrOrDefault<int64_t>("normalize_routing_weights", 0) == 1;
   }
 
+  bool normalize_routing_weights_;
   int64_t k_;
   ort_fastertransformer::ActivationType activation_type_;
 };
diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc
index 001b6070d5e1a..168c69c69f003 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc
@@ -154,7 +154,7 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
   CudaT dequant_scale;
   CudaT input_scale = *(reinterpret_cast<const CudaT*>(input_scale_tensor->Data<T>()));
   CudaT weight_scale = *(reinterpret_cast<const CudaT*>(weight_scale_tensor->Data<T>()));
-  if (sizeof(T) == 2) {
+  if constexpr (sizeof(T) == 2) {
     dequant_scale = __float2half(__half2float(input_scale) * __half2float(weight_scale));
   } else {
     dequant_scale = input_scale * weight_scale;
diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
index 6b66f1d84e221..265adf22eeb61 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
+++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
@@ -2,10 +2,12 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
+#include <cstdint>
 #include <cub/cub.cuh>
 #include <cublas_v2.h>
 #include <cuda_fp16.h>
 #include <cmath>
+#include <type_traits>
 #include <math_constants.h>
 #include "core/providers/cuda/cu_inc/common.cuh"
 #include "core/providers/cuda/cuda_common.h"
@@ -21,7 +23,7 @@ namespace cuda {
 
 __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) {
   half2 scale_half2 = {scale, scale};
-  half zp_adjust = -scale * __short2half_rn(zp);
+  half zp_adjust = -scale * zp;
   half2 zp_adjust2 = {zp_adjust, zp_adjust};
 
   alignas(16) half2 results[4];
@@ -56,41 +58,95 @@ __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, f
 }
 
 template <class T>
-__global__ void Dequantize4BitsKernel(
+__global__ void Dequantize4BitsKernelReOrder(
     T* output,
     const uint8_t* quant_data,
     const T* scale_data,
     const uint8_t* zero_points,
+    const int32_t* reorder_idx,
     int block_size,
-    int blocks_per_K,
-    int blocks_per_threadblock,
-    int total_blks,
-    int shift) {
-  int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift);
-  if (block_id >= total_blks) {
+    int groups_per_K,
+    int groups_per_threadblock,
+    int total_groups) {
+  int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size);
+  if (group_id >= total_groups) {
     return;
   }
-  int n_idx = block_id / blocks_per_K;
-  int kb_idx = block_id % blocks_per_K;
-  int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1));
+  // T __shared__ zero_points_after_reorder[];//K
+  // T __shared__ scales_after_reorder[];     // K
+  // const int num_r_per_thread = k / 256;
+
+  const int zero_point_shape_x = (groups_per_K + 1) / 2;
+  const int scales_shape_x = groups_per_K;
+  int n_idx = group_id / scales_shape_x;
+  int kb_idx = group_id % scales_shape_x;
+  int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1));
+  T* output_i = output + element_offset;
+  uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
+  const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx.x * 8) & (block_size - 1));
+  for (int i = 0; i < 8; i++) {
+    int32_t rid = reorder_idx_with_off[i];
+    T scale = *(scale_data + n_idx * scales_shape_x + rid);
+    uint8_t zp = 8;
+    if (zero_points) {
+      zp = zero_points[n_idx * zero_point_shape_x + rid / 2];
+      zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f);
+    }
+
+    if constexpr (std::is_same_v<T, half>) {
+      T zp_adjust = -scale * __short2half_rn(zp);
+      output_i[i] = __uint2half_rn((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust;
+    } else {
+      T zp_adjust = -scale * T(zp);
+      output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust;
+    }
+  }
+}
+
+template <class T, typename ZeroT = uint8_t>
+__global__ void Dequantize4BitsKernel(
+    T* output,
+    const uint8_t* quant_data,
+    const T* scale_data,
+    const ZeroT* zero_points,
+    int block_size,
+    int groups_per_K,
+    int groups_per_threadblock,
+    int total_groups) {
+  int block_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size);
+  if (block_id >= total_groups) {
+    return;
+  }
+  int element_offset = block_id * block_size + ((threadIdx.x * 8) & (block_size - 1));
   uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
   T scale = *(scale_data + block_id);
-  uint8_t zp = 8;
-  if (zero_points) {
-    zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2];
-    zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f);
+  T zero_point_value;
+  if constexpr (std::is_same_v<ZeroT, uint8_t>) {
+    const int scales_shape_x = groups_per_K;
+    const int zero_point_shape_x = (groups_per_K + 1) / 2;
+    int kb_idx = block_id % scales_shape_x;
+    int n_idx = block_id / scales_shape_x;
+    uint8_t zp = 8;
+    if (zero_points) {
+      zp = zero_points[n_idx * zero_point_shape_x + kb_idx / 2];
+      zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f);
+    }
+    zero_point_value = static_cast<T>(zp);
+  } else {
+    zero_point_value = zero_points? *(zero_points + block_id):static_cast<T>(8);
   }
 
   output = output + element_offset;
-  DequantizeEightElements(quant_value, scale, static_cast<T>(zp), output);
+  DequantizeEightElements(quant_value, scale, zero_point_value, output);
 }
 
-template <class T>
+template <class T, typename ZeroT>
 Status Dequantize4Bits(
     T* output,
     const uint8_t* quant_data,
     const T* scales_data,
-    const uint8_t* zero_points,  // shape: [N, (block_per_K + 1)/2]
+    const ZeroT* zero_points,  // shape: [N, (block_per_K + 1)/2]
+    const int32_t* reorder_idx,
     int k,
     int n,
     int block_size,
@@ -98,47 +154,79 @@ Status Dequantize4Bits(
   // k is padded and equal to block_per_K * block_size
   ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size");
   constexpr int element_per_thread = 8;
-  int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size;
-  int blocks_per_K = k / block_size;
-  int total_blks = n * blocks_per_K;
-  int blocks_per_grid = static_cast<int>(CeilDiv(n * blocks_per_K, blocks_per_threadblock));
-  int shift = static_cast<int>(log2f(float(block_size)));
-
-  Dequantize4BitsKernel<<<blocks_per_grid, GridDim::maxThreadsPerBlock, 0, stream>>>(
-      output,
-      quant_data,
-      scales_data,
-      zero_points,
-      block_size,
-      blocks_per_K,
-      blocks_per_threadblock,
-      total_blks,
-      shift);
+  int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size;
+  int groups_per_K = k / block_size;
+  int total_groups = n * groups_per_K;  // total elemenets in quant_data
+  int groups_per_grid = static_cast<int>(CeilDiv(total_groups, groups_per_threadblock));
+  if (!reorder_idx || std::is_same_v<ZeroT, T>) {
+    Dequantize4BitsKernel<T, ZeroT><<<groups_per_grid, GridDim::maxThreadsPerBlock, 0, stream>>>(
+        output,
+        quant_data,
+        scales_data,
+        zero_points,
+        block_size,
+        groups_per_K,
+        groups_per_threadblock,
+        total_groups);
+  } else {
+    // static_assert(std::is_same_v<ZeroT, uint8_t>, "ZeroT must be uint8_t");
+    Dequantize4BitsKernelReOrder<<<groups_per_grid, GridDim::maxThreadsPerBlock, 0, stream>>>(
+        output,
+        quant_data,
+        scales_data,
+        (const uint8_t*)zero_points,
+        reorder_idx,
+        block_size,
+        groups_per_K,
+        groups_per_threadblock,
+        total_groups);
+  }
 
   return Status::OK();
 }
 
-template Status Dequantize4Bits<float>(
+template Status Dequantize4Bits<float, uint8_t>(
     float* output,
     const uint8_t* quant_data,
     const float* scales_data,
     const uint8_t* zero_points,
+    const int32_t* reorder_idx,
     int k,
     int n,
     int block_size,
     cudaStream_t stream);
 
-template Status Dequantize4Bits<half>(
+template Status Dequantize4Bits<half, uint8_t>(
     half* output,
     const uint8_t* quant_data,
     const half* scales_data,
     const uint8_t* zero_points,
+    const int32_t* reorder_idx,
+    int k,
+    int n,
+    int block_size,
+    cudaStream_t stream);
+template Status Dequantize4Bits<float, float>(
+    float* output,
+    const uint8_t* quant_data,
+    const float* scales_data,
+    const float* zero_points,
+    const int32_t* reorder_idx,
     int k,
     int n,
     int block_size,
     cudaStream_t stream);
 
-
+template Status Dequantize4Bits<half, half>(
+    half* output,
+    const uint8_t* quant_data,
+    const half* scales_data,
+    const half* zero_points,
+    const int32_t* reorder_idx,
+    int k,
+    int n,
+    int block_size,
+    cudaStream_t stream);
 ///////////////////////////////////////////////////////////////////////////////
 // A more general block-wise dequantization implementation that supports
 // different block sizes and block orientations (row-wise/column-wise).
diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh
index f9c09c55fd893..580b5087f3fa3 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh
+++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh
@@ -7,18 +7,18 @@
 namespace onnxruntime {
 namespace contrib {
 namespace cuda {
-template <class T>
+template <class T, typename ZeroT>
 Status Dequantize4Bits(
     T* output,
     const uint8_t* quant_data,
     const T* scales_data,
-    const uint8_t* zero_points,
+    const ZeroT* zero_points,
+    const int32_t* reorder_idx,
     int k,
     int n,
     int block_size,
     cudaStream_t stream);
 
-
 /**
  * @brief Dequantize a block-wise quantized matrix, and store the result in a
  *        column major matrix for use in subsequent GEMM. This implementation supports
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
index 015df70c8ec3c..1cec6f6a12f1c 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
@@ -1,15 +1,12 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
-//
-// This module define MatMulFp32Q4 operator, it is basically
-// matmul float32 with right hand side being a 2-D matrix
-// pre-packed and block-compacted into int4
-//
-
-#include "core/common/safeint.h"
-#include "core/providers/cuda/cuda_kernel.h"
-#include "core/providers/cuda/shared_inc/fpgeneric.h"
+#include "contrib_ops/cuda/quantization/matmul_nbits.h"
+
+#include <cstdint>
+
+#include "core/common/status.h"
+#include "core/framework/float16.h"
 #include "core/providers/cpu/math/matmul_helper.h"
 #include "matmul_nbits.cuh"
 #include "dequantize_blockwise.cuh"
@@ -19,40 +16,19 @@ namespace contrib {
 namespace cuda {
 using namespace onnxruntime::cuda;
 
-template <typename T>
-class MatMulNBits final : public CudaKernel {
- public:
-  MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) {
-    ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("K", &K_));
-    ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
-    ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
-    ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("bits", &nbits_));
-    ORT_ENFORCE(nbits_ == 4,
-                "Only 4b quantization is supported for MatMulNBits op,"
-                " additional bits support is planned.");
-  }
-
-  Status ComputeInternal(OpKernelContext* context) const override;
-
- private:
-  int64_t K_;
-  int64_t N_;
-  int64_t block_size_;
-  int64_t nbits_;
-  bool column_wise_quant_blk_{true};
-};
-
 template <typename T>
 Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
   const Tensor* a = ctx->Input<Tensor>(0);
   const Tensor* b = ctx->Input<Tensor>(1);
   const Tensor* scales = ctx->Input<Tensor>(2);
   const Tensor* zero_points = ctx->Input<Tensor>(3);
+  const Tensor* reorder_idx = ctx->Input<Tensor>(4);
 
   const auto* a_data = a->Data<T>();
   const uint8_t* blob_data = b->Data<uint8_t>();
   const auto* scales_data = scales->Data<T>();
-  const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data<uint8_t>();
+  const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
+  const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data<int32_t>();
 
   typedef typename ToCudaType<T>::MappedType CudaT;
 
@@ -67,77 +43,99 @@ Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
   // Bail out early if the output is going to be empty
   if (Y->Shape().Size() == 0) return Status::OK();
 
-  bool is_4bit_done = TryMatMul4Bits(
-      reinterpret_cast<CudaT*>(Y->MutableData<T>()),
-      reinterpret_cast<const CudaT*>(a_data),
-      blob_data,
-      reinterpret_cast<const CudaT*>(scales_data),
-      zero_points_data,
-      SafeInt<int>(helper.M()),
-      SafeInt<int>(helper.N()),
-      SafeInt<int>(helper.K()),
-      SafeInt<int>(block_size_),
-      SafeInt<int>(GetDeviceProp().sharedMemPerBlock),
-      static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()));
-  if (!is_4bit_done) {
-    int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_;
-    IAllocatorUniquePtr<T> b_data_ptr = GetScratchBuffer<T>(N_ * K_padded, ctx->GetComputeStream());
-    auto* b_data = b_data_ptr.get();
-    if (column_wise_quant_blk_) {
-      // column-wise block
+  bool is_4bit_done = (reorder_idx_data == nullptr) &&
+                      (!zero_points || !zero_points->IsDataType<T>()) &&
+                      TryMatMul4Bits(
+                          reinterpret_cast<CudaT*>(Y->MutableData<T>()),
+                          reinterpret_cast<const CudaT*>(a_data),
+                          blob_data,
+                          reinterpret_cast<const CudaT*>(scales_data),
+                          static_cast<const uint8_t*>(zero_points_data),
+                          SafeInt<int>(helper.M()),
+                          SafeInt<int>(helper.N()),
+                          SafeInt<int>(helper.K()),
+                          SafeInt<int>(block_size_),
+                          SafeInt<int>(GetDeviceProp().sharedMemPerBlock),
+                          static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()));
+
+  if (is_4bit_done) {
+    return Status::OK();
+  }
+
+  int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_;
+  IAllocatorUniquePtr<T> b_data_ptr = GetScratchBuffer<T>(N_ * K_padded, ctx->GetComputeStream());
+  auto* b_data = b_data_ptr.get();
+  if (column_wise_quant_blk_) {
+    if (reorder_idx) {
+      ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]");
+    }
+    // column-wise block
+    if ((zero_points && zero_points->IsDataType<T>())) {
       ORT_RETURN_IF_ERROR(Dequantize4Bits(
           reinterpret_cast<CudaT*>(b_data),
           blob_data,
           reinterpret_cast<const CudaT*>(scales_data),
-          zero_points_data,
+          (const CudaT*)zero_points_data,
+          reorder_idx_data,
           SafeInt<int>(K_padded),
           SafeInt<int>(N_),
           SafeInt<int>(block_size_),
           static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
     } else {
-      // row-wise block
-      K_padded = K_;
-
-      ORT_RETURN_IF_ERROR(DequantizeBlockwise4b(
+      ORT_RETURN_IF_ERROR(Dequantize4Bits(
           reinterpret_cast<CudaT*>(b_data),
           blob_data,
           reinterpret_cast<const CudaT*>(scales_data),
-          zero_points_data,
-          SafeInt<int>(block_size_),
-          column_wise_quant_blk_,
-          SafeInt<int>(K_),
+          (const uint8_t*)zero_points_data,
+          reorder_idx_data,
+          SafeInt<int>(K_padded),
           SafeInt<int>(N_),
+          SafeInt<int>(block_size_),
           static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
     }
+  } else {
+    // row-wise block
+    K_padded = K_;
+
+    ORT_RETURN_IF_ERROR(DequantizeBlockwise4b(
+        reinterpret_cast<CudaT*>(b_data),
+        blob_data,
+        reinterpret_cast<const CudaT*>(scales_data),
+        (const uint8_t*)zero_points_data,
+        SafeInt<int>(block_size_),
+        column_wise_quant_blk_,
+        SafeInt<int>(K_),
+        SafeInt<int>(N_),
+        static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
+  }
 #if 0
-  cudaStreamSynchronize(static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()));
-  T* b_data_cpu = new T[K_ * N_];
-  cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost);
-  delete[] b_data_cpu;
+cudaStreamSynchronize(static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()));
+T* b_data_cpu = new T[K_ * N_];
+cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost);
+delete[] b_data_cpu;
 #endif
 
-    const CudaT alpha = ToCudaType<T>::FromFloat(1.f);
-    const CudaT zero = ToCudaType<T>::FromFloat(0.f);
-
-    if (helper.OutputOffsets().size() == 1) {
-      CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
-          GetCublasHandle(ctx),
-          CUBLAS_OP_T,
-          CUBLAS_OP_N,
-          SafeInt<int>(helper.N()),
-          SafeInt<int>(helper.M()),
-          SafeInt<int>(helper.K()),
-          &alpha,
-          reinterpret_cast<const CudaT*>(b_data),
-          SafeInt<int>(K_padded),
-          reinterpret_cast<const CudaT*>(a_data),
-          helper.Lda(transa),
-          &zero,
-          reinterpret_cast<CudaT*>(Y->MutableData<T>()),
-          helper.Ldc(),
-          GetDeviceProp(),
-          UseTF32()));
-    }
+  const CudaT alpha = ToCudaType<T>::FromFloat(1.f);
+  const CudaT zero = ToCudaType<T>::FromFloat(0.f);
+
+  if (helper.OutputOffsets().size() == 1) {
+    CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
+        GetCublasHandle(ctx),
+        CUBLAS_OP_T,
+        CUBLAS_OP_N,
+        SafeInt<int>(helper.N()),
+        SafeInt<int>(helper.M()),
+        SafeInt<int>(helper.K()),
+        &alpha,
+        reinterpret_cast<const CudaT*>(b_data),
+        SafeInt<int>(K_padded),
+        reinterpret_cast<const CudaT*>(a_data),
+        helper.Lda(transa),
+        &zero,
+        reinterpret_cast<CudaT*>(Y->MutableData<T>()),
+        helper.Ldc(),
+        GetDeviceProp(),
+        UseTF32()));
   }
 
   return Status::OK();
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h
new file mode 100644
index 0000000000000..f5c2c6c4e4fdf
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h
@@ -0,0 +1,41 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+//
+// This module define MatMulNBits operator, it is basically
+// matmul float with right hand side being a 2-D matrix
+// pre-packed and block-compacted into int4
+//
+#pragma once
+#include "core/common/safeint.h"
+#include "core/providers/cuda/cuda_kernel.h"
+#include "core/providers/cuda/shared_inc/fpgeneric.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+using namespace onnxruntime::cuda;
+
+template <typename T>
+class MatMulNBits final : public CudaKernel {
+ public:
+  MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) {
+    ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("K", &K_));
+    ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
+    ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
+    ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("bits", &nbits_));
+  }
+
+  Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+  int64_t K_;
+  int64_t N_;
+  int64_t block_size_;
+  int64_t nbits_;
+  bool column_wise_quant_blk_{true};
+};
+
+}  // namespace cuda
+}  // namespace contrib
+}  // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc
index 3cecebedae2f0..12835978536e1 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc
@@ -142,7 +142,7 @@ inline void debug_print([[maybe_unused]] const T* arr,
   std::cout << "========" << name << std::endl;
   for (size_t i = 0; i < sz; i++) {
     if (i % w == 0) std::cout << std::endl;
-    if (std::is_same<T, int8_t>().value) {
+    if constepxr (std::is_same<T, int8_t>::value) {
       std::cout << (int)buf[i] << ", ";
     } else {
       std::cout << buf[i] << ", ";
diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu
index f4d5a7b404a62..fd4b51f40fb4f 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu
@@ -151,7 +151,7 @@ QOrderBatchInt8MatrixTransposeKernel(const int8_t* src, const int8_t* dst, const
   }
 }
 
-Status QOrderBatchTransposeInt8Matrix(cudaStream_t stream, const cudaDeviceProp& device_prop,
+Status QOrderBatchTransposeInt8Matrix(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/,
                                       const int batch_size, const int rows, const int cols,
                                       const int8_t* input, int8_t* output) {
   ORT_ENFORCE(rows % 4 == 0 && cols % 4 == 0, "Matrix rows and cols must be divisible by 4!");
diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu
index baff8e76ec73b..e6ac0bc8a5171 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu
@@ -389,7 +389,7 @@ QOrderDequantizeKernel_Strict(const int8_t* __restrict__ src, const __half* __re
   }
 }
 
-Status QOrderDequantize_Strict(cudaStream_t stream, const cudaDeviceProp& device_prop,
+Status QOrderDequantize_Strict(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/,
                                const int8_t* src, __half* dst, float scale, size_t N) {
   ORT_RETURN_IF(N & 0x3LL, "N can not divide by 4!");
 
diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu
index a39abefed9cd0..eb1943b59d976 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu
@@ -1,11 +1,22 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
+
+// cub.cuh includes device/dispatch_radix_sort.cuh which has assignment in conditional expressions
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4706)  
+#endif
+#include <cub/cub.cuh>
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+#include <cub/util_type.cuh>
+
 #include "core/providers/cuda/cuda_common.h"
 #include "core/providers/cuda/cu_inc/common.cuh"
-#include "cub/util_type.cuh"
-#include <cub/cub.cuh>
-#include <cub/device/device_segmented_radix_sort.cuh>
+
 #include "contrib_ops/cuda/bert/utils.cuh"
 #include "contrib_ops/cuda/transformers/generation_cuda_impl.h"
 
diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc
index bba30805ae1be..7adc2fe0a67ea 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc
+++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc
@@ -424,7 +424,7 @@ Status ProcessLogits(const OrtValue& logits,                                 //
   const bool is_whisper_model = (parameters->model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper);
   if (step == 1 && is_whisper_model && parameters->no_speech_probs) {
     cuda::LaunchSaveNoSpeechProbs<T>(
-        (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token, cuda_stream);
+        (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token_id, cuda_stream);
   }
 
   // NOTE: currently we treat extra decoding ids are same
@@ -469,7 +469,15 @@ Status ProcessLogits(const OrtValue& logits,                                 //
                                          cudaMemcpyDeviceToHost,
                                          cuda_stream));
     constexpr int max_initial_timestamp_index = 50;
-    onnxruntime::contrib::transformers::TimestampLogitsProcessor<float> time_logit_processor(parameters->eos_token_id, max_initial_timestamp_index);
+    // Token ids are passed below in the order that they appear in the tokenizer
+    onnxruntime::contrib::transformers::TimestampLogitsProcessor<float> time_logit_processor(parameters->eos_token_id,
+                                                                                             parameters->decoder_start_token_id,
+                                                                                             parameters->translate_token_id,
+                                                                                             parameters->transcribe_token_id,
+                                                                                             parameters->start_of_lm_token_id,
+                                                                                             parameters->no_timestamps_token_id,
+                                                                                             parameters->beginning_timestamp_token_id,
+                                                                                             max_initial_timestamp_index);
     onnxruntime::contrib::transformers::NextTokenScores<float> next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size});
 
     CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc
index bd58dded026a6..25e7567a2e9fc 100644
--- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc
@@ -8,13 +8,14 @@ namespace contrib {
 namespace js {
 
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention);
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu);
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization);
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv);
 
 template <>
 KernelCreateInfo BuildKernelCreateInfo<void>() {
@@ -25,14 +26,15 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
 Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
   static const BuildKernelCreateInfoFn function_table[] = {
       BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
-      BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd)>,
-      BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1,
-                                                            SkipLayerNormalization)>,
-      BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv)>};
+                                                            SkipLayerNormalization)>};
 
   for (auto& function_table_entry : function_table) {
     KernelCreateInfo info = function_table_entry();
diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc
new file mode 100644
index 0000000000000..888db0fd161f2
--- /dev/null
+++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc
@@ -0,0 +1,25 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/js/quantization/matmul_nbits.h"
+#include "core/providers/js/js_data_types.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace js {
+
+using onnxruntime::js::JsepSupportedFloatTypes;
+
+ONNX_OPERATOR_KERNEL_EX(
+    MatMulNBits,
+    kMSDomain,
+    1,
+    kJsExecutionProvider,
+    (*KernelDefBuilder::Create())
+        .TypeConstraint("T1", JsepSupportedFloatTypes())
+        .TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
+    MatMulNBits);
+
+}  // namespace js
+}  // namespace contrib
+}  // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h
new file mode 100644
index 0000000000000..cca2c4757765b
--- /dev/null
+++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h
@@ -0,0 +1,48 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/js/js_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace js {
+
+using onnxruntime::js::JsKernel;
+
+class MatMulNBits final : public JsKernel {
+ public:
+  MatMulNBits(const OpKernelInfo& info) : JsKernel(info),
+                                          K_{narrow<size_t>(info.GetAttr<int64_t>("K"))},
+                                          N_{narrow<size_t>(info.GetAttr<int64_t>("N"))},
+                                          accuracy_level_{info.GetAttrOrDefault<int64_t>("accuracy_level", 0)},
+                                          nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
+                                          block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))} {
+    ORT_ENFORCE(nbits_ == 4,
+                "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
+    ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)),
+                "Block size must be a power of 2 and greater than or equal to 16.");
+    JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({
+                                 "k" : $1,
+                                 "n" : $2,
+                                 "accuracyLevel" : $3,
+                                 "bits" : $4,
+                                 "blockSize" : $5
+                               }),
+                               static_cast<int32_t>(K_),
+                               static_cast<int32_t>(N_),
+                               static_cast<int32_t>(accuracy_level_),
+                               static_cast<int32_t>(nbits_),
+                               static_cast<int32_t>(block_size_));
+  }
+
+ private:
+  const size_t K_;
+  const size_t N_;
+  const int64_t accuracy_level_;
+  const size_t nbits_;
+  const size_t block_size_;
+};
+
+}  // namespace js
+}  // namespace contrib
+}  // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc
deleted file mode 100644
index 9cb414e4e8980..0000000000000
--- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc
+++ /dev/null
@@ -1,59 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "contrib_ops/rocm/bert/fast_gelu.h"
-
-#include "core/providers/rocm/rocm_common.h"
-#include "core/providers/rocm/miopen_common.h"
-#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
-#include "contrib_ops/rocm/bert/elementwise.h"
-#include "contrib_ops/rocm/bert/transformer_common.h"
-
-namespace onnxruntime {
-namespace contrib {
-namespace rocm {
-
-#define REGISTER_KERNEL_TYPED(T)                                  \
-  ONNX_OPERATOR_TYPED_KERNEL_EX(                                  \
-      FastGelu,                                                   \
-      kMSDomain,                                                  \
-      1,                                                          \
-      T,                                                          \
-      kRocmExecutionProvider,                                     \
-      (*KernelDefBuilder::Create())                               \
-          .TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
-      FastGelu<T>);
-
-REGISTER_KERNEL_TYPED(float)
-REGISTER_KERNEL_TYPED(MLFloat16)
-REGISTER_KERNEL_TYPED(BFloat16)
-
-using namespace ONNX_NAMESPACE;
-
-template <typename T>
-Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
-  ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context));
-
-  const Tensor* input = context->Input<Tensor>(0);
-  const Tensor* bias = context->Input<Tensor>(1);
-  Tensor* output = context->Output(0, input->Shape());
-
-  int64_t input_length = input->Shape().Size();
-  if (input_length == 0) {
-    return Status::OK();
-  }
-  int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
-  typedef typename ToHipType<T>::MappedType HipT;
-
-  const HipT* input_buffer = reinterpret_cast<const HipT*>(input->Data<T>());
-  const HipT* bias_buffer = (nullptr != bias) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr;
-  return LaunchElementwiseKernel<functor::FastGeLU, HipT>(
-      GetTuningContext(), context->GetComputeStream(),
-      input_buffer, static_cast<int>(input_length),
-      bias_buffer, static_cast<int>(bias_length),
-      reinterpret_cast<HipT*>(output->MutableData<T>()));
-}
-
-}  // namespace rocm
-}  // namespace contrib
-}  // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h
deleted file mode 100644
index 42bfe5a0b0246..0000000000000
--- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h
+++ /dev/null
@@ -1,24 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include "core/common/common.h"
-#include "core/providers/rocm/rocm_kernel.h"
-
-namespace onnxruntime {
-namespace contrib {
-namespace rocm {
-
-using namespace onnxruntime::rocm;
-
-template <typename T>
-class FastGelu final : public RocmKernel {
- public:
-  FastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {}
-  Status ComputeInternal(OpKernelContext* ctx) const override;
-};
-
-}  // namespace rocm
-}  // namespace contrib
-}  // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc
deleted file mode 100644
index e82e15a304f4c..0000000000000
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc
+++ /dev/null
@@ -1,152 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "core/providers/rocm/rocm_common.h"
-#include "contrib_ops/rocm/diffusion/group_norm.h"
-#include "contrib_ops/rocm/diffusion/group_norm_impl.h"
-
-namespace onnxruntime {
-namespace contrib {
-namespace rocm {
-
-#define GROUP_NORM_TYPES float, MLFloat16
-
-ONNX_OPERATOR_KERNEL_EX(
-    GroupNorm, kMSDomain, 1, kRocmExecutionProvider,
-    (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<GROUP_NORM_TYPES>()), GroupNorm);
-
-using namespace ONNX_NAMESPACE;
-
-namespace {
-template <typename T>
-struct DispatchGroupNorm {
-  Status operator()(RocmTuningContext* tuning_ctx,
-                    Stream* stream,
-                    Tensor* output,
-                    const Tensor* input,
-                    const Tensor* gamma,
-                    const Tensor* beta,
-                    void* workspace,
-                    float epsilon,
-                    int batch_size,
-                    int num_channels,
-                    int height,
-                    int width,
-                    int num_groups,
-                    bool use_swish_activation) {
-    typedef typename ToHipType<T>::MappedType HipT;
-    return LaunchGroupNormKernel<HipT>(
-        tuning_ctx,
-        stream,
-        reinterpret_cast<HipT*>(output->MutableData<T>()),
-        reinterpret_cast<const HipT*>(input->Data<T>()),
-        gamma->Data<float>(),
-        beta->Data<float>(),
-        workspace,
-        epsilon,
-        batch_size,
-        num_channels,
-        height,
-        width,
-        num_groups,
-        use_swish_activation);
-  }
-};
-
-}  // namespace
-
-GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) {
-  epsilon_ = op_info.GetAttrOrDefault<float>("epsilon", 1e-5f);
-  ORT_ENFORCE(epsilon_ >= 0);
-
-  int64_t num_groups;
-  ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK());
-  ORT_ENFORCE(num_groups >= 0);
-  num_groups_ = static_cast<int>(num_groups);
-
-  int64_t activation;
-  ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK());
-  ORT_ENFORCE(activation == 0 || activation == 1);  // 0 is None, 1 is Swish
-  use_swish_activation_ = (activation == 1);
-
-  channels_last_ = (op_info.GetAttrOrDefault<int64_t>("channels_last", static_cast<int64_t>(1)) != 0);
-}
-
-Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
-                          bool& is_packed, PrePackedWeights* /*prepacked_weights*/) {
-  is_packed = false;
-  return Status::OK();
-}
-
-Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
-  const Tensor* input = context->Input<Tensor>(0);
-  const Tensor* gamma = context->Input<Tensor>(1);
-  const Tensor* beta = context->Input<Tensor>(2);
-  Tensor* output = context->Output(0, input->Shape());
-
-  if (!channels_last_) {
-    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
-                           "only the channels_last layout is supported");
-  }
-
-  const auto& input_dims = input->Shape().GetDims();
-  if (input_dims.size() != 4) {
-    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
-                           "input is expected to have 4 dimensions, got ", input_dims.size());
-  }
-
-  const auto& gamma_dims = gamma->Shape().GetDims();
-  if (gamma_dims.size() != 1) {
-    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
-                           "gamma is expected to have 1 dimension, got ", gamma_dims.size());
-  }
-  if (gamma_dims[0] != input_dims[3]) {
-    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
-                           "Number of channels in gamma and input does not match");
-  }
-
-  const auto& beta_dims = beta->Shape().GetDims();
-  if (beta_dims.size() != 1) {
-    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
-                           "beta is expected to have 1 dimension, got ", beta_dims.size());
-  }
-  if (beta_dims[0] != input_dims[3]) {
-    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
-                           "Number of channels in beta and input does not match");
-  }
-
-  // Input and output format is NHWC
-  int batch_size = static_cast<int>(input_dims[0]);
-  int num_channels = static_cast<int>(input_dims[3]);
-  int height = static_cast<int>(input_dims[1]);
-  int width = static_cast<int>(input_dims[2]);
-
-  if (num_channels % num_groups_ != 0) {
-    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
-                           "number of channels should be divisible by num_groups");
-  }
-
-  if (context->GetUseDeterministicCompute()) {
-    static std::once_flag log_warning;
-    std::call_once(log_warning, []() {
-      LOGS_DEFAULT(WARNING) << "GroupNorm has no deterministic GPU kernel, its outputs may still be nondeterministic.";
-    });
-  }
-
-  auto workspace = GetScratchBuffer<void>(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream());
-
-  utils::MLTypeCallDispatcher<GROUP_NORM_TYPES> dispatcher(input->GetElementType());
-  return dispatcher.InvokeRet<Status, DispatchGroupNorm>(GetTuningContext(), context->GetComputeStream(),
-                                                         output, input, gamma, beta, workspace.get(),
-                                                         epsilon_,
-                                                         batch_size,
-                                                         num_channels,
-                                                         height,
-                                                         width,
-                                                         num_groups_,
-                                                         use_swish_activation_);
-}
-
-}  // namespace rocm
-}  // namespace contrib
-}  // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh
index fb7091592c16e..d0a0d09fcbae3 100644
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh
+++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh
@@ -26,13 +26,18 @@ namespace rocm {
 
 using onnxruntime::rocm::CKDataTypeAdaptor;
 
-using Swish = ck::tensor_operation::element_wise::Swish;
+// The SiLU function is a special case of Swish function,
+// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as:
+// SiLU(x) = x * sigmoid(x)
+// Swish(x) = x * sigmoid(bx)
+// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here.
+using Silu = ck::tensor_operation::element_wise::Swish;
 using Pass = ck::tensor_operation::element_wise::PassThrough;
 
 constexpr int Rank = 5;
 constexpr int NumReduceDim = 3;
 
-template <typename T, typename AccT, bool WithSwish>
+template <typename T, typename AccT, bool WithSilu>
 auto GetCKGroupNormNHWCTypeStringAndOps() {
   using XDataType = typename CKDataTypeAdaptor<T>::type;
   using YDataType = typename CKDataTypeAdaptor<T>::type;
@@ -40,26 +45,30 @@ auto GetCKGroupNormNHWCTypeStringAndOps() {
   using GammaDataType = float;
   using BetaDataType = float;
 
-  using Activation = std::conditional_t<WithSwish, Swish, Pass>;
+  using Activation = std::conditional_t<WithSilu, Silu, Pass>;
 
-  std::vector<std::pair<std::string, onnxruntime::rocm::tunable::Op<GroupNormNHWCParams<T>>>> ret;
+  std::vector<std::pair<std::string, onnxruntime::rocm::tunable::Op<GroupNormNHWCTunableParams<T>>>> ret;
   for (auto&& impl : internal::GetDeviceGroupNormInstances<XDataType, GammaDataType, BetaDataType, YDataType,
                                                            SaveMeanInvStdDataType, Activation, Rank, NumReduceDim>()) {
-    std::string swish_suffix = WithSwish ? "_Swish" : "_Pass";
-    auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix;
+    std::string silu_suffix = WithSilu ? "_Silu" : "_Pass";
+    auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix;
     auto invoker = impl->MakeInvokerPointer();
 
-    auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCParams<T>* params) -> Status {
-      if constexpr (WithSwish) {
+    auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](
+                                const GroupNormNHWCTunableParams<T>* params) -> Status {
+      TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr),
+                                                "Input skip or bias is not supported by composable kernel.");
+      if constexpr (WithSilu) {
         TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
-            !params->withSwish, "Swish version only support groupnorm with swish");
+            !params->use_silu, "Silu version only support groupnorm with silu");
       } else {
         TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
-            params->withSwish, "Pass version only support groupnorm without swish");
+            params->use_silu, "Pass version only support groupnorm without silu");
       }
-      std::vector<ck::index_t> in_lengths{params->n, params->h, params->w, params->groups, params->cPerGroup};
-      std::vector<ck::index_t> in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->cPerGroup, 1};
-      std::vector<ck::index_t> gamma_beta_strides{0, 0, 0, params->cPerGroup, 1};
+      std::vector<ck::index_t> in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group};
+      std::vector<ck::index_t> in_out_strides{params->h * params->w * params->c, params->w * params->c,
+                                              params->c, params->channels_per_group, 1};
+      std::vector<ck::index_t> gamma_beta_strides{0, 0, 0, params->channels_per_group, 1};
       std::vector<ck::index_t> reduce_dims{1, 2, 4};
 
       auto activation = Activation{};
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh
index 19b081881dcec..4cb371fdcf960 100644
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh
+++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh
@@ -18,7 +18,7 @@ namespace internal {
 using F16 = ck::half_t;
 using F32 = float;
 
-using Swish = ck::tensor_operation::element_wise::Swish;
+using Silu = ck::tensor_operation::element_wise::Swish;
 using Pass = ck::tensor_operation::element_wise::PassThrough;
 
 using ck::tensor_operation::device::DeviceNormalizationFwd;      // the interface
@@ -101,9 +101,9 @@ GetDeviceGroupNormInstances() {
 
 template <>
 std::vector<std::unique_ptr<DeviceNormalizationFwd<
-    F16, F32, F32, F16, F32, Swish, 5, 3>>>
+    F16, F32, F32, F16, F32, Silu, 5, 3>>>
 GetDeviceGroupNormInstances<
-    F16, F32, F32, F16, F32, Swish, 5, 3>();
+    F16, F32, F32, F16, F32, Silu, 5, 3>();
 
 template <>
 std::vector<std::unique_ptr<DeviceNormalizationFwd<
@@ -113,9 +113,9 @@ GetDeviceGroupNormInstances<
 
 template <>
 std::vector<std::unique_ptr<DeviceNormalizationFwd<
-    F32, F32, F32, F32, F32, Swish, 5, 3>>>
+    F32, F32, F32, F32, F32, Silu, 5, 3>>>
 GetDeviceGroupNormInstances<
-    F32, F32, F32, F32, F32, Swish, 5, 3>();
+    F32, F32, F32, F32, F32, Silu, 5, 3>();
 
 template <>
 std::vector<std::unique_ptr<DeviceNormalizationFwd<
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu
index 6718f29268031..ad191314e5e4c 100644
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu
+++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp16.cu
@@ -11,12 +11,12 @@ namespace rocm {
 namespace internal {
 
 template <>
-std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Swish, 5, 3>>>
-GetDeviceGroupNormInstances<F16, F32, F32, F16, F32, Swish, 5, 3>() {
-  std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Swish, 5, 3>>> instances;
+std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Silu, 5, 3>>>
+GetDeviceGroupNormInstances<F16, F32, F32, F16, F32, Silu, 5, 3>() {
+  std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Silu, 5, 3>>> instances;
   ck::tensor_operation::device::instance::add_device_operation_instances(
       instances,
-      device_normalization_f16_instances<Swish, 5, 3>{});
+      device_normalization_f16_instances<Silu, 5, 3>{});
 
   return instances;
 }
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu
index 9b0ccab17b4c1..ceb53ed442abc 100644
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu
+++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu
@@ -11,12 +11,12 @@ namespace rocm {
 namespace internal {
 
 template <>
-std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Swish, 5, 3>>>
-GetDeviceGroupNormInstances<F32, F32, F32, F32, F32, Swish, 5, 3>() {
-  std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Swish, 5, 3>>> instances;
+std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Silu, 5, 3>>>
+GetDeviceGroupNormInstances<F32, F32, F32, F32, F32, Silu, 5, 3>() {
+  std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Silu, 5, 3>>> instances;
   ck::tensor_operation::device::instance::add_device_operation_instances(
       instances,
-      device_normalization_f32_instances<Swish, 5, 3>{});
+      device_normalization_f32_instances<Silu, 5, 3>{});
 
   return instances;
 }
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h
index 008ae20b0561f..7cff640db2f34 100644
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h
+++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h
@@ -8,110 +8,47 @@
 #include "core/providers/rocm/cu_inc/common.cuh"
 #include "core/providers/rocm/rocm_common.h"
 #include "core/providers/rocm/tunable/rocm_tunable.h"
+#include "contrib_ops/rocm/diffusion/group_norm_common_base.h"
 
 namespace onnxruntime {
 namespace contrib {
 namespace rocm {
 
-using onnxruntime::rocm::CeilDiv;
-
-int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
-  int32_t maxDivisor = -1;
-  for (int32_t i = 1; i <= std::sqrt(n); i++) {
-    if (n % i == 0) {
-      int32_t divisor1 = n / i;
-      int32_t divisor2 = i;
-
-      if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) {
-        maxDivisor = divisor1;
-      }
-      if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) {
-        maxDivisor = divisor2;
-      }
-    }
-  }
-  return maxDivisor;
-}
-
 template <typename T>
-struct GroupNormNHWCParams : OpParams {
-  GroupNormNHWCParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, T* dst, float* redBuffer, const T* src, const float* gamma,
-                      const float* beta, int32_t n, int32_t h, int32_t w, int32_t c, int32_t groups, float epsilon, bool withSwish)
-      : OpParams(tuning_ctx, stream), dst(dst), src(src), gamma(gamma), beta(beta), redBuffer(redBuffer), epsilon(epsilon), n(n), h(h), w(w), c(c), groups(groups), withSwish(withSwish) {
-    int32_t maxBlocksPerHW = 1024;
-    switch (c) {
-      case 960:
-      case 1920:
-        cPerBlock = 480;
-        break;
-      case 512:
-      case 256:
-        cPerBlock = 256;
-        break;
-      case 128:
-        cPerBlock = 128;
-        break;
-      default:
-        cPerBlock = 320;
-    }
-
-    hw = h * w;
-    const int32_t blocksPerHW = findMaxDivisor(hw, maxBlocksPerHW);
-    hwPerBlock = CeilDiv(hw, blocksPerHW);
-    cPerGroup = c / groups;
-    hwc = hw * c;
-    invHWC = 1.F / (float)(hw * cPerGroup);
-    groupsPerBlock = cPerBlock / cPerGroup;
-  }
+struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams<T> {
+  GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx,
+                             onnxruntime::Stream* ort_stream,
+                             T* output,
+                             T* add_out,
+                             const T* input,
+                             const T* skip,
+                             const T* bias,
+                             const float* gamma,
+                             const float* beta,
+                             float* workspace,
+                             float epsilon,
+                             int batch_size,
+                             int num_channels,
+                             int height,
+                             int width,
+                             int num_groups,
+                             bool use_silu,
+                             bool broadcast_skip,
+                             int channels_per_block)
+      : OpParams(tuning_ctx, ort_stream),
+        GroupNormNHWCParams<T>(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size,
+                               num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {}
 
   std::string Signature() const override {
-    std::string swish_suffix = withSwish ? "_Swish" : "_Pass";
-    std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups) + swish_suffix;
+    std::string silu_suffix = this->use_silu ? "_silu" : "_pass";
+    std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip";
+    std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast";
+    std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias";
+    std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" +
+                      std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix +
+                      skip_suffix + broadcast_suffix + bias_suffix;
     return sig;
   }
-
-  // The output buffer. Layout NHWC.
-  T* dst;
-  // The input buffer. Layout NHWC.
-  T const* src;
-  // The gamma scaling factor.
-  float const* gamma;
-  // The beta term to add in GN.
-  float const* beta;
-  // The temporary buffer to do the global parallel reduction. Size:
-  // BLOCKS_PER_BATCH x C x 2.
-  float* redBuffer;
-  float epsilon;
-
-  // The number of instances in the batch.
-  int32_t n;
-  // The height and width of each activation map.
-  int32_t h;
-  int32_t w;
-  // The number of channels.
-  int32_t c;
-  // The number of groups.
-  int32_t groups;
-  // Do we apply the Swish activation function?
-  bool withSwish;
-
-  // Precomputed values and parameters to control the execution of the kernels.
-
-  // The number of activations per instance (h * w) and the number of
-  // activations per block.
-  int32_t hw;
-  int32_t hwPerBlock;
-  // The number of channels per group and blocks per activation in the C
-  // dimension.
-  int32_t cPerBlock;
-  int32_t cPerGroup;
-
-  // The precomputed stride between instances.
-  int32_t hwc;
-  // The inverse of hwc in floats (to compute mean/var).
-  float invHWC;
-  // The precomputed number of groups per block.
-  int32_t groupsPerBlock;
 };
 
 }  // namespace rocm
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu
index dbd5009e63676..142aaf14e8d2d 100644
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu
+++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu
@@ -15,9 +15,12 @@ namespace rocm {
 template <typename T>
 Status LaunchGroupNormKernel(
     RocmTuningContext* tuning_ctx,
-    Stream* stream,
+    Stream* ort_stream,
     T* output,
+    T* add_out,
     const T* input,
+    const T* skip,
+    const T* bias,
     const float* gamma,
     const float* beta,
     void* workspace,
@@ -27,19 +30,26 @@ Status LaunchGroupNormKernel(
     int height,
     int width,
     int num_groups,
-    bool use_swish_activation) {
-  if (batch_size > static_cast<int>(kMaxGroupNormBatchSize)) {
-    return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
-                           "only support batch_size <= 32. Got", batch_size);
-  }
+    bool use_silu,
+    bool broadcast_skip,
+    int channels_per_block) {
+  GroupNormNHWCTunableParams<T> params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta,
+                                       reinterpret_cast<float*>(workspace), epsilon, batch_size, num_channels,
+                                       height, width, num_groups, use_silu, broadcast_skip, channels_per_block);
 
-  if (num_groups != static_cast<int>(kGroupNormNumberOfGroups)) {
-    return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
-                           "only num_groups=32 is supported. Got", num_groups);
+  if (params.channels_per_block % params.channels_per_group != 0 ||
+      params.channels_per_block > kMaxSize ||
+      (params.channels_per_group % CHANNELS_PER_THREAD != 0)) {
+    return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
+                           "GroupNorm in ROCM does not support the input: n=", batch_size,
+                           " h=", height,
+                           " w=", width,
+                           " c=", num_channels,
+                           " groups=", num_groups);
   }
 
-  GroupNormNHWCParams<T> params(tuning_ctx, stream, output, reinterpret_cast<float*>(workspace), input, gamma, beta,
-                                batch_size, height, width, num_channels, num_groups, epsilon, use_swish_activation);
+  HIP_RETURN_IF_ERROR(hipMemsetAsync(
+      params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle()));
 
   if (tuning_ctx->IsTunableOpEnabled()) {
     static GroupNormNHWCTunableOp<T> op;
@@ -50,14 +60,17 @@ Status LaunchGroupNormKernel(
 }
 
 template Status LaunchGroupNormKernel<half>(RocmTuningContext* tuning_ctx, Stream* stream, half* output,
-                                            const half* input, const float* gamma, const float* beta, void* workspace,
-                                            float epsilon, int batch_size, int num_channels,
-                                            int height, int width, int num_groups, bool swish);
+                                            half* add_out, const half* input, const half* skip, const half* bias,
+                                            const float* gamma, const float* beta, void* workspace, float epsilon,
+                                            int batch_size, int num_channels, int height, int width, int num_groups,
+                                            bool use_silu, bool broadcast_skip, int channels_per_block);
 
 template Status LaunchGroupNormKernel<float>(RocmTuningContext* tuning_ctx, Stream* stream, float* output,
-                                             const float* input, const float* gamma, const float* beta, void* workspace,
-                                             float epsilon, int batch_size, int num_channels,
-                                             int height, int width, int num_groups, bool swish);
+                                             float* add_out, const float* input, const float* skip, const float* bias,
+                                             const float* gamma, const float* beta, void* workspace, float epsilon,
+                                             int batch_size, int num_channels, int height, int width, int num_groups,
+                                             bool use_silu, bool broadcast_skip, int channels_per_block);
+
 }  // namespace rocm
 }  // namespace contrib
 }  // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h
deleted file mode 100644
index a0f7e0aca5def..0000000000000
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h
+++ /dev/null
@@ -1,47 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include <cstdint>
-#include <hip/hip_runtime.h>
-
-#include "core/common/common.h"
-#include "core/common/status.h"
-#include "core/providers/rocm/tunable/rocm_tunable.h"
-
-using onnxruntime::rocm::tunable::RocmTuningContext;
-
-namespace onnxruntime {
-namespace contrib {
-namespace rocm {
-
-constexpr size_t kMaxGroupNormBatchSize = 32;
-constexpr size_t kGroupNormNumberOfGroups = 32;
-
-constexpr size_t GetGroupNormWorkspaceSizeInBytes() {
-  // Two buffers for sum and squared sum
-  return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups;
-}
-
-template <typename T>
-Status LaunchGroupNormKernel(
-    RocmTuningContext* tuning_ctx,
-    Stream* stream,
-    T* output,                 // normalized output tensor
-    const T* input,            // input tensor
-    const float* gamma,        // gamma (also known as weight or scale)
-    const float* beta,         // beta (also known as bias)
-    void* workspace,           // Work space
-    float epsilon,             // epsilon used normalization
-    int batch_size,            // N
-    int num_channels,          // C
-    int height,                // H
-    int width,                 // W
-    int num_groups,            // number of groups
-    bool use_swish_activation  // Whether there is Swish activation after group normalization
-);
-
-}  // namespace rocm
-}  // namespace contrib
-}  // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh
deleted file mode 100644
index d6322a12a9363..0000000000000
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh
+++ /dev/null
@@ -1,213 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-// The ROCm kernel is modified from TensorRT 8.5.
-/*
- * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
- * SPDX-License-Identifier: Apache-2.0
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <hip/hip_fp16.h>
-#include <hip/hip_runtime_api.h>
-#include <hipcub/hipcub.hpp>
-#include "core/providers/rocm/cu_inc/common.cuh"
-#include "core/providers/rocm/rocm_common.h"
-
-namespace onnxruntime {
-namespace contrib {
-namespace rocm {
-
-static inline __device__ __host__ float sigmoid(float x) {
-  return 1.F / (1.F + expf(-x));
-}
-
-struct GroupSums {
-  // Is it the 1st element of the group?
-  int32_t flag;
-  // The sum.
-  float sum;
-  // The sum of squares.
-  float sumSq;
-};
-
-struct GroupSumsOp {
-  inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) {
-    GroupSums dst;
-    dst.sum = b.flag ? b.sum : (a.sum + b.sum);
-    dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq);
-    dst.flag = a.flag + b.flag;
-    return dst;
-  }
-};
-
-template <typename T, typename U, int ILP>
-inline __device__ void UpdateSum(const T* src, int64_t offset, U& sum, U& sumSq) {
-  using VecT = onnxruntime::rocm::aligned_vector<T, ILP>;
-  const VecT input_v = *reinterpret_cast<const VecT*>(src + offset);
-
-#pragma unroll
-  for (int i = 0; i < ILP; i++) {
-    const U val = static_cast<U>(input_v.val[i]);
-    sum += val;
-    sumSq += val * val;
-  }
-}
-
-template <typename T, int ThreadsPerBlock, int ILP>
-__global__ void groupNormNHWCSumKernel(const T* src, float* redBuffer, int32_t cPerBlock, int32_t hwPerBlock, int32_t hw,
-                                       int32_t hwc, int32_t c, int32_t cPerGroup, int32_t groups, int32_t groupsPerBlock) {
-  // The object in charge of doing the sums for the different blocks.
-  typedef hipcub::BlockScan<GroupSums, ThreadsPerBlock> BlockScan;
-
-  // Allocate shared memory for BlockScan.
-  __shared__ typename BlockScan::TempStorage tempStorage;
-  // Allocate shared memory for the groups. We could reduce the amount of shared
-  // memory reserved.
-  __shared__ float2 smem[ThreadsPerBlock];
-
-  // The instance in the batch.
-  int32_t ni = blockIdx.z;
-  // The channel loaded by that thread (ILP channels per thread).
-  int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP;
-
-  // The first activation loaded by that block.
-  int32_t hwBegin = blockIdx.y * hwPerBlock;
-  // The last activation loaded by that block.
-  int32_t hwEnd = min(hwBegin + hwPerBlock, hw);
-
-  // The sums.
-  float sum = 0.F;
-  float sumSq = 0.F;
-
-  // Iterate over the activations to compute the sums.
-  if (ci < c) {
-    for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
-      // The offset.
-      int64_t offset = static_cast<int64_t>(ni) * hwc + static_cast<int64_t>(hwi) * c + ci;
-      UpdateSum<T, float, ILP>(src, offset, sum, sumSq);
-    }
-  }
-
-  // The group that thread works on and the channel in the group (modulus).
-  int32_t gi = threadIdx.x * ILP / cPerGroup;
-  int32_t cj = threadIdx.x * ILP - cPerGroup * gi;
-
-  // The data for the summations.
-  GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq};
-
-  // Do the segmented scan.
-  GroupSums out;
-  BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());
-
-  // Store the results for the groups in shared memory (to produce coalesced
-  // stores later).
-  if (cj == cPerGroup - ILP) {  // ILP channels per thread
-    smem[gi] = make_float2(out.sum, out.sumSq);
-  }
-
-  // Make sure the data is in shared memory.
-  __syncthreads();
-
-  // The global group index.
-  int32_t gj = blockIdx.x * groupsPerBlock + threadIdx.x;
-
-  // Threads that have nothing left to do, exit.
-  if (threadIdx.x >= groupsPerBlock || gj >= groups) {
-    return;
-  }
-
-  // The first threads (those storing to global memory, load the values).
-  float2 sums = smem[threadIdx.x];
-
-  // Store to global memory.
-  atomicAdd(&redBuffer[(2 * ni + 0) * groups + gj], sums.x);
-  atomicAdd(&redBuffer[(2 * ni + 1) * groups + gj], sums.y);
-}
-
-template <typename T, typename U, int32_t ILP>
-__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, U mean, U invStdDev,
-                                 const U* gamma_v, const U* beta_v, bool swish) {
-  using VecT = onnxruntime::rocm::aligned_vector<T, ILP>;
-  const VecT input_v = *reinterpret_cast<const VecT*>(src + offset);
-  VecT output_v;
-
-#pragma unroll
-  for (int i = 0; i < ILP; i++) {
-    U val = static_cast<U>(input_v.val[i]);
-    val = (val - mean) * invStdDev;
-    val = gamma_v[i] * val + beta_v[i];
-
-    if (swish) {
-      val = val * sigmoid(val);
-    }
-    output_v.val[i] = static_cast<T>(val);
-  }
-  *(reinterpret_cast<VecT*>(dst + offset)) = output_v;
-}
-
-template <typename T, int ThreadsPerBlock, int ILP>
-__global__ void groupNormNHWCScaleKernel(T* dst, const T* src, const float* gamma, const float* beta, const float* redBuffer, float epsilon, int32_t c, int32_t cPerBlock,
-                                         int32_t cPerGroup, int32_t groups, int32_t hwc, float invHWC, int32_t hw, int32_t hwPerBlock, bool withSwish) {
-  // The channel loaded by that thread (ILP channels per thread for F16x2).
-  int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP;
-  if (ci >= c) {
-    return;
-  }
-
-  // The instance in the batch.
-  int32_t ni = blockIdx.z;
-
-  // The group that thread works on and the channel in the group (modulus).
-  int32_t gi = ci / cPerGroup;
-
-  // Load the sum and sum of squares for the group.
-  float sum = 0.F, sumSq = 0.F;
-  if (gi < groups) {
-    sum = redBuffer[(2 * ni + 0) * groups + gi];
-    sumSq = redBuffer[(2 * ni + 1) * groups + gi];
-  }
-
-  using VecF = onnxruntime::rocm::aligned_vector<float, ILP>;
-
-  const VecF gamma_v = *reinterpret_cast<const VecF*>(gamma + ci);
-  const VecF beta_v = *reinterpret_cast<const VecF*>(beta + ci);
-
-  // Compute the mean.
-  float mean = sum * invHWC;
-  // Compute the variance.
-  float var = sumSq * invHWC - (mean * mean);
-  // Compute the inverse of the stddev.
-  float invStdDev = var <= 0.F ? 1.F : rsqrtf(var + epsilon);
-
-  // The first activation loaded by that block.
-  int32_t hwBegin = blockIdx.y * hwPerBlock;
-  // The last activation loaded by that block.
-  int32_t hwEnd = min(hwBegin + hwPerBlock, hw);
-
-  // Iterate over the activations to compute the sums.
-  for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
-    // The src/dst offset.
-    int64_t offset = (int64_t)ni * hwc + hwi * c + ci;
-
-    // Fetch ILP channels per thread.
-    computeGroupNorm<T, float, ILP>(src, dst, offset, mean, invStdDev, gamma_v.val, beta_v.val, withSwish);
-  }
-}
-
-}  // namespace rocm
-}  // namespace contrib
-}  // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
index b7b9441ac997d..c6ca16bfdfc80 100644
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
+++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
@@ -20,21 +20,21 @@ namespace rocm {
 
 namespace {
 
-template <typename T, bool WithSwish>
+template <typename T, bool WithSilu>
 std::string GetGroupNormTritonGroupName() {
   std::string ret = "GroupNormTriton_";
-  std::string swish_suffix = WithSwish ? "Swish_" : "Pass_";
-  ret += swish_suffix;
+  std::string silu_suffix = WithSilu ? "Silu_" : "Pass_";
+  ret += silu_suffix;
   ret += GetDataTypeName<T>();
   return ret;
 }
 
 }  // namespace
 
-template <typename T, bool WithSwish>
+template <typename T, bool WithSilu>
 auto GetTritonGroupNormNHWCTypeStringAndOps() {
-  std::vector<std::pair<std::string, tunable::Op<GroupNormNHWCParams<T>>>> ret;
-  auto group_name = GetGroupNormTritonGroupName<T, WithSwish>();
+  std::vector<std::pair<std::string, tunable::Op<GroupNormNHWCTunableParams<T>>>> ret;
+  auto group_name = GetGroupNormTritonGroupName<T, WithSilu>();
   auto* kernel_list = GetOrtTritonKernelByGroup(group_name);
   if (kernel_list == nullptr) {
     return ret;
@@ -45,36 +45,50 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
     auto* metadata = GetOrtTritonKernelMetadata(i);
     auto block_size = metadata->constants.at("BLOCK_SIZE");
     auto hw_size = metadata->constants.at("HW_SIZE");
-    auto impl = [i, block_size, hw_size](const GroupNormNHWCParams<T>* params) -> Status {
+    auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams<T>* params) -> Status {
       TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
-          params->cPerGroup > block_size || params->cPerGroup * 2 <= block_size,
-          "Arg block_size (", block_size, ") is not the next power of 2 of cPerGroup (", params->cPerGroup, ").");
+          params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size,
+          "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (",
+          params->channels_per_group, ").");
       TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
           params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ").");
-      if constexpr (WithSwish) {
-        TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->withSwish, "Swish version does not support GN w/o swish.");
+      if constexpr (WithSilu) {
+        TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu.");
       } else {
-        TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->withSwish, "Pass version does not support GN w/ swish.");
+        TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu.");
       }
       // Construct args for launch kernel
       struct {
-        void* X;
-        void* Y;
+        const void* src;
+        const void* skip;
+        const void* bias;
+        void* out;
+        void* add_out;
         const void* gamma;
         const void* beta;
         int hw;
         int c;
         int c_per_group;
         float eps;
+        bool has_skip;
+        bool has_bias;
+        bool broadcast_skip;
       } args = {
-          (void*)params->src,
+          (const void*)params->src,
+          (const void*)params->skip,
+          (const void*)params->bias,
           (void*)params->dst,
+          (void*)params->skip_workspace,
           (const void*)params->gamma,
           (const void*)params->beta,
           params->hw,
           params->c,
-          params->cPerGroup,
-          params->epsilon};
+          params->channels_per_group,
+          params->epsilon,
+          params->skip != nullptr,
+          params->bias != nullptr,
+          params->broadcast_skip,
+      };
 
       // Grid dim is (batch_count, groups, 1)
       return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args));
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
index 56b3a030b289e..5ba96ebc117f0 100644
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
+++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
@@ -12,16 +12,22 @@
 @triton.jit
 def group_norm_kernel(
     input_ptr,
+    skip_ptr,
+    bias_ptr,
     output_ptr,
+    add_out_ptr,
     gamma_ptr,
     beta_ptr,
     img_size,
     c,
     c_per_group,
     eps,
+    has_skip,
+    has_bias,
+    broadcast_skip,
     BLOCK_SIZE: tl.constexpr,
     HW_SIZE: tl.constexpr,
-    ACTIVATION_SWISH: tl.constexpr,
+    ACTIVATION_SILU: tl.constexpr,
 ):
     row_x = tl.program_id(0)
     row_y = tl.program_id(1)
@@ -36,14 +42,35 @@ def group_norm_kernel(
     offsets = hw[:, None] * c + cols[None, :]
     mask = (cols < c_per_group)[None, :]
 
+    bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+    if has_skip:
+        add_out_ptr += row_x * stride + row_y * c_per_group
+        if broadcast_skip:
+            broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group
+            bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)
+        else:
+            skip_ptr += row_x * stride + row_y * c_per_group
+    if has_bias:
+        bias_ptr += row_y * c_per_group
+        bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)
+
     # Calculate mean and variance
     _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32)
     _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32)
     for i in range(tl.cdiv(img_size, HW_SIZE)):
         x_ptr = input_ptr + i * HW_SIZE * c
         a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+        if has_skip and not broadcast_skip:
+            s_ptr = skip_ptr + i * HW_SIZE * c
+            s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+            a += s
+        if has_bias or broadcast_skip:
+            a += bias
         _sum += a
         _square_sum += a * a
+        if has_skip:
+            add_y_ptr = add_out_ptr + i * HW_SIZE * c
+            tl.store(add_y_ptr + offsets, a, mask=mask)
 
     # Set axis=None (or leave it unspecified) to reduce all axes.
     # TODO: In older Triton we have to reduce an axis at a time, but in our case
@@ -57,12 +84,16 @@ def group_norm_kernel(
     gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32)
     beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32)
     for i in range(tl.cdiv(img_size, HW_SIZE)):
-        x_ptr = input_ptr + i * HW_SIZE * c
         y_ptr = output_ptr + i * HW_SIZE * c
-        x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+        if has_skip:
+            add_y_ptr = add_out_ptr + i * HW_SIZE * c
+            x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+        else:
+            x_ptr = input_ptr + i * HW_SIZE * c
+            x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
         x_hat = (x - group_mean) * rstd
         y = x_hat * gamma + beta
-        if ACTIVATION_SWISH:
+        if ACTIVATION_SILU:
             y *= tl.sigmoid(y)
         tl.store(y_ptr + offsets, y, mask=mask)
 
@@ -71,27 +102,27 @@ def group_norm_kernel(
 # blocks = [16, 32, 64, 128, 256, 512]
 # hw_sizes = [8, 16, 32, 64, 128, 256, 512]
 # but this will result in too many functions and slow down the compilation.
-with_swish = [True, False]
+with_silu = [True, False]
 dtypes = ["fp32", "fp16"]
 blocks = [16, 32, 64, 128]
 hw_sizes = [8, 16, 32, 64, 128, 256]
 warps = [1, 2, 4, 8, 16]
 name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}"
-sig_pattern = "*{},*{},*fp32,*fp32,i32,i32,i32,fp32"
+sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1"
 group_pattern = "GroupNormTriton_{}_{}"
 
 
 def get_function_table():
     func_table = []
 
-    for swish, dtype, hw_size, warp, b in product(with_swish, dtypes, hw_sizes, warps, blocks):
-        swish_suffix = "Swish" if swish else "Pass"
-        name = name_pattern.format(swish_suffix, dtype, b, hw_size, warp)
-        group = group_pattern.format(swish_suffix, dtype)
-        sig = sig_pattern.format(dtype, dtype)
+    for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks):
+        silu_suffix = "Silu" if silu else "Pass"
+        name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp)
+        group = group_pattern.format(silu_suffix, dtype)
+        sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype)
         kwargs = {
             "num_warps": warp,
-            "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SWISH": int(swish)},
+            "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)},
         }
         func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs}
         func_table.append(func_desc)
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h
index 25d820f7ed326..e6831f764b418 100644
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h
+++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h
@@ -20,115 +20,117 @@ namespace rocm {
 using onnxruntime::rocm::GPU_WARP_SIZE;
 
 template <typename T>
-void groupNormNHWCSum(const GroupNormNHWCParams<T>* params) {
-  // Make sure the values are as we expect.
-  ORT_ENFORCE(params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0);
-  // Make sure a group does not span multiple blocks.
-  ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0);
-
+void GroupNormNHWCSum(const GroupNormNHWCTunableParams<T>* params) {
   dim3 grid;
 
   // The number of blocks to compute all the channels.
-  grid.x = params->c / params->cPerBlock;
+  grid.x = DivUp(params->c, params->channels_per_block);
   // The number of blocks to compute all the activations in a given instance.
-  grid.y = CeilDiv(params->hw, params->hwPerBlock);
+  grid.y = DivUp(params->hw, params->hw_per_block);
   // The number of instances.
   grid.z = params->n;
 
-#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize)                \
-  groupNormNHWCSumKernel<T, ThreadsPerBlock, VecSize>                 \
-      <<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>(         \
-          params->src, params->redBuffer, params->cPerBlock,          \
-          params->hwPerBlock, params->hw, params->hwc, params->c,     \
-          params->cPerGroup, params->groups, params->groupsPerBlock); \
+#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize)                                                   \
+  GroupNormNHWCSumKernel<T, ThreadsPerBlock, VecSize>                                                    \
+      <<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>(                                            \
+          params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias,     \
+          params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c,          \
+          params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \
   break;
 
-  switch (params->cPerBlock) {
-    case 320:
-      LAUNCH_GROUPNORM_SUM(256, 2)
-    case 480:
-      LAUNCH_GROUPNORM_SUM(256, 2)
+  // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
+  switch (params->threads_per_block) {
     case 256:
-      LAUNCH_GROUPNORM_SUM(128, 2)
+      LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD)
+    case 192:
+      LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD)
+    case 160:
+      LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD)
     case 128:
-      LAUNCH_GROUPNORM_SUM(64, 2)
+      LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD)
+    case 64:
+      LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD)
     default:
       ORT_NOT_IMPLEMENTED("Not implemented");
   }
 }
 
 template <typename T, int ThreadsPerBlock, int VecSize>
-Status GroupNormNHWCSumOp(const GroupNormNHWCParams<T>* params) {
+Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams<T>* params) {
   dim3 grid;
-  grid.x = params->c / params->cPerBlock;
-  grid.y = CeilDiv(params->hw, params->hwPerBlock);
+  grid.x = DivUp(params->c, params->channels_per_block);
+  grid.y = DivUp(params->hw, params->hw_per_block);
   grid.z = params->n;
 
-  groupNormNHWCSumKernel<T, ThreadsPerBlock, VecSize>
+  GroupNormNHWCSumKernel<T, ThreadsPerBlock, VecSize>
       <<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>(
-          params->src, params->redBuffer, params->cPerBlock, params->hwPerBlock,
-          params->hw, params->hwc, params->c, params->cPerGroup, params->groups, params->groupsPerBlock);
+          params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias,
+          params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c,
+          params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip);
   return HIP_CALL(hipGetLastError());
 }
 
 template <typename T>
-void groupNormNHWCScale(const GroupNormNHWCParams<T>* params) {
-  // Make sure the dimensions are aligned with what we expect.
-  ORT_ENFORCE(params->c % params->cPerBlock == 0);
-  // Make sure a group does not span multiple blocks.
-  ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0);
-
+void GroupNormNHWCScale(const GroupNormNHWCTunableParams<T>* params) {
   dim3 grid;
 
   // The number of blocks to compute all the channels.
-  grid.x = params->c / params->cPerBlock;
+  grid.x = DivUp(params->c, params->channels_per_block);
   // The number of blocks to compute all the activations in a given instance.
-  grid.y = CeilDiv(params->hw, params->hwPerBlock);
+  grid.y = DivUp(params->hw, params->hw_per_block);
   // The number of instances.
   grid.z = params->n;
 
-#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize)                    \
-  groupNormNHWCScaleKernel<T, ThreadsPerBlock, VecSize>                     \
-      <<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>(               \
-          params->dst, params->src, params->gamma, params->beta,            \
-          params->redBuffer, params->epsilon, params->c, params->cPerBlock, \
-          params->cPerGroup, params->groups, params->hwc, params->invHWC,   \
-          params->hw, params->hwPerBlock, params->withSwish);               \
+#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize)                                               \
+  GroupNormNHWCScaleKernel<T, VecSize>                                                                 \
+      <<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>(                                          \
+          params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \
+          params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block,            \
+          params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group,  \
+          params->hw, params->hw_per_block, params->use_silu);                                         \
   break;
 
-  switch (params->cPerBlock) {
-    case 320:
-      LAUNCH_GROUPNORM_SCALE(256, 2)
-    case 480:
-      LAUNCH_GROUPNORM_SCALE(256, 2)
+  // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
+  switch (params->threads_per_block) {
     case 256:
-      LAUNCH_GROUPNORM_SCALE(128, 2)
+      LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD)
+    case 192:
+      LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD)
+    case 160:
+      LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD)
     case 128:
-      LAUNCH_GROUPNORM_SCALE(64, 2)
+      LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD)
+    case 64:
+      LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD)
     default:
       ORT_NOT_IMPLEMENTED("Not implemented");
   }
 }
 
 template <typename T, int ThreadsPerBlock, int VecSize>
-Status GroupNormNHWCScaleOp(const GroupNormNHWCParams<T>* params) {
+Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams<T>* params) {
   dim3 grid;
-  grid.x = params->c / params->cPerBlock;
-  grid.y = CeilDiv(params->hw, params->hwPerBlock);
+  grid.x = DivUp(params->c, params->channels_per_block);
+  grid.y = DivUp(params->hw, params->hw_per_block);
   grid.z = params->n;
 
-  groupNormNHWCScaleKernel<T, ThreadsPerBlock, VecSize>
+  GroupNormNHWCScaleKernel<T, VecSize>
       <<<grid, ThreadsPerBlock, 0, params->StreamHandle()>>>(
-          params->dst, params->src, params->gamma, params->beta, params->redBuffer, params->epsilon, params->c, params->cPerBlock,
-          params->cPerGroup, params->groups, params->hwc, params->invHWC, params->hw, params->hwPerBlock, params->withSwish);
+          params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace,
+          params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group,
+          params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block,
+          params->use_silu);
   return HIP_CALL(hipGetLastError());
 }
 
 template <typename T, int ThreadsPerBlock, int VecSize>
 class GroupNormNHWCOp {
  public:
-  Status operator()(const GroupNormNHWCParams<T>* params) {
-    HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle()));
+  Status operator()(const GroupNormNHWCTunableParams<T>* params) {
+    HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer,
+                                       0,
+                                       GetGroupNormWorkspaceSizeInBytes(params->n, params->groups),
+                                       params->StreamHandle()));
     auto status = GroupNormNHWCSumOp<T, ThreadsPerBlock, VecSize>(params);
     ORT_RETURN_IF_ERROR(status);
     HIP_RETURN_IF_ERROR(hipGetLastError());
@@ -138,29 +140,30 @@ class GroupNormNHWCOp {
     return Status::OK();
   }
 
-  Status IsSupported(const GroupNormNHWCParams<T>* params) {
+  Status IsSupported(const GroupNormNHWCTunableParams<T>* params) {
     TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
-        !(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0),
-        "The number of channels (", params->c, ") or the number of channels per group (", params->cPerGroup,
+        !(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0),
+        "The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group,
         ") isn't divisible by the number of vector size: ", VecSize);
-    TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock % params->cPerGroup == 0 &&
-                                                params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0),
-                                              "The value of attributes don't meet the requirements.");
-    TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock <= ThreadsPerBlock * VecSize &&
-                                                params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize),
+    TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize &&
+                                                params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize),
                                               "Configuration: Threads (", ThreadsPerBlock, "), vector size (",
-                                              VecSize, ") is redundant for the number of channels per group: ", params->cPerBlock);
+                                              VecSize, ") is redundant for the number of channels per group: ",
+                                              params->channels_per_block);
 
     return Status::OK();
   }
 };
 
 template <typename T>
-Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams<T>* params) {
-  HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle()));
-  groupNormNHWCSum<T>(params);
+Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams<T>* params) {
+  HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer,
+                                     0,
+                                     GetGroupNormWorkspaceSizeInBytes(params->n, params->groups),
+                                     params->StreamHandle()));
+  GroupNormNHWCSum<T>(params);
   HIP_RETURN_IF_ERROR(hipGetLastError());
-  groupNormNHWCScale<T>(params);
+  GroupNormNHWCScale<T>(params);
   HIP_RETURN_IF_ERROR(hipGetLastError());
   return Status::OK();
 }
@@ -178,30 +181,30 @@ Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams<T>* params) {
   ADD_OP_FOR_ALL_VEC_SIZE(name, 320)
 
 template <typename T>
-class GroupNormNHWCTunableOp : public TunableOp<GroupNormNHWCParams<T>> {
+class GroupNormNHWCTunableOp : public TunableOp<GroupNormNHWCTunableParams<T>> {
  public:
   GroupNormNHWCTunableOp() {
     this->RegisterOp(GroupNormNHWCStaticSelection<T>);
     ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp)
 
 #ifdef USE_COMPOSABLE_KERNEL
-    for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSwish=*/false>()) {
+    for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSilu=*/false>()) {
       ORT_UNUSED_PARAMETER(_);
       this->RegisterOp(std::move(op));
     }
 
-    for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSwish=*/true>()) {
+    for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps<T, /*AccT=*/float, /*WithSilu=*/true>()) {
       ORT_UNUSED_PARAMETER(_);
       this->RegisterOp(std::move(op));
     }
 #endif  // USE_COMPOSABLE_KERNEL
 
 #ifdef USE_TRITON_KERNEL
-    for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps<T, /*WithSwish=*/false>()) {
+    for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps<T, /*WithSilu=*/false>()) {
       ORT_UNUSED_PARAMETER(_);
       this->RegisterOp(std::move(op));
     }
-    for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps<T, /*WithSwish=*/true>()) {
+    for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps<T, /*WithSilu=*/true>()) {
       ORT_UNUSED_PARAMETER(_);
       this->RegisterOp(std::move(op));
     }
diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
index 55cd6a1d112f5..e19a976f3141c 100644
--- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
@@ -93,6 +93,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Samp
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
@@ -150,7 +151,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather);
 #endif
 
-#if defined(USE_MPI) && defined(ORT_USE_NCCL)
+#ifdef ORT_USE_NCCL
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll);
@@ -246,6 +247,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization)>,
@@ -309,7 +311,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather)>,
 #endif
 
-#if defined(USE_MPI) && defined(ORT_USE_NCCL)
+#ifdef ORT_USE_NCCL
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll)>,
diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc
index 711fd595e90fd..be881f6bc4bc2 100644
--- a/onnxruntime/core/common/cpuid_info.cc
+++ b/onnxruntime/core/common/cpuid_info.cc
@@ -52,6 +52,13 @@
 
 #if defined(CPUINFO_SUPPORTED)
 #include <cpuinfo.h>
+#if defined(CPUIDINFO_ARCH_ARM)
+namespace onnxruntime {
+// The following function is declared in "core/common/cpuid_uarch.h" but we cannot include the whole header file because
+//  some of its symbols are conflict with <cpuinfo.h>
+void decodeMIDR(uint32_t midr, uint32_t uarch[1]);
+}  // namespace onnxruntime
+#endif
 #else
 #include "core/common/cpuid_uarch.h"
 #endif  // CPUINFO_SUPPORTED
@@ -142,11 +149,6 @@ void CPUIDInfo::ArmLinuxInit() {
   // Pytorch CPUINFO only works on ARM linux or android
   // Assuming no hyper-threading, no NUMA groups
 #ifdef CPUINFO_SUPPORTED
-  pytorch_cpuinfo_init_ = cpuinfo_initialize();
-  if (!pytorch_cpuinfo_init_) {
-    LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features.";
-    return;
-  }
   is_hybrid_ = cpuinfo_get_uarchs_count() > 1;
   has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot();
   has_fp16_ = cpuinfo_has_arm_neon_fp16_arith();
@@ -239,52 +241,24 @@ void CPUIDInfo::ArmWindowsInit() {
       lastUarch = uarch;
     }
   }
-
-  switch (lastUarch) {
-    case cpuinfo_uarch_cortex_a55:
-    case cpuinfo_uarch_cortex_a55r0:
-    case cpuinfo_uarch_cortex_a76:
-    case cpuinfo_uarch_neoverse_n1:
-    case cpuinfo_uarch_cortex_a77:
-    case cpuinfo_uarch_exynos_m4:
-    case cpuinfo_uarch_exynos_m5:
-      has_fp16_ = true;
-      break;
-    default:
-      break;
-  }
-  if (!has_fp16_) {
-    /*
-     * Detecting fp16 support. Different cores should have the same instruction set.
-     * So we just check the first ID_AA64PFR0_EL1
-     *  Op0(0b11), Op1(0b000), CRn(0b0000), CRm(0b0100), Op2(0b000),
-     */
-    uint64_t ID_AA64PFR0_EL1;
-    unsigned long valsize = sizeof(uint64_t);
-    auto retCode = ::RegGetValueA(
-        HKEY_LOCAL_MACHINE,
-        "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0",
-        "CP 4020", RRF_RT_REG_QWORD, nullptr,
-        &ID_AA64PFR0_EL1, &valsize);
-    if (retCode == ERROR_SUCCESS) {
-      // AdvSIMD, bits [23:20]
-      auto advSimd = ID_AA64PFR0_EL1 >> 20;
-      if ((advSimd & 0xfULL) == 1) {
-        has_fp16_ = true;
-      }
-    }
-  }
 #endif /* Application Family or OneCore Family */
 
   has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0);
 #else
   has_arm_neon_dot_ = false;
 #endif
-  has_fp16_ |= has_arm_neon_dot_;
-  /* TODO: implement them when hw+sw is available for testing these features */
-  has_arm_neon_i8mm_ = false;
-  has_arm_sve_i8mm_ = false;
-  has_arm_neon_bf16_ = false;
+
+  if (pytorch_cpuinfo_init_) {
+    has_fp16_ = cpuinfo_has_arm_neon_fp16_arith();
+    has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm();
+    has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm();
+    has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16();
+  } else {
+    has_fp16_ = false;
+    has_arm_neon_i8mm_ = false;
+    has_arm_sve_i8mm_ = false;
+    has_arm_neon_bf16_ = false;
+  }
 }
 
 #endif /* (arm or arm64) and windows */
@@ -304,5 +278,21 @@ uint32_t CPUIDInfo::GetCurrentCoreIdx() const {
   return 0xFFFFFFFF;  // don't know how to get core index
 #endif
 }
-
+CPUIDInfo::CPUIDInfo() {
+#ifdef CPUIDINFO_ARCH_X86
+  X86Init();
+#elif defined(CPUIDINFO_ARCH_ARM)
+#if CPUINFO_SUPPORTED
+  pytorch_cpuinfo_init_ = cpuinfo_initialize();
+  if (!pytorch_cpuinfo_init_) {
+    LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features.";
+  }
+#endif
+#ifdef __linux__
+  ArmLinuxInit();
+#elif defined(_WIN32)
+  ArmWindowsInit();
+#endif /* (arm or arm64) and windows */
+#endif
+}
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h
index 2f8041e39f680..a3936b4bd11a6 100644
--- a/onnxruntime/core/common/cpuid_info.h
+++ b/onnxruntime/core/common/cpuid_info.h
@@ -93,17 +93,7 @@ class CPUIDInfo {
   }
 
  private:
-  CPUIDInfo() {
-#ifdef CPUIDINFO_ARCH_X86
-    X86Init();
-#elif defined(CPUIDINFO_ARCH_ARM)
-#ifdef __linux__
-    ArmLinuxInit();
-#elif defined(_WIN32)
-    ArmWindowsInit();
-#endif /* (arm or arm64) and windows */
-#endif
-  }
+  CPUIDInfo();
   bool has_amx_bf16_{false};
   bool has_avx_{false};
   bool has_avx2_{false};
@@ -131,11 +121,13 @@ class CPUIDInfo {
 #ifdef CPUIDINFO_ARCH_X86
 
   void X86Init();
-
 #elif defined(CPUIDINFO_ARCH_ARM)
+  // Now the following var is only used in ARM build, but later one we may expand the usage.
+  bool pytorch_cpuinfo_init_{false};
+#endif
+
 #ifdef __linux__
 
-  bool pytorch_cpuinfo_init_{false};
   void ArmLinuxInit();
 
 #elif defined(_WIN32)
@@ -143,7 +135,6 @@ class CPUIDInfo {
   void ArmWindowsInit();
 
 #endif /* (arm or arm64) and windows */
-#endif
 };
 
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/common/flatbuffers.h b/onnxruntime/core/common/flatbuffers.h
new file mode 100644
index 0000000000000..0d61e1038a82c
--- /dev/null
+++ b/onnxruntime/core/common/flatbuffers.h
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#if defined(__GNUC__)
+#include "onnxruntime_config.h"
+#pragma GCC diagnostic push
+
+#ifdef HAS_SHORTEN_64_TO_32
+#pragma GCC diagnostic ignored "-Wshorten-64-to-32"
+#endif
+#endif
+
+#include "flatbuffers/flatbuffers.h"
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic pop
+#endif
\ No newline at end of file
diff --git a/onnxruntime/core/common/string_utils.h b/onnxruntime/core/common/string_utils.h
index eca1221e84cb8..716eed1afec51 100644
--- a/onnxruntime/core/common/string_utils.h
+++ b/onnxruntime/core/common/string_utils.h
@@ -65,5 +65,24 @@ inline std::string TrimString(std::string s) {
   return s;
 }
 
+/**
+ * @brief A consistent way to construct the full qualified op name.
+ */
+inline std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) {
+  return MakeString(domain, "::", op_type);
+}
+
+/**
+ * Use this simple hash to generate unique int by given string input.
+ */
+inline uint32_t GetHashFromString(const std::string& str_value) {
+  uint32_t hash = 0;
+  for (char const& c : str_value) {
+    hash = hash * 101 + c;
+  }
+
+  return hash;
+}
+
 }  // namespace utils
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/flatbuffers/flatbuffers_utils.h b/onnxruntime/core/flatbuffers/flatbuffers_utils.h
index 55bde0b2df806..76860d6ab1db8 100644
--- a/onnxruntime/core/flatbuffers/flatbuffers_utils.h
+++ b/onnxruntime/core/flatbuffers/flatbuffers_utils.h
@@ -5,7 +5,7 @@
 
 #include <unordered_map>
 
-#include "flatbuffers/flatbuffers.h"
+#include "core/common/flatbuffers.h"
 
 #include "core/common/common.h"
 #include "core/common/path_string.h"
diff --git a/onnxruntime/core/flatbuffers/schema/README.md b/onnxruntime/core/flatbuffers/schema/README.md
index 932478111ee68..96a2936c196ae 100644
--- a/onnxruntime/core/flatbuffers/schema/README.md
+++ b/onnxruntime/core/flatbuffers/schema/README.md
@@ -21,7 +21,7 @@ e.g.
     - /build/Linux/Debug/_deps/flatbuffers-build/flatc
 
 It is possible to use another flatc as well, e.g., from a separate installation. Note that ONNX Runtime uses
-FlatBuffers 1.12.
+FlatBuffers 23.5.26.
 
 To update the flatbuffers schemas and generated files:
 1. Modify [the ORT file format schema](ort.fbs) or [training checkpoint schema](ort_training_checkpoint.fbs).
diff --git a/onnxruntime/core/flatbuffers/schema/ort.fbs.h b/onnxruntime/core/flatbuffers/schema/ort.fbs.h
index e0f5342c29621..dc8a471f2d81f 100644
--- a/onnxruntime/core/flatbuffers/schema/ort.fbs.h
+++ b/onnxruntime/core/flatbuffers/schema/ort.fbs.h
@@ -4,7 +4,7 @@
 #ifndef FLATBUFFERS_GENERATED_ORT_ONNXRUNTIME_FBS_H_
 #define FLATBUFFERS_GENERATED_ORT_ONNXRUNTIME_FBS_H_
 
-#include "flatbuffers/flatbuffers.h"
+#include "core/common/flatbuffers.h"
 
 namespace onnxruntime {
 namespace fbs {
@@ -562,8 +562,8 @@ struct DimensionValue FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
-           VerifyField<int8_t>(verifier, VT_DIM_TYPE) &&
-           VerifyField<int64_t>(verifier, VT_DIM_VALUE) &&
+           VerifyField<int8_t>(verifier, VT_DIM_TYPE, 1) &&
+           VerifyField<int64_t>(verifier, VT_DIM_VALUE, 8) &&
            VerifyOffset(verifier, VT_DIM_PARAM) &&
            verifier.VerifyString(dim_param()) &&
            verifier.EndTable();
@@ -634,7 +634,7 @@ struct TensorTypeAndShape FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
-           VerifyField<int32_t>(verifier, VT_ELEM_TYPE) &&
+           VerifyField<int32_t>(verifier, VT_ELEM_TYPE, 4) &&
            VerifyOffset(verifier, VT_SHAPE) &&
            verifier.VerifyTable(shape()) &&
            verifier.EndTable();
@@ -687,7 +687,7 @@ struct MapType FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
-           VerifyField<int32_t>(verifier, VT_KEY_TYPE) &&
+           VerifyField<int32_t>(verifier, VT_KEY_TYPE, 4) &&
            VerifyOffset(verifier, VT_VALUE_TYPE) &&
            verifier.VerifyTable(value_type()) &&
            verifier.EndTable();
@@ -787,7 +787,7 @@ struct NodeEdge FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
-           VerifyField<uint32_t>(verifier, VT_NODE_INDEX) &&
+           VerifyField<uint32_t>(verifier, VT_NODE_INDEX, 4) &&
            VerifyOffset(verifier, VT_INPUT_EDGES) &&
            verifier.VerifyVector(input_edges()) &&
            VerifyOffset(verifier, VT_OUTPUT_EDGES) &&
@@ -911,11 +911,11 @@ struct Node FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
            verifier.VerifyString(doc_string()) &&
            VerifyOffset(verifier, VT_DOMAIN) &&
            verifier.VerifyString(domain()) &&
-           VerifyField<int32_t>(verifier, VT_SINCE_VERSION) &&
-           VerifyField<uint32_t>(verifier, VT_INDEX) &&
+           VerifyField<int32_t>(verifier, VT_SINCE_VERSION, 4) &&
+           VerifyField<uint32_t>(verifier, VT_INDEX, 4) &&
            VerifyOffset(verifier, VT_OP_TYPE) &&
            verifier.VerifyString(op_type()) &&
-           VerifyField<int32_t>(verifier, VT_TYPE) &&
+           VerifyField<int32_t>(verifier, VT_TYPE, 4) &&
            VerifyOffset(verifier, VT_EXECUTION_PROVIDER_TYPE) &&
            verifier.VerifyString(execution_provider_type()) &&
            VerifyOffset(verifier, VT_INPUTS) &&
@@ -1174,7 +1174,7 @@ struct TypeInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
     return VerifyTableStart(verifier) &&
            VerifyOffset(verifier, VT_DENOTATION) &&
            verifier.VerifyString(denotation()) &&
-           VerifyField<uint8_t>(verifier, VT_VALUE_TYPE) &&
+           VerifyField<uint8_t>(verifier, VT_VALUE_TYPE, 1) &&
            VerifyOffset(verifier, VT_VALUE) &&
            VerifyTypeInfoValue(verifier, value(), value_type()) &&
            verifier.EndTable();
@@ -1259,7 +1259,7 @@ struct OperatorSetId FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
     return VerifyTableStart(verifier) &&
            VerifyOffset(verifier, VT_DOMAIN) &&
            verifier.VerifyString(domain()) &&
-           VerifyField<int64_t>(verifier, VT_VERSION) &&
+           VerifyField<int64_t>(verifier, VT_VERSION, 8) &&
            verifier.EndTable();
   }
 };
@@ -1343,7 +1343,7 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
            verifier.VerifyString(doc_string()) &&
            VerifyOffset(verifier, VT_DIMS) &&
            verifier.VerifyVector(dims()) &&
-           VerifyField<int32_t>(verifier, VT_DATA_TYPE) &&
+           VerifyField<int32_t>(verifier, VT_DATA_TYPE, 4) &&
            VerifyOffset(verifier, VT_RAW_DATA) &&
            verifier.VerifyVector(raw_data()) &&
            VerifyOffset(verifier, VT_STRING_DATA) &&
@@ -1568,9 +1568,9 @@ struct Attribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
            verifier.VerifyString(name()) &&
            VerifyOffset(verifier, VT_DOC_STRING) &&
            verifier.VerifyString(doc_string()) &&
-           VerifyField<int32_t>(verifier, VT_TYPE) &&
-           VerifyField<float>(verifier, VT_F) &&
-           VerifyField<int64_t>(verifier, VT_I) &&
+           VerifyField<int32_t>(verifier, VT_TYPE, 4) &&
+           VerifyField<float>(verifier, VT_F, 4) &&
+           VerifyField<int64_t>(verifier, VT_I, 8) &&
            VerifyOffset(verifier, VT_S) &&
            verifier.VerifyString(s()) &&
            VerifyOffset(verifier, VT_T) &&
@@ -1759,12 +1759,12 @@ struct NodesToOptimizeIndices FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
     return VerifyTableStart(verifier) &&
            VerifyOffset(verifier, VT_NODE_INDICES) &&
            verifier.VerifyVector(node_indices()) &&
-           VerifyField<uint32_t>(verifier, VT_NUM_INPUTS) &&
-           VerifyField<uint32_t>(verifier, VT_NUM_OUTPUTS) &&
-           VerifyField<uint8_t>(verifier, VT_HAS_VARIADIC_INPUT) &&
-           VerifyField<uint8_t>(verifier, VT_HAS_VARIADIC_OUTPUT) &&
-           VerifyField<uint32_t>(verifier, VT_NUM_VARIADIC_INPUTS) &&
-           VerifyField<uint32_t>(verifier, VT_NUM_VARIADIC_OUTPUTS) &&
+           VerifyField<uint32_t>(verifier, VT_NUM_INPUTS, 4) &&
+           VerifyField<uint32_t>(verifier, VT_NUM_OUTPUTS, 4) &&
+           VerifyField<uint8_t>(verifier, VT_HAS_VARIADIC_INPUT, 1) &&
+           VerifyField<uint8_t>(verifier, VT_HAS_VARIADIC_OUTPUT, 1) &&
+           VerifyField<uint32_t>(verifier, VT_NUM_VARIADIC_INPUTS, 4) &&
+           VerifyField<uint32_t>(verifier, VT_NUM_VARIADIC_OUTPUTS, 4) &&
            verifier.EndTable();
   }
 };
@@ -1862,8 +1862,8 @@ struct DeprecatedNodeIndexAndKernelDefHash FLATBUFFERS_FINAL_CLASS : private fla
   }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
-           VerifyField<uint32_t>(verifier, VT_NODE_INDEX) &&
-           VerifyField<uint64_t>(verifier, VT_KERNEL_DEF_HASH) &&
+           VerifyField<uint32_t>(verifier, VT_NODE_INDEX, 4) &&
+           VerifyField<uint64_t>(verifier, VT_KERNEL_DEF_HASH, 8) &&
            verifier.EndTable();
   }
 };
@@ -2161,7 +2161,7 @@ struct Graph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
            VerifyOffset(verifier, VT_NODES) &&
            verifier.VerifyVector(nodes()) &&
            verifier.VerifyVectorOfTables(nodes()) &&
-           VerifyField<uint32_t>(verifier, VT_MAX_NODE_INDEX) &&
+           VerifyField<uint32_t>(verifier, VT_MAX_NODE_INDEX, 4) &&
            VerifyOffset(verifier, VT_NODE_EDGES) &&
            verifier.VerifyVector(node_edges()) &&
            verifier.VerifyVectorOfTables(node_edges()) &&
@@ -2390,7 +2390,7 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
-           VerifyField<int64_t>(verifier, VT_IR_VERSION) &&
+           VerifyField<int64_t>(verifier, VT_IR_VERSION, 8) &&
            VerifyOffset(verifier, VT_OPSET_IMPORT) &&
            verifier.VerifyVector(opset_import()) &&
            verifier.VerifyVectorOfTables(opset_import()) &&
@@ -2400,7 +2400,7 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
            verifier.VerifyString(producer_version()) &&
            VerifyOffset(verifier, VT_DOMAIN) &&
            verifier.VerifyString(domain()) &&
-           VerifyField<int64_t>(verifier, VT_MODEL_VERSION) &&
+           VerifyField<int64_t>(verifier, VT_MODEL_VERSION, 8) &&
            VerifyOffset(verifier, VT_DOC_STRING) &&
            verifier.VerifyString(doc_string()) &&
            VerifyOffset(verifier, VT_GRAPH) &&
@@ -2740,8 +2740,8 @@ struct ArgTypeAndIndex FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
-           VerifyField<int8_t>(verifier, VT_ARG_TYPE) &&
-           VerifyField<uint32_t>(verifier, VT_INDEX) &&
+           VerifyField<int8_t>(verifier, VT_ARG_TYPE, 1) &&
+           VerifyField<uint32_t>(verifier, VT_INDEX, 4) &&
            verifier.EndTable();
   }
 };
diff --git a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h
index d205c5eb8f409..62e6cf74394e5 100644
--- a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h
+++ b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h
@@ -4,7 +4,7 @@
 #ifndef FLATBUFFERS_GENERATED_ORTTRAININGCHECKPOINT_ONNXRUNTIME_FBS_H_
 #define FLATBUFFERS_GENERATED_ORTTRAININGCHECKPOINT_ONNXRUNTIME_FBS_H_
 
-#include "flatbuffers/flatbuffers.h"
+#include "core/common/flatbuffers.h"
 
 #include "ort.fbs.h"
 
@@ -59,7 +59,7 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
            VerifyOffset(verifier, VT_FROZEN_PARAMS) &&
            verifier.VerifyVector(frozen_params()) &&
            verifier.VerifyVectorOfTables(frozen_params()) &&
-           VerifyField<uint8_t>(verifier, VT_IS_NOMINAL_STATE) &&
+           VerifyField<uint8_t>(verifier, VT_IS_NOMINAL_STATE, 1) &&
            verifier.EndTable();
   }
 };
@@ -206,8 +206,8 @@ struct OptimizerGroup FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
     return VerifyTableStart(verifier) &&
            VerifyOffset(verifier, VT_GROUP_NAME) &&
            verifier.VerifyString(group_name()) &&
-           VerifyField<int64_t>(verifier, VT_STEP) &&
-           VerifyField<float>(verifier, VT_INITIAL_LEARNING_RATE) &&
+           VerifyField<int64_t>(verifier, VT_STEP, 8) &&
+           VerifyField<float>(verifier, VT_INITIAL_LEARNING_RATE, 4) &&
            VerifyOffset(verifier, VT_OPTIMIZER_STATES) &&
            verifier.VerifyVector(optimizer_states()) &&
            verifier.VerifyVectorOfTables(optimizer_states()) &&
@@ -289,7 +289,7 @@ struct IntProperty FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
     return VerifyTableStart(verifier) &&
            VerifyOffset(verifier, VT_NAME) &&
            verifier.VerifyString(name()) &&
-           VerifyField<int64_t>(verifier, VT_VALUE) &&
+           VerifyField<int64_t>(verifier, VT_VALUE, 8) &&
            verifier.EndTable();
   }
 };
@@ -353,7 +353,7 @@ struct FloatProperty FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
     return VerifyTableStart(verifier) &&
            VerifyOffset(verifier, VT_NAME) &&
            verifier.VerifyString(name()) &&
-           VerifyField<float>(verifier, VT_VALUE) &&
+           VerifyField<float>(verifier, VT_VALUE, 4) &&
            verifier.EndTable();
   }
 };
@@ -572,7 +572,7 @@ struct Checkpoint FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
-           VerifyField<int32_t>(verifier, VT_VERSION) &&
+           VerifyField<int32_t>(verifier, VT_VERSION, 4) &&
            VerifyOffset(verifier, VT_MODULE_STATE) &&
            verifier.VerifyTable(module_state()) &&
            VerifyOffset(verifier, VT_OPTIMIZER_GROUPS) &&
diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc
index ea7a6432a7507..fa6233476fe62 100644
--- a/onnxruntime/core/framework/allocation_planner.cc
+++ b/onnxruntime/core/framework/allocation_planner.cc
@@ -182,7 +182,6 @@ class PlannerImpl {
   // upstream_node_0 and upstream_node_1 are the immmediate upstream nodes of downstream_node
   // upstream_node_2 is the immediate nodes ahead of downstream_node in the same logic stream
   InlinedHashMap<onnxruntime::NodeIndex, InlinedHashSet<onnxruntime::NodeIndex>> dependence_graph_;
-  InlinedHashMap<onnxruntime::OrtValueIndex, InlinedHashSet<onnxruntime::NodeIndex>> value_consumer_map_;
   InlinedHashMap<onnxruntime::OrtValueIndex, onnxruntime::NodeIndex> value_node_map_;
 
   // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation:
@@ -295,7 +294,7 @@ class PlannerImpl {
   }
 #endif
 
-  // Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node.
+  // Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node.
   bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input,
                          bool* is_strided_tensor) {
     *is_strided_tensor = false;
@@ -530,6 +529,7 @@ class PlannerImpl {
 
     // Initialize allocation plan:
     plan_.allocation_plan.resize(num_ml_values);
+    for (int i = 0; static_cast<size_t>(i) < num_ml_values; i++) AllocPlan(i).reused_buffer = i;
   }
 
   bool HasExternalOutputs(const Node& node) const {
@@ -1065,7 +1065,8 @@ class PlannerImpl {
 
     // build the consumer list for each value
     int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1;
-    value_consumer_map_.reserve(num_ml_values);
+    InlinedHashMap<onnxruntime::OrtValueIndex, InlinedHashSet<onnxruntime::NodeIndex>> value_consumer_map;
+    value_consumer_map.reserve(num_ml_values);
 
     // iterate each stream from back, so the first element is the last consumer in single stream case
     for (auto& stream : stream_nodes_) {
@@ -1078,10 +1079,10 @@ class PlannerImpl {
             const auto& name = input.Name();
             int value_idx;
             ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx));
-            auto origin = Buffer(value_idx);
-            if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) {
+            auto origin = AllocPlan(value_idx).reused_buffer;
+            if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) {
               // add current node as consumer for origin buffer
-              value_consumer_map_[origin].insert(node_index);
+              value_consumer_map[origin].insert(node_index);
             }
           }
           return Status::OK();
@@ -1138,8 +1139,8 @@ class PlannerImpl {
                   std::cout << p_input_arg->Name() << " reused by " << p_output_arg->Name() << " as input" << std::endl;
                   allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
                   allocation_plan[output_idx_global].reused_buffer = reusable_input;
-                  value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(),
-                                                             value_consumer_map_[output_idx_global].end());
+                  value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(),
+                                                            value_consumer_map[output_idx_global].end());
                   reused.insert(reusable_input);
                   found_reusable = true;
                   break;
@@ -1168,8 +1169,8 @@ class PlannerImpl {
                   allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) {
                 allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
                 allocation_plan[output_idx_global].reused_buffer = reusable_input;
-                value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(),
-                                                           value_consumer_map_[output_idx_global].end());
+                value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(),
+                                                          value_consumer_map[output_idx_global].end());
                 reused.insert(reusable_input);
                 continue;
               }  // if
@@ -1187,11 +1188,11 @@ class PlannerImpl {
                 OrtValueIndex input_arg_index{};
                 if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() &&
                     allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) {
-                  if (value_consumer_map_[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) {
+                  if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) {
                     allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
                     allocation_plan[output_idx_global].reused_buffer = input_arg_index;
-                    value_consumer_map_[input_arg_index].insert(value_consumer_map_[output_idx_global].begin(),
-                                                                value_consumer_map_[output_idx_global].end());
+                    value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(),
+                                                               value_consumer_map[output_idx_global].end());
                     reused.insert(input_arg_index);
                   }
                 }
@@ -1266,7 +1267,7 @@ class PlannerImpl {
             }
 
             bool all_covered = true;
-            for (auto consumer : value_consumer_map_[output_idx_global]) {
+            for (auto consumer : value_consumer_map[output_idx_global]) {
               if (deps->find(consumer) == deps->end()) {
                 all_covered = false;
                 break;
@@ -1277,9 +1278,9 @@ class PlannerImpl {
               allocation_plan[downstream_value].reused_buffer = output_idx_global;
               get_reused = true;
               // add new consumer for the value to be reused
-              value_consumer_map_[output_idx_global].insert(value_node_map_[downstream_value]);
-              value_consumer_map_[output_idx_global].insert(value_consumer_map_[downstream_value].begin(),
-                                                            value_consumer_map_[downstream_value].end());
+              value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]);
+              value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(),
+                                                           value_consumer_map[downstream_value].end());
               node_iter = size_iter->second.erase(node_iter);
               if (size_iter->second.empty()) {
                 local_iter->second.erase(size_iter);
@@ -1342,8 +1343,9 @@ class PlannerImpl {
     ort_value_usecount.reserve(ort_value_info_.size());
 #endif
     for (size_t i = 0; i < stream_nodes_.size(); ++i) {
-      // compute use count first
+      // compute use count first. TODO(leca): call ComputeReuseCount() only once is enough!
       ORT_RETURN_IF_ERROR(ComputeReuseCount());
+      for (int j = 0; static_cast<size_t>(j) < ort_value_info_.size(); j++) Buffer(j) = j;
 #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
       if (i == 0) {
         for (auto ort_value_info : ort_value_info_) {
@@ -1693,8 +1695,8 @@ class PlannerImpl {
             const auto& name = input.Name();
             int value_idx;
             ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx));
-            auto origin = Buffer(value_idx);
-            if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) {
+            auto origin = AllocPlan(value_idx).reused_buffer;
+            if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) {
               // add current node as consumer for origin buffer
               value_consumers[origin].push_back(node_index);
             }
@@ -1773,7 +1775,12 @@ class PlannerImpl {
       execution_plan.emplace_back(std::make_unique<SequentialExecutionPlan::LogicStream>(node_device_mem_location));
       // 2. add steps to the execution plan
       for (auto node_index : stream_nodes_[0]) {
+#if defined(ORT_MINIMAL_BUILD)
         execution_plan[0]->steps_.emplace_back(std::make_unique<LaunchKernelStep>(node_index));
+#else
+        execution_plan[0]->steps_.emplace_back(std::make_unique<LaunchKernelStep>(node_index,
+                                                                                  graph_viewer_.GetNode(node_index)->Name()));
+#endif
       }
     } else {
       // graph with no nodes. e.g. subgraph of If might return the input as-is or a constant value from an initializer
@@ -1889,7 +1896,7 @@ class PlannerImpl {
                   // 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op.
                   //    for example, a resize cuda kernel consumer a tensor from MemCpyToHost cuda kernel on the same stream.
                   //    in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching
-                  OrtDevice::DeviceType output_arg_device = plan_.allocation_plan[output_arg_idx].location.Type();
+                  OrtDevice::DeviceType output_arg_device = AllocPlan(output_arg_idx).location.Type();
                   WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, output_arg_device);
                   if ((node_stream_map_[it->Index()] != i || output_arg_device == OrtDevice::CPU) && wait_handle != nullptr) {
                     if (node_to_notification.find(node_index) == node_to_notification.end()) {
@@ -1978,8 +1985,12 @@ class PlannerImpl {
           // add dependency for model graph
           dependence_graph_[it->Index()].insert(node_index);
         }
-        // push launch kernel command
+// push launch kernel command
+#if defined(ORT_MINIMAL_BUILD)
         execution_plan[i]->steps_.emplace_back(std::make_unique<LaunchKernelStep>(node_index));
+#else
+        execution_plan[i]->steps_.emplace_back(std::make_unique<LaunchKernelStep>(node_index, graph_viewer_.GetNode(node_index)->Name()));
+#endif
         // check if any notification generated by this node, if yes, push a activate
         auto notification_it = node_to_notification.find(node_index);
         if (notification_it != node_to_notification.end()) {
diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc
index 8c08152986cf6..32a5f749af084 100644
--- a/onnxruntime/core/framework/execution_frame.cc
+++ b/onnxruntime/core/framework/execution_frame.cc
@@ -204,6 +204,14 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtDevice& info) const {
 
 Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); }
 
+#ifdef ENABLE_TRAINING
+void IExecutionFrame::ReleaseAllMLValues() {
+  for (size_t ort_value_idx = 0; ort_value_idx < all_values_.size(); ort_value_idx++) {
+    all_values_[ort_value_idx] = OrtValue();
+  }
+}
+#endif
+
 Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) {
   if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast<size_t>(ort_value_idx) >= all_values_size_) {
     return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx);
@@ -831,7 +839,20 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const {
 // This method is not thread safe!
 // Return S_OK and nullptr if index map to a value that is an unused optional input/output
 Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) {
+#ifdef ENABLE_TRAINING
+  try {
+    auto status = AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape);
+    return status;
+  } catch (const std::exception& e) {
+    LOGS(session_state_.Logger(), WARNING)
+        << "Exception caught when allocating memory for ort_value with index: " << ort_value_idx
+        << "so clean up all OrtValues";
+    ReleaseAllMLValues();
+    return Status(ONNXRUNTIME, FAIL, e.what());
+  }
+#else
   return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape);
+#endif
 }
 
 void ExecutionFrame::VerifyOutputSizes(int output_index, const Node& node, const TensorShape& output_shape) {
diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h
index 1576c16684faa..18d210ffd48f7 100644
--- a/onnxruntime/core/framework/execution_frame.h
+++ b/onnxruntime/core/framework/execution_frame.h
@@ -67,6 +67,8 @@ class IExecutionFrame {
 
                      const std::unordered_map<int, OrtValue>& initializers);
   Status GetOutputs(gsl::span<const int> fetch_mlvalue_idxs, std::vector<OrtValue>& fetches);
+  // if OOM happens, then release all values, so session can run next batch.
+  void ReleaseAllMLValues();
 #endif
 
   // TO DO: make it thread safe
diff --git a/onnxruntime/core/framework/execution_steps.cc b/onnxruntime/core/framework/execution_steps.cc
index df19236d037c0..b647833cfd373 100644
--- a/onnxruntime/core/framework/execution_steps.cc
+++ b/onnxruntime/core/framework/execution_steps.cc
@@ -1,8 +1,11 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
+
 #include "core/framework/execution_steps.h"
 #include "core/framework/sequential_executor.h"
+
 namespace onnxruntime {
+
 BarrierStep::BarrierStep(size_t id, NodeIndex node_index) : SequentialExecutionPlan::ExecutionStep(node_index),
                                                             barrier_id_{id} {}
 
@@ -16,8 +19,8 @@ Status BarrierStep::Execute(StreamExecutionContext& ctx,
 }
 
 std::string BarrierStep::ToString() const {
-  return ::onnxruntime::MakeString("Set a barrier with id: ",
-                                   barrier_id_, ", count: ", 2, ".");
+  // Set a barrier with id: barrier_id_, count: 2.
+  return MakeString("Barrier - BarrierId: ", barrier_id_, ", Count: ", 2);
 }
 
 WaitOnEPStep::WaitOnEPStep(WaitNotificationFn handle,
@@ -42,11 +45,17 @@ Status WaitOnEPStep::Execute(StreamExecutionContext& ctx,
 }
 
 std::string WaitOnEPStep::ToString() const {
-  return ::onnxruntime::MakeString("WaitOnEPStep: wait on notification with id: ",
-                                   notification_idx_, ". ");
+  // Wait on notification with notification_idx_
+  return MakeString("WaitOnEP - NotificationId: ", notification_idx_);
 }
 
-LaunchKernelStep::LaunchKernelStep(NodeIndex index) : SequentialExecutionPlan::ExecutionStep(index) {}
+#if defined(ORT_MINIMAL_BUILD)
+LaunchKernelStep::LaunchKernelStep(NodeIndex index)
+    : SequentialExecutionPlan::ExecutionStep(index) {}
+#else
+LaunchKernelStep::LaunchKernelStep(NodeIndex index, std::string_view node_name)
+    : SequentialExecutionPlan::ExecutionStep(index), node_name_(node_name) {}
+#endif
 
 Status LaunchKernelStep::Execute(StreamExecutionContext& ctx,
                                  size_t stream_idx,
@@ -61,13 +70,17 @@ Status LaunchKernelStep::Execute(StreamExecutionContext& ctx,
     return Status::OK();
   }
 #endif
-  onnxruntime::Status status = ExecuteKernel(ctx, node_index_, stream_idx, terminate_flag, session_scope);
+  Status status = ExecuteKernel(ctx, node_index_, stream_idx, terminate_flag, session_scope);
   continue_flag = status.IsOK();
   return status;
 }
 
 std::string LaunchKernelStep::ToString() const {
-  return ::onnxruntime::MakeString("Launch kernel with node id: ", node_index_, ". ");
+#if defined(ORT_MINIMAL_BUILD)
+  return MakeString("LaunchKernel - ", "NodeIndex: ", node_index_);
+#else
+  return MakeString("LaunchKernel - ", "NodeIndex: ", node_index_, ", Name: ", node_name_);
+#endif
 }
 
 ActivateNotificationStep::ActivateNotificationStep(
@@ -89,12 +102,12 @@ Status ActivateNotificationStep::Execute(StreamExecutionContext& ctx,
 }
 
 std::string ActivateNotificationStep::ToString() const {
-  return ::onnxruntime::MakeString("ActivateNotificationStep: activate notification with id: ",
-                                   notification_idx_, ". ");
+  // Activate notification with id: notification_idx_
+  return MakeString("ActivateNotification - NotificationId: ", notification_idx_);
 }
 
-TriggerDownstreamStep::TriggerDownstreamStep(size_t trigger_point_index, NodeIndex node_index) : SequentialExecutionPlan::ExecutionStep(node_index),
-                                                                                                 trigger_point_index_(trigger_point_index) {}
+TriggerDownstreamStep::TriggerDownstreamStep(size_t trigger_point_index, NodeIndex node_index)
+    : SequentialExecutionPlan::ExecutionStep(node_index), trigger_point_index_(trigger_point_index) {}
 
 Status TriggerDownstreamStep::Execute(StreamExecutionContext& ctx,
                                       size_t /*stream_idx*/,
@@ -107,7 +120,8 @@ Status TriggerDownstreamStep::Execute(StreamExecutionContext& ctx,
 }
 
 std::string TriggerDownstreamStep::ToString() const {
-  return ::onnxruntime::MakeString("TriggerDownstreamStep: trigger downstream of trigger point: ",
-                                   trigger_point_index_, ".");
+  // Trigger downstream of trigger point: trigger_point_index_.
+  return MakeString("TriggerDownstream - TriggerPointIndex: ", trigger_point_index_);
 }
+
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/framework/execution_steps.h b/onnxruntime/core/framework/execution_steps.h
index b67b583900824..545dabc56b272 100644
--- a/onnxruntime/core/framework/execution_steps.h
+++ b/onnxruntime/core/framework/execution_steps.h
@@ -44,7 +44,11 @@ class WaitOnEPStep : public SequentialExecutionPlan::ExecutionStep {
 
 class LaunchKernelStep : public SequentialExecutionPlan::ExecutionStep {
  public:
+#if defined(ORT_MINIMAL_BUILD)
   LaunchKernelStep(NodeIndex index);
+#else
+  LaunchKernelStep(NodeIndex index, std::string_view node_name);
+#endif
 
   Status Execute(StreamExecutionContext& ctx,
                  size_t stream_idx,
@@ -53,6 +57,11 @@ class LaunchKernelStep : public SequentialExecutionPlan::ExecutionStep {
                  bool& continue_flag) override;
 
   std::string ToString() const override;
+
+#if !defined(ORT_MINIMAL_BUILD)
+ private:
+  std::string node_name_;
+#endif
 };
 
 class ActivateNotificationStep : public SequentialExecutionPlan::ExecutionStep {
diff --git a/onnxruntime/core/framework/kernel_type_str_resolver.h b/onnxruntime/core/framework/kernel_type_str_resolver.h
index 31a806dd52291..fea2a6ef3a439 100644
--- a/onnxruntime/core/framework/kernel_type_str_resolver.h
+++ b/onnxruntime/core/framework/kernel_type_str_resolver.h
@@ -7,7 +7,7 @@
 #include <string_view>
 #include <utility>
 
-#include "flatbuffers/flatbuffers.h"
+#include "core/common/flatbuffers.h"
 
 #if !defined(ORT_MINIMAL_BUILD)
 #include "core/graph/onnx_protobuf.h"
diff --git a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc
index 4f5fa9910b5df..473e78c3f5e25 100644
--- a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc
+++ b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc
@@ -5,7 +5,7 @@
 
 #include "core/framework/kernel_type_str_resolver_utils.h"
 
-#include "flatbuffers/flatbuffers.h"
+#include "core/common/flatbuffers.h"
 
 #include "core/common/common.h"
 #include "core/flatbuffers/schema/ort.fbs.h"
diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc
new file mode 100644
index 0000000000000..4dee1c14b3761
--- /dev/null
+++ b/onnxruntime/core/framework/node_unit.cc
@@ -0,0 +1,351 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
+
+#include "node_unit.h"
+#include "core/graph/graph_viewer.h"
+
+namespace onnxruntime {
+
+namespace {
+
+enum class QLinearOpType : uint8_t {
+  Unknown,  // Unknown or not a linear quantized op
+  DequantizeLinear,
+  QuantizeLinear,
+  QLinearConv,
+  QLinearMatMul,
+  QLinearAdd,
+  QLinearSigmoid,
+  QLinearAveragePool,
+  QLinearMul,
+  QLinearReduceMean,
+  QLinearConcat,
+  QLinearGlobalAveragePool,
+  QLinearLeakyRelu,
+};
+
+QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) {
+  const auto& op_type = node.OpType();
+  if (op_type == "DequantizeLinear")
+    return QLinearOpType::DequantizeLinear;
+  else if (op_type == "QuantizeLinear")
+    return QLinearOpType::QuantizeLinear;
+  else if (op_type == "QLinearConv")
+    return QLinearOpType::QLinearConv;
+  else if (op_type == "QLinearMatMul")
+    return QLinearOpType::QLinearMatMul;
+  else if (op_type == "QLinearAdd")
+    return QLinearOpType::QLinearAdd;
+  else if (op_type == "QLinearSigmoid")
+    return QLinearOpType::QLinearSigmoid;
+  else if (op_type == "QLinearAveragePool")
+    return QLinearOpType::QLinearAveragePool;
+  else if (op_type == "QLinearMul")
+    return QLinearOpType::QLinearMul;
+  else if (op_type == "QLinearReduceMean")
+    return QLinearOpType::QLinearReduceMean;
+  else if (op_type == "QLinearConcat")
+    return QLinearOpType::QLinearConcat;
+  else if (op_type == "QLinearGlobalAveragePool")
+    return QLinearOpType::QLinearGlobalAveragePool;
+  else if (op_type == "QLinearLeakyRelu")
+    return QLinearOpType::QLinearLeakyRelu;
+
+  return QLinearOpType::Unknown;
+}
+
+// Ops have 1 input
+bool IsUnaryQLinearOp(QLinearOpType type) {
+  return type == QLinearOpType::QLinearSigmoid ||
+         type == QLinearOpType::QLinearAveragePool ||
+         type == QLinearOpType::QLinearGlobalAveragePool ||
+         type == QLinearOpType::QLinearLeakyRelu ||
+         type == QLinearOpType::QLinearReduceMean;
+}
+
+// Ops have 2 inputs
+bool IsBinaryQLinearOp(QLinearOpType type) {
+  return type == QLinearOpType::QLinearConv ||
+         type == QLinearOpType::QLinearMatMul ||
+         type == QLinearOpType::QLinearAdd ||
+         type == QLinearOpType::QLinearMul;
+}
+
+// Ops have 1 or more inputs
+bool IsVariadicQLinearOp(QLinearOpType type) {
+  return type == QLinearOpType::QLinearConcat;
+}
+
+const std::vector<const Node*> GetQDQIONodes(const GraphViewer& graph_viewer,
+                                             const QDQ::NodeGroup& node_group, bool is_input) {
+  std::vector<const Node*> io_nodes;
+  const auto& src_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes;
+  io_nodes.reserve(src_nodes.size());
+  for (const auto& node_idx : src_nodes) {
+    io_nodes.push_back(graph_viewer.GetNode(node_idx));
+  }
+
+  return io_nodes;
+}
+
+// Get the input or output NodeUnitIODef(s) for the given QDQ NodeGroup
+std::vector<NodeUnitIODef> GetQDQIODefs(const Node& target_node, const QDQ::NodeGroup& node_group, bool is_input) {
+  const auto& dq_or_q_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes;
+  const auto target_node_io_defs = is_input ? target_node.InputDefs() : target_node.OutputDefs();
+  const size_t target_node_io_defs_size = target_node_io_defs.size();
+
+  // Find all the quantized IO defs and indices (for the input/output of the target node)
+  std::unordered_map<size_t, NodeUnitIODef> quantized_io_defs;
+  quantized_io_defs.reserve(target_node_io_defs_size);
+
+  auto cur = is_input ? target_node.InputEdgesBegin() : target_node.OutputEdgesBegin();
+  auto end = is_input ? target_node.InputEdgesEnd() : target_node.OutputEdgesEnd();
+
+  for (; cur != end; ++cur) {
+    const Node& node = cur->GetNode();
+
+    // If we can find the node index in the dq or q nodes this is a quantized input/output
+    if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) {
+      const auto node_inputs = node.InputDefs();
+      // quantization scale and zp are always the input[1, 2]
+      NodeUnitIODef::QuantParam quant_param{*node_inputs[1], node_inputs.size() == 3 ? node_inputs[2] : nullptr};
+
+      if (is_input) {
+        // DQ is input to the target node, use the DstArgIndex
+        auto idx = cur->GetDstArgIndex();
+        // This is a DQ node, we are using x, x_scale, x_zp (input[0, 1, 2])
+        quantized_io_defs.insert({idx, NodeUnitIODef{*node_inputs[0], quant_param}});
+      } else {
+        // Q is output of the target node, use the SrcArgIndex
+        auto idx = cur->GetSrcArgIndex();
+        // This is a Q node, we are using y (output[0]), y_scale, y_zp (input[1, 2])
+        const auto node_outputs = node.OutputDefs();
+        quantized_io_defs.insert({idx, NodeUnitIODef{*node_outputs[0], quant_param}});
+      }
+    }
+  }
+
+  // Construct the IODefs for this QDQ NodeGroup
+  std::vector<NodeUnitIODef> io_defs;
+  io_defs.reserve(target_node_io_defs_size);
+  for (size_t i = 0; i < target_node_io_defs_size; i++) {
+    // If we can find the NodeUnitIODef for this index, this is a quantized input/output
+    if (quantized_io_defs.find(i) != quantized_io_defs.cend()) {
+      io_defs.push_back(std::move(quantized_io_defs.at(i)));
+    } else {
+      // This is a regular input
+      io_defs.push_back({*target_node_io_defs[i], std::nullopt});
+    }
+  }
+
+  return io_defs;
+}
+
+}  // namespace
+
+Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer,
+                                          const Node& target_node,
+                                          gsl::span<const Node* const> dq_nodes,
+                                          gsl::span<const Node* const> q_nodes) {
+  // Within a QDQ node group, a target node input is the only consumer of each DQ.
+  // This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications
+  // may have happened since. Verify that this is still true.
+  for (const auto* dq_node : dq_nodes) {
+    const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node);
+    ORT_RETURN_IF(dq_produces_graph_output,
+                  "QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(),
+                  ", target node: ", target_node.Name());
+
+    const bool dq_has_single_output_edge_to_target =
+        dq_node->GetOutputEdgesCount() == 1 &&
+        dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index();
+    ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target,
+                      "QDQ node group cannot have DQ that doesn't have a single output edge to the target node. "
+                      "DQ node: ",
+                      dq_node->Name(), ", target node: ", target_node.Name());
+  }
+
+  // an output from the target node can have either Q consumers or direct consumers. it cannot have both.
+  // this must be checked on a per output basis.
+  // e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ
+  // node group with a quantized TopK, an int64_t indices value will be produced and can provide a graph output.
+  if (!q_nodes.empty()) {
+    auto cur_edge = target_node.OutputEdgesBegin();
+    auto end_edge = target_node.OutputEdgesEnd();
+    std::vector<const Node*> output_consumers(target_node.OutputDefs().size(), nullptr);
+
+    for (; cur_edge != end_edge; ++cur_edge) {
+      auto output_idx = cur_edge->GetSrcArgIndex();
+      const Node& this_consumer = cur_edge->GetNode();
+      const Node* existing_consumer = output_consumers[output_idx];
+
+      if (existing_consumer != nullptr) {
+        // another edge for this output. either both are Q or both are not.
+        bool valid = true;
+        if (existing_consumer->OpType() == "QuantizeLinear") {
+          valid = this_consumer.OpType() == "QuantizeLinear";
+        } else {
+          valid = this_consumer.OpType() != "QuantizeLinear";
+        }
+
+        ORT_RETURN_IF_NOT(valid,
+                          "QDQ node group cannot have an output from the target node being consumed by a Q node and "
+                          "a non-Q node. target node: ",
+                          target_node.Name());
+      } else {
+        output_consumers[output_idx] = &this_consumer;
+      }
+    }
+
+    const auto& graph_outputs = graph_viewer.GetOutputs();
+    for (size_t idx = 0, end = output_consumers.size(); idx < end; ++idx) {
+      // any output with a Q cannot be a graph output as it will disappear if the QDQ node unit is converted to
+      // a quantized op.
+      if (output_consumers[idx] != nullptr && output_consumers[idx]->OpType() == "QuantizeLinear") {
+        const auto& output_name = target_node.OutputDefs()[idx]->Name();
+        bool is_graph_output = std::any_of(graph_outputs.begin(), graph_outputs.end(),
+                                           [&output_name](const NodeArg* node_arg) {
+                                             return node_arg->Name() == output_name;
+                                           });
+        ORT_RETURN_IF(is_graph_output,
+                      "QDQ node group cannot have an output from the target node that is consumed by a Q node and "
+                      "a graph output. target node: ",
+                      target_node.Name(), " output idx:", idx);
+      }
+    }
+  }
+
+  return Status::OK();
+}
+NodeUnit::NodeUnit(const Node& node)
+    : target_node_(node),
+      type_(Type::SingleNode),
+      input_edge_count_(node.GetInputEdgesCount()) {
+  InitForSingleNode();
+}
+
+NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group)
+    : dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)},
+      target_node_(*graph_viewer.GetNode(node_group.target_node)),
+      q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)},
+      type_(Type::QDQGroup),
+      inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)},
+      outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} {
+  ORT_THROW_IF_ERROR(QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, dq_nodes_, q_nodes_));
+
+  input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0),
+                                      [](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); });
+
+  // add edges for inputs that are not from DQ nodes. there is one edge to each DQ node.
+  // other inputs could come from initializers or graph inputs (no edges) or other nodes (edge).
+  input_edge_count_ += target_node_.GetInputEdgesCount() - dq_nodes_.size();
+
+  // create output edges. each target node output either goes to Q node/s or non-Q node/s.
+  // ValidateNodeGroupQDQNodes ensures this.
+  auto cur_edge = target_node_.OutputEdgesBegin();
+  auto end_edge = target_node_.OutputEdgesEnd();
+  for (; cur_edge != end_edge; ++cur_edge) {
+    const Node& node = cur_edge->GetNode();
+
+    // if node is in q_nodes we hide the Q node.
+    if (std::find(q_nodes_.cbegin(), q_nodes_.cend(), &node) != q_nodes_.cend()) {
+      auto src_idx = cur_edge->GetSrcArgIndex();
+      auto q_cur_edge = node.OutputEdgesBegin();
+      auto q_end_edge = node.OutputEdgesEnd();
+      for (; q_cur_edge != q_end_edge; ++q_cur_edge) {
+        output_edges_.insert(Node::EdgeEnd{q_cur_edge->GetNode(), src_idx, q_cur_edge->GetDstArgIndex()});
+      }
+    } else {
+      // non-Q node, or Q node that isn't in the QDQ node group (unexpected but may be possible). add as-is.
+      output_edges_.insert(*cur_edge);
+    }
+  }
+}
+
+const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); }
+const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); }
+const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); }
+int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); }
+NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); }
+const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); }
+ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); }
+
+void NodeUnit::InitForSingleNode() {
+  const auto& input_defs = target_node_.InputDefs();
+  const auto& output_defs = target_node_.OutputDefs();
+  auto qlinear_type = GetQLinearOpType(target_node_);
+  if (qlinear_type == QLinearOpType::Unknown || IsVariadicQLinearOp(qlinear_type)) {  // TODO, add variadic support
+    // Not a Qlinear op, add all inputs / outputs
+    auto add_all_io = [](std::vector<NodeUnitIODef>& defs,
+                         const ConstPointerContainer<std::vector<NodeArg*>>& node_defs) {
+      defs.reserve(node_defs.size());
+
+      for (const auto def : node_defs) {
+        defs.push_back(NodeUnitIODef{*def, std::nullopt});
+      }
+    };
+
+    add_all_io(inputs_, input_defs);
+    add_all_io(outputs_, output_defs);
+  } else if (IsUnaryQLinearOp(qlinear_type)) {
+    // Unary QLinear Op has 5 inputs
+    // x, x_scale, x_zp, y_scale, y_zp (optional)
+    inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
+    outputs_.push_back(NodeUnitIODef{*output_defs[0],
+                                     NodeUnitIODef::QuantParam{*input_defs[3],
+                                                               input_defs.size() > 4 ? input_defs[4] : nullptr}});
+
+  } else if (IsBinaryQLinearOp(qlinear_type)) {
+    // Binary QLinear Op has 9 inputs
+    // x1, x1_scale, x1_zp, x2/w, x2_scale, x2_zp, y_scale , y_zp, B
+    inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
+    inputs_.push_back(NodeUnitIODef{*input_defs[3], NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}});
+
+    if (input_defs.size() == 9) {                                      // has Bias
+      inputs_.push_back(NodeUnitIODef{*input_defs[8], std::nullopt});  // for Bias the scale and zp are optional
+    }
+
+    outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}});
+
+  } else if (qlinear_type == QLinearOpType::DequantizeLinear) {
+    // DequantizeLinear has 3 inputs
+    // x, x_scale, x_zp
+    // output is not quantized
+    inputs_.push_back(NodeUnitIODef{*input_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3
+                                                                                                  ? input_defs[2]
+                                                                                                  : nullptr}});
+    outputs_.push_back(NodeUnitIODef{*output_defs[0], std::nullopt});
+
+  } else if (qlinear_type == QLinearOpType::QuantizeLinear) {
+    // QuantizeLinear the input is not quantized and has 3 inputs
+    // x, y_scale, y_zp (optional)
+    // The output is quantized
+    inputs_.push_back(NodeUnitIODef{*input_defs[0], std::nullopt});
+    outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3
+                                                                                                    ? input_defs[2]
+                                                                                                    : nullptr}});
+  } else {
+    ORT_THROW("The QLinear op [", static_cast<uint8_t>(qlinear_type), "] is not supported");
+  }
+}
+
+Node::EdgeConstIterator NodeUnit::OutputEdgesBegin() const {
+  return (type_ == Type::SingleNode) ? target_node_.OutputEdgesBegin() : output_edges_.begin();
+}
+
+Node::EdgeConstIterator NodeUnit::OutputEdgesEnd() const {
+  return (type_ == Type::SingleNode) ? target_node_.OutputEdgesEnd() : output_edges_.end();
+}
+
+std::vector<const Node*> NodeUnit::GetAllNodesInGroup() const noexcept {
+  std::vector<const Node*> all_nodes = dq_nodes_;
+  all_nodes.push_back(&target_node_);
+  all_nodes.insert(all_nodes.end(), q_nodes_.begin(), q_nodes_.end());
+  return all_nodes;
+}
+
+}  // namespace onnxruntime
+
+#endif  // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
diff --git a/onnxruntime/core/providers/shared/node_unit/node_unit.h b/onnxruntime/core/framework/node_unit.h
similarity index 54%
rename from onnxruntime/core/providers/shared/node_unit/node_unit.h
rename to onnxruntime/core/framework/node_unit.h
index b47204ca3c42d..66afaec8ee1e2 100644
--- a/onnxruntime/core/providers/shared/node_unit/node_unit.h
+++ b/onnxruntime/core/framework/node_unit.h
@@ -3,6 +3,9 @@
 
 #pragma once
 
+// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build
+#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
+
 #include <string>
 #include <optional>
 #include <vector>
@@ -18,8 +21,21 @@ class NodeArg;
 class Path;
 
 namespace QDQ {
-struct NodeGroup;
-}
+// Struct to represent a DequantizeLinear -> Op -> QuantizeLinear node group
+struct NodeGroup {
+  std::vector<NodeIndex> dq_nodes;
+  std::vector<NodeIndex> q_nodes;
+  NodeIndex target_node;
+
+  // Validator to check if the set of nodes can form a valid QDQ NodeGroup.
+  // Checks target node is only consumer of each DQ, and that the outputs remain valid if the QDQ node group was to
+  // be converted into a single node with a quantized operator.
+  static Status CanCreateNodeGroup(const GraphViewer& graph_viewer,
+                                   const Node& target_node,
+                                   gsl::span<const Node* const> dq_nodes,
+                                   gsl::span<const Node* const> q_nodes);
+};
+}  // namespace QDQ
 
 // Definition of one input or output
 // If the optional quant_param is present, then this is a quantized input,
@@ -69,26 +85,33 @@ class NodeUnit {
   const std::vector<const Node*>& GetQNodes() const noexcept { return q_nodes_; }
   std::vector<const Node*> GetAllNodesInGroup() const noexcept;
 
-  Node::EdgeConstIterator OutputEdgesBegin(size_t index) const;
-  Node::EdgeConstIterator OutputEdgesEnd(size_t index) const;
+  /// Number of input edges to the logical node. For a QDQ node this is the count of input edges to the DQ nodes
+  /// plus any other edges to the target node for inputs that are not via a DQ node.
+  size_t InputEdgeCount() const { return input_edge_count_; }
+
+  // output edges. src index is for outputs of the target node. dest index and node is for consumer of node unit
+  // output. any Q nodes are hidden.
+  Node::EdgeConstIterator OutputEdgesBegin() const;
+  Node::EdgeConstIterator OutputEdgesEnd() const;
 
  private:
-  const std::vector<const Node*> q_nodes_;   // q-nodes for this NodeUnit
-  const std::vector<const Node*> dq_nodes_;  // dq nodes for this NodeUnit, not all inputs
+  // Initialization for a NodeUnit that contains a single node
+  void InitForSingleNode();
+
+  const std::vector<const Node*> dq_nodes_;  // dq nodes for this NodeUnit, not necessarily all inputs
   const Node& target_node_;
+  const std::vector<const Node*> q_nodes_;  // q-nodes for this NodeUnit. not necessarily all outputs
   const Type type_;
 
   std::vector<NodeUnitIODef> inputs_;
   std::vector<NodeUnitIODef> outputs_;
 
-  // Initializing for a single Node
-  void InitForSingleNode();
-};
+  size_t input_edge_count_;  // total number of input edges
 
-// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup)
-// And return a map to quick query the NodeUnit which contains the given Node,
-// Note, the value of the map is owned by the vector of std::unique_ptr<NodeUnit>
-std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
-GetAllNodeUnits(const GraphViewer& graph_viewer);
+  // output edges, hiding any Q nodes involved. src_idx will be value from target node. only used for QDQ node group.
+  Node::EdgeSet output_edges_;
+};
 
 }  // namespace onnxruntime
+
+#endif  // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc
index ea7f1397c961b..0cc7294a46495 100644
--- a/onnxruntime/core/framework/sequential_executor.cc
+++ b/onnxruntime/core/framework/sequential_executor.cc
@@ -306,18 +306,20 @@ class KernelScope {
 #endif
 
 #ifdef ENABLE_NVTX_PROFILE
-    auto& node = kernel_.Node();
-    profile::NvtxRangeCreator& forward_range = session_scope_.forward_range_;
-    profile::NvtxRangeCreator& backward_range = session_scope_.backward_range_;
-    if (node.Description() != "Backward pass" && !forward_range.IsBeginCalled()) {
-      // Start timing forward pass when encountering the first forward node.
-      forward_range.Begin();
-    } else if (node.Description() == "Backward pass" && !backward_range.IsBeginCalled() &&
-               forward_range.IsBeginCalled()) {
-      // Start timing backward pass when encountering the first backward node.
-      // In the meanwhile, forward range ends.
-      forward_range.End();
-      backward_range.Begin();
+    {
+      auto& node = kernel_.Node();
+      profile::NvtxRangeCreator& forward_range = session_scope_.forward_range_;
+      profile::NvtxRangeCreator& backward_range = session_scope_.backward_range_;
+      if (node.Description() != "Backward pass" && !forward_range.IsBeginCalled()) {
+        // Start timing forward pass when encountering the first forward node.
+        forward_range.Begin();
+      } else if (node.Description() == "Backward pass" && !backward_range.IsBeginCalled() &&
+                 forward_range.IsBeginCalled()) {
+        // Start timing backward pass when encountering the first backward node.
+        // In the meanwhile, forward range ends.
+        forward_range.End();
+        backward_range.Begin();
+      }
     }
 #endif
 
diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h
index 51bb02918d82f..a2ee1601d386b 100644
--- a/onnxruntime/core/framework/session_state.h
+++ b/onnxruntime/core/framework/session_state.h
@@ -8,7 +8,7 @@
 #include <unordered_map>
 #include <vector>
 
-#include "flatbuffers/flatbuffers.h"
+#include "core/common/flatbuffers.h"
 
 #include "core/common/gsl.h"
 
diff --git a/onnxruntime/core/framework/stream_execution_context.cc b/onnxruntime/core/framework/stream_execution_context.cc
index 875e7f395bfa8..dd7f4d35b34bd 100644
--- a/onnxruntime/core/framework/stream_execution_context.cc
+++ b/onnxruntime/core/framework/stream_execution_context.cc
@@ -181,11 +181,13 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess
   }
 
 #ifdef USE_CANN
+  // Leave it to CANN EP to fill the gap if they want to use run_options
+  static onnxruntime::RunOptions run_options;
   // For CANN EP, it is necessary to explicitly create a corresponding Context for each thread in the thread pool,
   // which is different from CUDA Runtime API, but similar to CUDA Driver API.
   auto& execution_providers = ctx.GetSessionState().GetExecutionProviders();
   for (auto& xp : execution_providers) {
-    auto status = xp->OnRunStart();
+    auto status = xp->OnRunStart(run_options);
     if (!status.IsOK()) {
       ctx.SetStatus(status);
       return;
diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc
index 23fe5e1cd3d96..b737d735b977b 100644
--- a/onnxruntime/core/framework/utils.cc
+++ b/onnxruntime/core/framework/utils.cc
@@ -1015,9 +1015,19 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index)
   }
 
 #ifdef ENABLE_ATEN
+  // For ATen node, we assume that all tensor inputs are on device, all non-tensor inputs are on CPU,
+  // except those specified in attribute cpu_input_args;
   if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" &&
       node.Domain() == kPytorchAtenDomain) {
     const auto& attrs = node.GetAttributes();
+    if (auto entry = attrs.find("cpu_input_args"); entry != attrs.end()) {
+      const auto& attr = entry->second;
+      if (utils::HasInts(attr) && std::any_of(attr.ints().cbegin(), attr.ints().cend(),
+                                              [index](int64_t arg) { return static_cast<int64_t>(index) == arg; })) {
+        return true;
+      }
+    }
+
     ORT_ENFORCE(utils::HasString(attrs.at("operator")));
     std::string op_name = attrs.at("operator").s();
     std::string overload_name = "";
@@ -1025,7 +1035,7 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index)
       overload_name = attrs.at("overload_name").s();
     }
 
-    return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true);
+    return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index, true);
   }
 #else
   ORT_UNUSED_PARAMETER(node);
@@ -1040,9 +1050,19 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index
   }
 
 #ifdef ENABLE_ATEN
+  // For ATen node, we assume that all tensor outputs are on device, all non-tensor outputs are on CPU,
+  // except those specified in attribute cpu_output_args;
   if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" &&
       node.Domain() == kPytorchAtenDomain) {
     const auto& attrs = node.GetAttributes();
+    if (auto entry = attrs.find("cpu_output_args"); entry != attrs.end()) {
+      const auto& attr = entry->second;
+      if (utils::HasInts(attr) && std::any_of(attr.ints().cbegin(), attr.ints().cend(),
+                                              [index](int64_t arg) { return static_cast<int64_t>(index) == arg; })) {
+        return true;
+      }
+    }
+
     ORT_ENFORCE(utils::HasString(attrs.at("operator")));
     std::string op_name = attrs.at("operator").s();
     std::string overload_name = "";
@@ -1050,7 +1070,7 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index
       overload_name = attrs.at("overload_name").s();
     }
 
-    return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false);
+    return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index, false);
   }
 #else
   ORT_UNUSED_PARAMETER(node);
diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc
index 4aa43f5de1cd5..a0ca2e45f153a 100644
--- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc
@@ -91,10 +91,18 @@ void RegisterCollectiveOps() {
             "Number of top experts to select from expert pool",
             AttributeProto::INT,
             static_cast<int64_t>(1))
+      .Attr("normalize_routing_weights",
+            "Whether to normalize routing weights",
+            AttributeProto::INT,
+            static_cast<int64_t>(0))
       .Attr("local_experts_start_index",
             "The start index of local experts",
             AttributeProto::INT,
-            static_cast<int64_t>(-1))
+            static_cast<int64_t>(0))
+      .Attr("tensor_shards",
+            "Tensor parallelism config. The number of shards for each expert weight and bias",
+            AttributeProto::INT,
+            static_cast<int64_t>(1))
       .Input(0,
              "input",
              "2D input tensor with shape (num_rows, hidden_size) or "
@@ -106,22 +114,32 @@ void RegisterCollectiveOps() {
              "T")
       .Input(2,
              "fc1_experts_weights",
-             "3D input tensor with shape (local_num_experts, hidden_size, inter_size)",
+             "3D input tensor with shape (local_num_experts, hidden_size, local_inter_size)",
              "T")
       .Input(3,
-             "fc2_experts_weights",
-             "3D input tensor with shape (local_num_experts, inter_size, hidden_size)",
-             "T")
-      .Input(4,
              "fc1_experts_bias",
-             "2D optional input tensor with shape (local_num_experts, inter_size)",
+             "2D optional input tensor with shape (local_num_experts, local_inter_size)",
              "T",
              OpSchema::Optional)
+      .Input(4,
+             "fc2_experts_weights",
+             "3D input tensor with shape (local_num_experts, local_inter_size, hidden_size)",
+             "T")
       .Input(5,
              "fc2_experts_bias",
              "2D optional input tensor with shape (num_experts, hidden_size)",
              "T",
              OpSchema::Optional)
+      .Input(6,
+             "fc3_experts_weights",
+             "3D optional input tensor with shape (local_num_experts, hidden_size, local_inter_size)",
+             "T",
+             OpSchema::Optional)
+      .Input(7,
+             "fc3_experts_bias",
+             "2D optional input tensor with shape (local_num_experts, local_inter_size)",
+             "T",
+             OpSchema::Optional)
       .Output(0,
               "output",
               "2D input tensor with shape (num_rows, hidden_size) or "
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 27c968a59eb91..82cc16acad582 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -1163,7 +1163,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
                                        "Shape is (1,)",
                                        "T", OpSchema::Optional)
                                 .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional)
-                                .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional)
+                                .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional)
                                 .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional)
                                 .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
                                 .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional)
@@ -1188,7 +1188,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1,
                                 .SetDoc("Beam Search for whisper model, especiall with cross_qk features etc.")
                                 .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT)
                                 .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT)
-                                .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast<int64_t>(-1))
+                                .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts (i.e. the start of transcription token id)", AttributeProto::INT, static_cast<int64_t>(-1))
+                                .Attr("translate_token_id", "The id of the translate task", AttributeProto::INT, OPTIONAL_VALUE)
+                                .Attr("transcribe_token_id", "The id of the transcribe task", AttributeProto::INT, OPTIONAL_VALUE)
+                                .Attr("start_of_lm_token_id", "The id of the token that indicates LM starts", AttributeProto::INT, OPTIONAL_VALUE)
+                                .Attr("no_speech_token_id",
+                                      "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.",
+                                      AttributeProto::INT, OPTIONAL_VALUE)
+                                .Attr("no_timestamps_token_id", "The id of the token that indicates no timestamps", AttributeProto::INT, OPTIONAL_VALUE)
+                                .Attr("beginning_timestamp_token_id", "The id of the first timestamp", AttributeProto::INT, OPTIONAL_VALUE)
                                 .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast<int64_t>(0))
                                 .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast<int64_t>(0))
                                 .Attr("model_type", "Must be 2 for whisper", AttributeProto::INT, static_cast<int64_t>(2))
@@ -1203,27 +1211,24 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1,
                                       "If not provided, it will be inferred from the decoder subgraph's output shape",
                                       AttributeProto::INT, static_cast<int64_t>(-1))
                                 .Attr("decoder_output_cross_qk", "If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.", AttributeProto::INT, OPTIONAL_VALUE)
-                                .Attr("no_speech_token",
-                                      "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.",
-                                      AttributeProto::INT, OPTIONAL_VALUE)
                                 .Input(0, "input_ids", "The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)", "F")
                                 .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")
                                 .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional)
                                 .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I")
                                 .Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I")
                                 .Input(5, "length_penalty",
-                                       "Exponential penalty to the length. Default value 1.0 means no penalty."
-                                       "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences."
+                                       "Exponential penalty to the length. Default value 1.0 means no penalty. "
+                                       "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. "
                                        "Shape is (1,)",
                                        "T", OpSchema::Optional)
                                 .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional)
-                                .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional)
+                                .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional)
                                 .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional)
                                 .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
                                 .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional)
                                 .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional)
                                 .Input(12, "cross_qk_layer_head",
-                                       "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all"
+                                       "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all "
                                        "its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]",
                                        "I", OpSchema::Optional)
                                 .Input(13, "extra_decoding_ids",
@@ -1235,20 +1240,19 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1,
                                 .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I")
                                 .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional)
                                 .Output(2, "scores",
-                                        "Processed beam scores for each vocabulary token at each generation step."
-                                        "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam."
+                                        "Processed beam scores for each vocabulary token at each generation step. "
+                                        "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. "
                                         "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)",
                                         "T", OpSchema::Optional)
                                 .Output(3, "cross_qk",
                                         "Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, "
-                                        "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,"
-                                        "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]."
+                                        "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, "
+                                        "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. "
                                         "If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]",
                                         "V", OpSchema::Optional)
                                 .Output(4, "non_speech_probs",
-                                        "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token."
-                                        "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph."
-                                        "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]",
+                                        "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. "
+                                        "The shape of non_speech_probs is [B]",
                                         "T", OpSchema::Optional)
                                 .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.")
                                 .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.")
@@ -1322,7 +1326,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1,
                                 .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")
                                 .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional)
                                 .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional)
-                                .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional)
+                                .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional)
                                 .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
                                 .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
                                 .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I")
@@ -1363,7 +1367,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1,
                                 .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")
                                 .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional)
                                 .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional)
-                                .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional)
+                                .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional)
                                 .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
                                 .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
                                 .Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
@@ -1378,8 +1382,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1,
 
 constexpr const char* MoE_ver1_doc = R"DOC(
       Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1,
-      GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf)
-      usually uses top 32 experts.
+      GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf)
+      usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral).
       )DOC";
 
 ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1,
@@ -1387,12 +1391,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1,
                                 .SetDoc(MoE_ver1_doc)
                                 .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu"))
                                 .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast<int64_t>(1))
+                                .Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast<int64_t>(0))
                                 .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
                                 .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T")
                                 .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T")
-                                .Input(3, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T")
-                                .Input(4, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional)
+                                .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional)
+                                .Input(4, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T")
                                 .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional)
+                                .Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, hidden_size, inter_size)", "T", OpSchema::Optional)
+                                .Input(7, "fc3_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional)
                                 .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
                                 .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or float16 tensors.")
                                 .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
@@ -3339,22 +3346,23 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7
      And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
   3. Input B's scale and zero point are specified by input scales and zero_points.
 
-Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
-- n_blocks_per_col = (K + block_size - 1) / block_size
-- blob_size = block_size / 8 * bits
+  Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
+  - n_blocks_per_col = (K + block_size - 1) / block_size
+  - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>)
+  For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t.
+    - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t.
+        4bit example:
+        |.|.|.|.| .|.|.|.| =uint8_t (2x4bit)
+    - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted.
+        3bit example:
+        |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used.
+  The last uint_8 may have some bits unused.
 
-  For a block blob. It is stored in format:
-  struct Blob {
-    uint8 one_bits[(bits & 0x1) * 1 * block_size / 8];  // highest 1 bit for 3, 5, 7 bits quantization
-    uint8 two_bits[(bits & 0x2) * 2 * block_size / 8];  // high 2 bits for 2, 6, 7 bits quantization
-    uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization
-  }
 
 Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col]
-Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is:
-  - [(N * n_blocks_per_col + 1) / 2] if bits <=4
-  - [N * n_blocks_per_col] if bits > 4
-
+Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B.
+  - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)]
+  If zero_points has same type as A, it's not packed and has the same shape as Scales.
 )DOC";
 
   ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits)
@@ -3373,12 +3381,15 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
             "type T1.",
             AttributeProto::INT, static_cast<int64_t>(0))
       .Input(0, "A", "The input tensor, not quantized", "T1")
-      .Input(1, "B", "1-dimensional data blob", "T2")
+      .Input(1, "B", "1 or 2 dimensional data blob", "T2")
       .Input(2, "scales", "quantization scale", "T1")
-      .Input(3, "zero_points", "quantization zero points", "T2", OpSchema::Optional)
+      .Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional)
+      .Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional)
       .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1")
       .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.")
-      .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.")
+      .TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/int32.")
+      .TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.")
+      .TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.")
       .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
         // Type inference
         propagateElemTypeFromInputToOutput(ctx, 0, 0);
@@ -3466,6 +3477,8 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4
               /*min_arity*/ 1)
       .Attr("operator", "Name of ATen operator.", AttributeProto::STRING)
       .Attr("overload_name", "Overload name of ATen operator.", AttributeProto::STRING, false)
+      .Attr("cpu_input_args", "CPU input argument indices.", AttributeProto::INTS, false)
+      .Attr("cpu_output_args", "CPU output argument indices.", AttributeProto::INTS, false)
       .TypeConstraint("T", OpSchema::all_tensor_types_ir4(),
                       "Allow inputs and outputs to be any kind of tensor.");
 #endif
diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc
index c8960578f9e3d..6bf19654a3ce9 100644
--- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc
+++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc
@@ -106,6 +106,7 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function<void(ONNX_NAMES
   REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, BatchNormalization, 14);
   REGISTER_NHWC_SCHEMA_WITH_ACTIVATION(fn, BatchNormalization, 15);
 
+  REGISTER_NHWC_SCHEMA(fn, DepthToSpace, 1);
   REGISTER_NHWC_SCHEMA(fn, DepthToSpace, 11);
   REGISTER_NHWC_SCHEMA(fn, DepthToSpace, 13);
 
diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
index 4313fae767fe5..22a79ef652515 100644
--- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
@@ -434,7 +434,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
         .Output(0, "Y", "Matrix multiply results from A * B", "T3")
         .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input A data type to 8-bit integer tensor.")
         .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input B data type to 8-bit integer tensor.")
-        .TypeConstraint("T3", {"tensor(float)"},
+        .TypeConstraint("T3", {"tensor(float)", "tensor(float16)"},
                         "Constrain input a_scale, b_scale and output Y data type as float tensor.")
         .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
           propagateElemTypeFromInputToOutput(ctx, 2, 0);
diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc
index 902839bee04ba..305122c56b865 100644
--- a/onnxruntime/core/graph/graph.cc
+++ b/onnxruntime/core/graph/graph.cc
@@ -1818,16 +1818,36 @@ void Graph::ReverseDFSFrom(gsl::span<const Node* const> from,
   }
 }
 
+template <typename T>
+struct VisitorPriorityQueue {
+  using ComparatorType = std::function<bool(T, T)>;
+  std::list<T> list_;
+  const ComparatorType comparator_ = nullptr;
+  VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {}
+
+  void push(T node) {
+    list_.insert(
+        std::upper_bound(list_.begin(), list_.end(), node, comparator_),
+        node);
+  }
+  bool empty() { return list_.empty(); }
+  T top() { return list_.back(); }
+  void pop() { list_.pop_back(); }
+};
+
 #if !defined(ORT_MINIMAL_BUILD)
 void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
                                  const std::function<bool(const Node*, const Node*)>& comp) const {
-  std::unordered_map<NodeIndex, size_t> in_degree;
-  std::priority_queue<const Node*, std::vector<const Node*>, decltype(comp)> to_visit(comp);
-  std::vector<NodeIndex> topo_order;
+  InlinedVector<size_t> in_degree(MaxNodeIndex(), 0);
+  InlinedVector<NodeIndex> topo_order;
+  VisitorPriorityQueue<const Node*> to_visit(comp);
+
+  auto number_of_nodes = NumberOfNodes();
+  topo_order.reserve(number_of_nodes);
 
   for (auto& node : Nodes()) {
     size_t input_edge_count = node.GetInputEdgesCount();
-    in_degree.insert({node.Index(), input_edge_count});
+    in_degree[node.Index()] = input_edge_count;
     if (input_edge_count == 0) {
       to_visit.push(&node);
     }
@@ -1844,16 +1864,17 @@ void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
     }
 
     for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) {
-      in_degree[node_it->Index()]--;
+      auto& node_in_degree = in_degree[node_it->Index()];
+      node_in_degree--;
 
-      if (in_degree[node_it->Index()] == 0) {
+      if (node_in_degree == 0) {
         to_visit.push(&*node_it);
       }
     }
     topo_order.push_back(current->Index());
   }
 
-  if (NumberOfNodes() != static_cast<int>(topo_order.size())) {
+  if (number_of_nodes != static_cast<int>(topo_order.size())) {
     ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle.");
   }
 }
@@ -2843,7 +2864,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) {
 
   const gsl::not_null<TensorProto*> tensor_added{graph_proto_->add_initializer()};
   *(tensor_added) = tensor;
-  name_to_initial_tensor_[tensor.name()] = tensor_added;
+  name_to_initial_tensor_.emplace(tensor.name(), tensor_added);
   SetGraphResolveNeeded();
   if (!is_loaded_from_model_file_ && GetNodeArg(tensor.name()) == nullptr) {
     // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs.
diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc
index 6d7ed94b2956d..2314a5228f83c 100644
--- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc
+++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc
@@ -3,7 +3,7 @@
 
 #include "graph_flatbuffers_utils.h"
 
-#include "flatbuffers/flatbuffers.h"
+#include "core/common/flatbuffers.h"
 
 #include "core/common/narrow.h"
 #include "core/flatbuffers/flatbuffers_utils.h"
diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.h b/onnxruntime/core/graph/graph_flatbuffers_utils.h
index b625cbf3ca492..9c55dad3c41ef 100644
--- a/onnxruntime/core/graph/graph_flatbuffers_utils.h
+++ b/onnxruntime/core/graph/graph_flatbuffers_utils.h
@@ -5,7 +5,7 @@
 
 #include <memory>
 
-#include "flatbuffers/flatbuffers.h"
+#include "core/common/flatbuffers.h"
 
 #include "core/common/status.h"
 #include "core/graph/ort_format_load_options.h"
diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc
index acf7b3a16541f..119d420066a84 100644
--- a/onnxruntime/core/graph/graph_viewer.cc
+++ b/onnxruntime/core/graph/graph_viewer.cc
@@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const {
 struct PriorityNodeCompare {
   inline bool IsHighPri(const Node* n) const {
     // local statics so we can compare std::strings in the checks
-    static const std::string shape_op("Shape");
-    static const std::string size_op("Size");
+    static constexpr std::string_view shape_op("Shape");
+    static constexpr std::string_view size_op("Size");
 
     const auto& op_type = n->OpType();
     return op_type == shape_op || op_type == size_op;
@@ -26,15 +26,20 @@ struct PriorityNodeCompare {
   // If return true, n2 will be output first
   bool operator()(const Node* n1, const Node* n2) const {
     // nodes in global high priority list will be output first
-    if (IsHighPri(n1) != IsHighPri(n2)) {
-      return IsHighPri(n2);
+    const bool isN1HighPri = IsHighPri(n1);
+    const bool isN2HighPri = IsHighPri(n2);
+    if (isN1HighPri != isN2HighPri) {
+      return isN2HighPri;
     }
 
     // nodes with lower priority value will be output first
-    if (n1->Priority() != n2->Priority()) {
-      return n1->Priority() > n2->Priority();
+    const auto n1_priority = n1->Priority();
+    const auto n2_priority = n2->Priority();
+    if (n1_priority != n2_priority) {
+      return n1_priority > n2_priority;
     }
 
+#ifdef ENABLE_TRAINING
     // nodes of forward pass will be output first
     auto n1_attrs = n1->GetAttributes();
     auto n2_attrs = n2->GetAttributes();
@@ -45,6 +50,7 @@ struct PriorityNodeCompare {
     if (n1_is_forward != n2_is_forward) {
       return n2_is_forward > n1_is_forward;
     }
+#endif
 
     // otherwise, nodes with lower index will be output first
     return n1->Index() > n2->Index();
diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h
index 4ce6660b794bc..a774d5fe34461 100644
--- a/onnxruntime/core/graph/model.h
+++ b/onnxruntime/core/graph/model.h
@@ -8,7 +8,7 @@
 #include <climits>
 #include <string>
 
-#include "flatbuffers/flatbuffers.h"
+#include "core/common/flatbuffers.h"
 
 #include "core/common/path.h"
 #include "core/graph/graph_viewer.h"
diff --git a/onnxruntime/core/graph/op_identifier_utils.h b/onnxruntime/core/graph/op_identifier_utils.h
index 8a9351a2d0ddc..f7b1198c31972 100644
--- a/onnxruntime/core/graph/op_identifier_utils.h
+++ b/onnxruntime/core/graph/op_identifier_utils.h
@@ -3,7 +3,7 @@
 
 #pragma once
 
-#include "flatbuffers/flatbuffers.h"
+#include "core/common/flatbuffers.h"
 
 #include "core/graph/op_identifier.h"
 
diff --git a/onnxruntime/core/graph/runtime_optimization_record_container.h b/onnxruntime/core/graph/runtime_optimization_record_container.h
index a28b19e786de0..75750c2b96987 100644
--- a/onnxruntime/core/graph/runtime_optimization_record_container.h
+++ b/onnxruntime/core/graph/runtime_optimization_record_container.h
@@ -9,7 +9,7 @@
 #include <unordered_map>
 #include <vector>
 
-#include "flatbuffers/flatbuffers.h"
+#include "core/common/flatbuffers.h"
 
 #include "core/common/common.h"
 #include "core/graph/runtime_optimization_record.h"
diff --git a/onnxruntime/core/mickey/README.md b/onnxruntime/core/mickey/README.md
index 7e8d30cd1805b..735ec4b80daf3 100644
--- a/onnxruntime/core/mickey/README.md
+++ b/onnxruntime/core/mickey/README.md
@@ -4,3 +4,7 @@ Playful name for a template library of high performance cuda code that
 are often shared by various AI operators. The intention is to make this
 header files only, with no binary impact unless it is instantiated
 where it is needed.
+
+Currently cuda code are scattered in multiple locations in the repo.
+Hopefully this can be the starting point of consolidating all cuda
+code.
diff --git a/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h
new file mode 100644
index 0000000000000..52bff7e40dbe3
--- /dev/null
+++ b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h
@@ -0,0 +1,208 @@
+/**
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the MIT License.
+ *
+ * Module Name:
+ *   blk_q4/f16_gemm_sm80.h
+ *
+ * Abstract:
+ *   Entry point for Q4F16 GEMM kernel for SM80 devices.
+ */
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include "cutlass_ext/q4gemm/device/quantb_gemm.h"
+
+namespace onnxruntime {
+namespace cuda {
+
+//
+// This is the implementation of the quantized GEMM kernel for 16b float x blocked quantized 4b data type
+//
+template <
+    typename ElementDequant_,  // <- data type of dequantized elements for gemm, fp16 or bf16
+    typename QuantBlocking_,   // <- weights block per scale, cutlass::MatrixShape<x,y>
+    bool SmallM,               // <- true if M <= 16
+    bool kHasQuantOffset>
+struct BlkQ4F16GemmImpl {
+  //
+  // Type definitions
+  //
+
+  using ElementDequant = ElementDequant_;
+  using QuantBlocking = QuantBlocking_;
+
+  static_assert(sizeof(ElementDequant) == 2, "q4f16gemm kerenl only support 16b operands!");
+
+  // Data types that are fixed for this kernel
+  using ElementAccumulator = float;
+  using ElementComputeEpilogue = ElementAccumulator;
+  using ElementInputA = ElementDequant;
+  using ElementOutput = ElementDequant;
+
+  using ElementW = uint8_t;  // <- Weight is int4, uint8 for two of them
+
+  // We pack 4 weights into one 16b element, so as to leverage cutlass tile iterators
+  // for async shared memory loading and minimize bank conflict
+  using ElementWPack = ElementDequant;
+
+  using ElementQScale = ElementDequant;  // <- data type of quantization scale
+  using ElementQOffset = uint8_t;
+
+  using LayoutInputA = cutlass::layout::RowMajor;
+  using LayoutInputWPack = cutlass::layout::ColumnMajor;
+  using LayoutOutput = cutlass::layout::RowMajor;
+
+  // Layout of quantization scale and offset, oriented to be loaded using less instructions
+  // in a warp tile
+  using LayoutInputQScale =
+      typename std::conditional<QuantBlocking::kRow == 1,
+                                cutlass::layout::ColumnMajor,
+                                cutlass::layout::RowMajor>::type;  // <- layout of quantization scale
+
+  using ShapeMMAThreadBlock =
+      typename std::conditional<SmallM,
+                                cutlass::gemm::GemmShape<16, 64, 64>,
+                                cutlass::gemm::GemmShape<128, 256, 64>>::type;
+
+  static constexpr int MinN = QuantBlocking::kColumn > 32 ? QuantBlocking::kColumn : 32;
+  using ShapeMMAWarp =
+      typename std::conditional<SmallM,
+                                cutlass::gemm::GemmShape<16, MinN, 64>,
+                                cutlass::gemm::GemmShape<64, 64, 64>>::type;
+
+  using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>;
+
+  // This code section describes how threadblocks are scheduled on GPU
+  using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;  // <- ??
+
+  // This code section describes the epilogue part of the kernel
+  using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
+      ElementOutput,                                     // <- data type of output matrix
+      128 / cutlass::sizeof_bits<ElementOutput>::value,  // <- the number of elements per vectorized
+                                                         // memory access. For a byte, it's 16
+                                                         // elements. This becomes the vector width of
+                                                         // math instructions in the epilogue too
+      ElementAccumulator,                                // <- data type of accumulator
+      ElementComputeEpilogue>;                           // <- data type for alpha/beta in linear combination function
+
+  // Number of pipelines you want to use
+  static constexpr int NumStages = 3;
+
+  using Gemm = cutlass::gemm::device::QuantBGemm<
+      ElementInputA,
+      LayoutInputA,
+      ElementWPack,
+      LayoutInputWPack,
+      ElementQScale,
+      typename std::conditional<kHasQuantOffset, ElementQOffset, std::monostate>::type,
+      LayoutInputQScale,
+      QuantBlocking,
+      ElementOutput,
+      LayoutOutput,
+      ElementAccumulator,
+      cutlass::arch::OpClassTensorOp,
+      cutlass::arch::Sm80,
+      ShapeMMAThreadBlock,
+      ShapeMMAWarp,
+      ShapeMMAOp,
+      EpilogueOp,
+      SwizzleThreadBlock,
+      NumStages>;
+
+  using Arguments = typename Gemm::Arguments;
+
+  // Invoke gemm kernel (the version with quantization offset)
+  static cutlass::Status run(
+      cudaStream_t stream,
+      const cutlass::gemm::GemmCoord& problem_size_,
+      cutlass::TensorRef<ElementInputA const, LayoutInputA> ref_A_,
+      cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_B_,
+      cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_Qscale_,
+      cutlass::TensorRef<ElementQOffset const, LayoutInputQScale> ref_Qoffset_,
+      cutlass::TensorRef<ElementOutput const, LayoutOutput> ref_C_,
+      cutlass::TensorRef<ElementOutput, LayoutOutput> ref_D_,
+      typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) {
+    if constexpr (!kHasQuantOffset) {
+      return cutlass::Status::kErrorNotSupported;
+    } else {
+      if constexpr (ShapeMMAThreadBlock::kM == 16) {
+        if (problem_size_.m() > 16) {
+          // For M > 16, the caller should have picked the
+          // kernel with bigger M
+          return cutlass::Status::kErrorNotSupported;
+        }
+      }
+
+      // Construct Gemm arguments
+      Arguments args{
+          problem_size_,
+          ref_A_,
+          ref_B_,
+          ref_Qscale_,
+          ref_Qoffset_,
+          ref_C_,
+          ref_D_,
+          epilogue_};
+
+      Gemm gemm_op;
+
+      // Check if this GEMM can be run or not
+      cutlass::Status status = gemm_op.can_implement(args);
+      if (status != cutlass::Status::kSuccess) {
+        return status;
+      }
+
+      // Launch the CUTLASS GEMM kernel.
+      return gemm_op(args, nullptr, stream);
+    }
+  }
+
+  // Invoke gemm kernel (the version without quantization offset)
+  static cutlass::Status run(
+      cudaStream_t stream,
+      const cutlass::gemm::GemmCoord& problem_size_,
+      cutlass::TensorRef<ElementInputA const, LayoutInputA> ref_A_,
+      cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_B_,
+      cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_Qscale_,
+      cutlass::TensorRef<ElementOutput const, LayoutOutput> ref_C_,
+      cutlass::TensorRef<ElementOutput, LayoutOutput> ref_D_,
+      typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) {
+    if constexpr (kHasQuantOffset) {
+      return cutlass::Status::kErrorNotSupported;
+    } else {
+      if constexpr (ShapeMMAThreadBlock::kM == 16) {
+        if (problem_size_.m() > 16) {
+          // For M > 16, the caller should have picked the
+          // kernel with bigger M
+          return cutlass::Status::kErrorNotSupported;
+        }
+      }
+
+      // Construct Gemm arguments
+      Arguments args{
+          problem_size_,
+          ref_A_,
+          ref_B_,
+          ref_Qscale_,
+          ref_C_,
+          ref_D_,
+          epilogue_};
+
+      Gemm gemm_op;
+
+      // Check if this GEMM can be run or not
+      cutlass::Status status = gemm_op.can_implement(args);
+      if (status != cutlass::Status::kSuccess) {
+        return status;
+      }
+
+      // Launch the CUTLASS GEMM kernel.
+      return gemm_op(args, nullptr, stream);
+    }
+  }
+};
+
+}  // namespace cuda
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h
similarity index 99%
rename from onnxruntime/core/mickey/blk_q4/prepack_sm80.h
rename to onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h
index e291ab39e8aa3..a08cfb97eed4a 100644
--- a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h
+++ b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h
@@ -3,7 +3,7 @@
  * Licensed under the MIT License.
  *
  * Module Name:
- *    prepack_sm80.h
+ *    blk_q4/f16_prepack_sm80.h
  *
  * Abstract:
  *    Prepack weights and quantization parameters (scales and offsets) for
diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h
new file mode 100644
index 0000000000000..38795291b0328
--- /dev/null
+++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h
@@ -0,0 +1,481 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+/**
+ * Modifications Copyright (c) Microsoft.
+ * Licensed under the MIT license.
+ *
+ * @file quantb_gemm.h
+ * @brief Modified from cutlass/gemm/device/gemm.h, boilerplate code passing input pointers to the kernel.
+*/
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include "cutlass/numeric_types.h"
+#include "cutlass/arch/arch.h"
+#include "cutlass/device_kernel.h"
+
+#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
+#include "cutlass/gemm/kernel/gemm.h"
+
+#include "cutlass_ext/q4gemm/kernel/default_quantb_gemm.h"
+#include "cutlass/gemm/device/default_gemm_configuration.h"
+
+#include "cutlass/layout/permute.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace device {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/*! A specialized GEMM operator for quantized B GEMM.
+
+  It is modified from cutlass::gemm::device::Gemm. Both this class and the original Gemm class
+  are pretty much boilerplate code that construct the Gemm kernel class, and pass parameters
+  and controls to it. The only difference is that this class has a few more template parameters
+  to support quantization.
+
+  This implementation pretty much follows the design of cutlass. But this class seems to be
+  just a wrapper of the Gemm kernel class. Consider combining them in future iterations.
+
+*/
+template <
+    /// Element type for A matrix operand
+    typename ElementA_,
+    /// Layout type for A matrix operand
+    typename LayoutA_,
+    /// Element type for B matrix operand
+    typename ElementB_,
+    /// Layout type for B matrix operand
+    typename LayoutB_,
+    /// Element type for quant scales
+    typename ElementQScale_,
+    /// Element type for quant offsets
+    typename ElementQOffset_,
+    /// Layout type for quant scales and offsets
+    typename LayoutQMeta_,
+    /// Blocking dimensions for quantization
+    typename QuantBlocking_,
+    /// Element type for C and D matrix operands
+    typename ElementC_,
+    /// Layout type for C and D matrix operands
+    typename LayoutC_,
+    /// Element type for internal accumulation
+    typename ElementAccumulator_ = ElementC_,
+    /// Operator class tag
+    typename OperatorClass_ = arch::OpClassSimt,
+    /// Tag indicating architecture to tune for
+    typename ArchTag_ = arch::Sm80,
+    /// Threadblock-level tile size (concept: GemmShape)
+    typename ThreadblockShape_ = typename DefaultGemmConfiguration<
+        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
+        ElementAccumulator_>::ThreadblockShape,
+    /// Warp-level tile size (concept: GemmShape)
+    typename WarpShape_ = typename DefaultGemmConfiguration<
+        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
+        ElementAccumulator_>::WarpShape,
+    /// Instruction-level tile size (concept: GemmShape)
+    typename InstructionShape_ = typename DefaultGemmConfiguration<
+        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
+        ElementAccumulator_>::InstructionShape,
+    /// Epilogue output operator
+    typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
+        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
+        ElementAccumulator_>::EpilogueOutputOp,
+    /// Threadblock-level swizzling operator
+    typename ThreadblockSwizzle_ =
+        typename threadblock::GemmIdentityThreadblockSwizzle<>,
+    /// Number of stages used in the pipelined mainloop
+    int Stages =
+        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
+                                 ElementC_, ElementAccumulator_>::kStages,
+    /// Access granularity of A matrix in units of elements
+    int AlignmentA =
+        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
+                                 ElementC_, ElementAccumulator_>::kAlignmentA,
+    /// Access granularity of B matrix in units of elements
+    int AlignmentB =
+        DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
+                                 ElementC_, ElementAccumulator_>::kAlignmentB,
+    /// If true, kernel supports split-K with serial reduction
+    bool SplitKSerial = false,
+    /// Operation performed by GEMM
+    typename Operator_ = typename DefaultGemmConfiguration<
+        OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
+        ElementAccumulator_>::Operator,
+    /// Gather operand A by using an index array
+    bool GatherA = false,
+    /// Gather operand B by using an index array
+    bool GatherB = false,
+    /// Scatter result D by using an index array
+    bool ScatterD = false,
+    /// Permute result D
+    typename PermuteDLayout = layout::NoPermute>
+class QuantBGemm {
+ public:
+
+  using ElementA = ElementA_;
+  using LayoutA = LayoutA_;
+  using TensorRefA = TensorRef<ElementA const, LayoutA>;
+  using ElementB = ElementB_;
+  using LayoutB = LayoutB_;
+  using TensorRefB = TensorRef<ElementB const, LayoutB>;
+  using ElementC = ElementC_;
+  using LayoutC = LayoutC_;
+  using TensorRefC = TensorRef<ElementC const, LayoutC>;
+  using TensorRefD = TensorRef<ElementC, LayoutC>;
+  using ElementAccumulator = ElementAccumulator_;
+  using OperatorClass = OperatorClass_;
+  using ArchTag = ArchTag_;
+  using ThreadblockShape = ThreadblockShape_;
+  using WarpShape = WarpShape_;
+  using InstructionShape = InstructionShape_;
+  using EpilogueOutputOp = EpilogueOutputOp_;
+  using ThreadblockSwizzle = ThreadblockSwizzle_;
+  using Operator = Operator_;
+  static int const kStages = Stages;
+  static int const kAlignmentA = AlignmentA;
+  static int const kAlignmentB = AlignmentB;
+  static int const kAlignmentC = EpilogueOutputOp::kCount;
+  static bool const kSplitKSerial = SplitKSerial;
+  static ComplexTransform const kTransformA = ComplexTransform::kNone;
+  static ComplexTransform const kTransformB = ComplexTransform::kNone;
+
+  // Quantization Parameters
+  static_assert(std::is_same<LayoutB, layout::ColumnMajor>::value,
+                "LayoutB, i.e. packed weights must appear ColumnMajor.");
+  static_assert(InstructionShape::kK == 16,
+                "InstructionShape::kK must be a multiple of 16 (2 tiles), required by 4b weight packing layout.");
+  using ElementQScale = ElementQScale_;
+  using ElementQOffset = ElementQOffset_;
+  using LayoutQMeta = LayoutQMeta_;
+  using QuantBlocking = QuantBlocking_;
+  static constexpr bool kHasQOffset = !(std::is_same<ElementQOffset, std::monostate>::value);
+
+  // TODO(chenfucn): consider moving to uint4_t or smaller for QOffset
+  static_assert(!kHasQOffset || std::is_same<ElementQOffset_, uint8_t>::value, "QOffset must be uint8_t");
+
+  /// Define the kernel
+  using GemmKernel = typename kernel::DefaultQuantBGemm<
+    ElementA,
+    LayoutA,
+    kAlignmentA,
+    ElementB,
+    LayoutB,
+    kAlignmentB,
+    ElementQScale,
+    ElementQOffset,
+    LayoutQMeta,
+    QuantBlocking,
+    ElementC,
+    LayoutC,
+    ElementAccumulator,
+    OperatorClass,
+    ArchTag,
+    ThreadblockShape,
+    WarpShape,
+    InstructionShape,
+    EpilogueOutputOp,
+    ThreadblockSwizzle,
+    kStages,
+    kSplitKSerial,
+    Operator,
+    GatherA,
+    GatherB,
+    ScatterD,
+    PermuteDLayout
+  >::GemmKernel;
+
+  /// Argument structure
+  struct Arguments {
+    //
+    // Data members
+    //
+
+    GemmCoord problem_size;
+    TensorRef<ElementA const, LayoutA> ref_A;
+    TensorRef<ElementB const, LayoutB> ref_B;
+    TensorRef<ElementC const, LayoutC> ref_C;
+    TensorRef<ElementC, LayoutC> ref_D;
+    TensorRef<ElementQScale const, LayoutQMeta> ref_Qscale;
+    TensorRef<ElementQOffset const, LayoutQMeta> ref_Qoffset;
+
+    typename EpilogueOutputOp::Params epilogue;
+
+    // split-K parallelism (etc.) are not yet supported, keeping this for future extension
+    int split_k_slices{1};
+    // For gather+scatter operations
+    int const *gather_A_indices{nullptr};
+    int const *gather_B_indices{nullptr};
+    int const *scatter_D_indices{nullptr};
+
+    //
+    // Methods
+    //
+
+    /// Default ctor
+    CUTLASS_HOST_DEVICE
+    Arguments(): problem_size(0, 0, 0) {}
+
+    /// Constructs an Arguments structure
+    CUTLASS_HOST_DEVICE
+    Arguments(
+      GemmCoord problem_size_,
+      TensorRef<ElementA const, LayoutA> ref_A_,
+      TensorRef<ElementB const, LayoutB> ref_B_,
+      TensorRef<ElementQScale const, LayoutQMeta> ref_Qscale_,
+      TensorRef<ElementC const, LayoutC> ref_C_,
+      TensorRef<ElementC, LayoutC> ref_D_,
+      typename EpilogueOutputOp::Params epilogue_ =
+        typename EpilogueOutputOp::Params()):
+      problem_size(problem_size_),
+      ref_A(ref_A_),
+      ref_B(ref_B_),
+      ref_Qscale(ref_Qscale_),
+      ref_C(ref_C_),
+      ref_D(ref_D_),
+      epilogue(epilogue_) {
+        assert(!kHasQOffset);
+    }
+
+    CUTLASS_HOST_DEVICE
+    Arguments(
+      GemmCoord problem_size_,
+      TensorRef<ElementA const, LayoutA> ref_A_,
+      TensorRef<ElementB const, LayoutB> ref_B_,
+      TensorRef<ElementQScale const, LayoutQMeta> ref_Qscale_,
+      TensorRef<ElementQOffset const, LayoutQMeta> ref_Qoffset_,
+      TensorRef<ElementC const, LayoutC> ref_C_,
+      TensorRef<ElementC, LayoutC> ref_D_,
+      typename EpilogueOutputOp::Params epilogue_ =
+        typename EpilogueOutputOp::Params()):
+      problem_size(problem_size_),
+      ref_A(ref_A_),
+      ref_B(ref_B_),
+      ref_Qscale(ref_Qscale_),
+      ref_Qoffset(ref_Qoffset_),
+      ref_C(ref_C_),
+      ref_D(ref_D_),
+      epilogue(epilogue_) {
+        assert(kHasQOffset);
+    }
+  };
+
+ private:
+  /// Kernel parameters object
+  typename GemmKernel::Params params_;
+
+ public:
+  /// Constructs the GEMM.
+  QuantBGemm() { }
+
+  /// Determines whether the GEMM can execute the given problem.
+  static Status can_implement(Arguments const &args) {
+
+    if (!kSplitKSerial && args.split_k_slices > 1) {
+      return Status::kErrorInvalidProblem;
+    }
+
+    Status status = GemmKernel::can_implement(
+      args.problem_size,
+      args.ref_A.non_const_ref(),
+      args.ref_B.non_const_ref(),
+      args.ref_Qscale.non_const_ref(),
+      args.ref_Qoffset.non_const_ref(),
+      args.ref_C.non_const_ref(),
+      args.ref_D
+    );
+
+    if (status != Status::kSuccess) {
+      return status;
+    }
+
+    return Status::kSuccess;
+  }
+
+  /// Gets the workspace size
+  static size_t get_workspace_size(Arguments const &args) {
+
+    size_t bytes = 0;
+
+    // Determine grid shape
+    ThreadblockSwizzle threadblock_swizzle;
+
+    cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
+      args.problem_size,
+      {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
+      args.split_k_slices);
+
+    if (kSplitKSerial && args.split_k_slices > 1) {
+
+      bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
+    }
+
+    return bytes;
+  }
+
+  /// Initializes GEMM state from arguments.
+  Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
+
+    // Determine grid shape
+    ThreadblockSwizzle threadblock_swizzle;
+
+    cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
+      args.problem_size,
+      {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
+      args.split_k_slices);
+
+    if (kSplitKSerial) {
+      if (args.split_k_slices > 1) {
+        if (!workspace) {
+          return Status::kErrorWorkspaceNull;
+        }
+
+        size_t bytes = get_workspace_size(args);
+
+        cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
+
+        if (result != cudaSuccess) {
+          return Status::kErrorInternal;
+        }
+      }
+    } else {
+
+      if (args.split_k_slices > 1) {
+        return Status::kErrorInvalidProblem;
+      }
+    }
+
+    // Initialize the Params structure
+    params_ = typename GemmKernel::Params{
+      args.problem_size,
+      grid_shape,
+      args.ref_A.non_const_ref(),
+      args.ref_B.non_const_ref(),
+      args.ref_Qscale.non_const_ref(),
+      args.ref_Qoffset.non_const_ref(),
+      args.ref_C.non_const_ref(),
+      args.ref_D,
+      args.epilogue,
+      static_cast<int *>(workspace),
+      args.gather_A_indices,
+      args.gather_B_indices,
+      args.scatter_D_indices
+    };
+
+    return Status::kSuccess;
+  }
+
+  /// Lightweight update given a subset of arguments
+  Status update(Arguments const &args, void *workspace = nullptr) {
+
+    if (kSplitKSerial && args.split_k_slices > 1) {
+      if (!workspace) {
+        return Status::kErrorWorkspaceNull;
+      }
+    }
+
+    params_.ref_A.reset(args.ref_A.non_const_ref().data());
+    params_.ref_B.reset(args.ref_B.non_const_ref().data());
+    params_.ref_Qscale.reset(args.ref_Qscale.non_const_ref().data());
+    params_.ref_Qoffset.reset(args.ref_Qoffset.non_const_ref().data());
+    params_.ref_C.reset(args.ref_C.non_const_ref().data());
+    params_.ref_D.reset(args.ref_D.data());
+    params_.output_op = args.epilogue;
+    params_.semaphore = static_cast<int *>(workspace);
+
+    return Status::kSuccess;
+  }
+
+  /// Runs the kernel using initialized state.
+  Status run(cudaStream_t stream = nullptr) {
+
+    ThreadblockSwizzle threadblock_swizzle;
+
+    dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
+    dim3 block(GemmKernel::kThreadCount, 1, 1);
+
+    cudaError_t result;
+
+    int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
+
+    if (smem_size >= (48 << 10)) {
+      result = cudaFuncSetAttribute(Kernel<GemmKernel>,
+                                    cudaFuncAttributeMaxDynamicSharedMemorySize,
+                                    smem_size);
+
+      if (result != cudaSuccess) {
+        std::cerr << "Failed to obtain maximum shared memory size " << smem_size << " for kernel: "
+                  << cudaGetErrorString(result) << "\n";
+        return Status::kErrorInternal;
+      }
+    }
+
+    cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
+
+    result = cudaGetLastError();
+
+    return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
+  }
+
+  /// Runs the kernel using initialized state.
+  Status operator()(cudaStream_t stream = nullptr) {
+    return run(stream);
+  }
+
+  /// Runs the kernel using initialized state.
+  Status operator()(
+    Arguments const &args,
+    void *workspace = nullptr,
+    cudaStream_t stream = nullptr) {
+
+    Status status = initialize(args, workspace, stream);
+
+    if (status == Status::kSuccess) {
+      status = run(stream);
+    }
+
+    return status;
+  }
+};
+
+
+////////////////////////////////////////////////////////////////////////////////
+
+} // namespace device
+} // namespace gemm
+} // namespace cutlass
+
+////////////////////////////////////////////////////////////////////////////////
diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h
new file mode 100644
index 0000000000000..2f4460bb59e9f
--- /dev/null
+++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h
@@ -0,0 +1,255 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+/**
+ * Modifications Copyright (c) Microsoft.
+ * Licensed under the MIT license.
+ *
+ * @file default_quantb_gemm.h
+ * @brief Modified from cutlass/gemm/kernel/default_gemm.h. templates for combining
+ *        threadblock-scoped matrix multiply-add with  the appropriate
+ *        threadblock-scoped epilogue.
+ */
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+
+#include "cutlass/layout/matrix.h"
+#include "cutlass/numeric_types.h"
+#include "cutlass/arch/wmma.h"
+
+#include "cutlass/epilogue/threadblock/epilogue.h"
+#include "cutlass/epilogue/thread/linear_combination.h"
+
+#include "cutlass/gemm/gemm.h"
+#include "cutlass_ext/q4gemm/kernel/quantb_gemm.h"
+#include "cutlass/gemm/kernel/gemm_pipelined.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
+#include "cutlass_ext/q4gemm/threadblock/default_quantb_mma.h"
+#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
+#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
+
+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
+#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
+
+#include "cutlass/layout/permute.h"
+
+#if defined(CUTLASS_ARCH_WMMA_ENABLED)
+#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h"
+#endif //CUTLASS_ARCH_WMMA_ENABLED
+
+////////////////////////////////////////////////////////////////////////////////
+namespace cutlass {
+namespace gemm {
+namespace kernel {
+
+////////////////////////////////////////////////////////////////////////////////
+
+template <
+    /// Element type for A matrix operand
+    typename ElementA_,
+    /// Layout type for A matrix operand
+    typename LayoutA_,
+    /// Access granularity of A matrix in units of elements
+    int kAlignmentA,
+    /// Element type for B matrix operand
+    typename ElementB_,
+    /// Layout type for B matrix operand
+    typename LayoutB_,
+    /// Access granularity of B matrix in units of elements
+    int kAlignmentB,
+    /// Element type for quant scales
+    typename ElementQScale_,
+    /// Element type for quant offsets
+    typename ElementQOffset_,
+    /// Layout type for quant scales and offsets
+    typename LayoutQMeta_,
+    /// Blocking dimensions for quantization
+    typename QuantBlocking_,
+    /// Access granularity of quant scales in units of elements
+    typename ElementC_,
+    /// Layout type for C and D matrix operands
+    typename LayoutC_,
+    /// Element type for internal accumulation
+    typename ElementAccumulator,
+    /// Operator class tag
+    typename OperatorClass,
+    /// Tag indicating architecture to tune for
+    typename ArchTag,
+    /// Threadblock-level tile size (concept: GemmShape)
+    typename ThreadblockShape,
+    /// Warp-level tile size (concept: GemmShape)
+    typename WarpShape,
+    /// Warp-level tile size (concept: GemmShape)
+    typename InstructionShape,
+    /// Epilogue output operator
+    typename EpilogueOutputOp,
+    /// Threadblock-level swizzling operator
+    typename ThreadblockSwizzle,
+    /// Number of stages used in the pipelined mainloop
+    int Stages,
+    /// If true, kernel is configured to support serial reduction in the
+    /// epilogue
+    bool SplitKSerial,
+    /// Operation performed by GEMM
+    typename Operator,
+    /// Gather operand A by using an index array
+    bool GatherA = false,
+    /// Gather operand B by using an index array
+    bool GatherB = false,
+    /// Scatter result D by using an index array
+    bool ScatterD = false,
+    /// Permute result D
+    typename PermuteDLayout = layout::NoPermute,
+    /// Permute operand A
+    typename PermuteALayout = layout::NoPermute,
+    /// Permute operand B
+    typename PermuteBLayout = layout::NoPermute,
+    ///
+    typename Enable = void
+>
+struct DefaultQuantBGemm;
+
+////////////////////////////////////////////////////////////////////////////////
+
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Partial specialization for Ampere Architecture
+template <
+    /// Element type for A matrix operand
+    typename ElementA,
+    /// Layout type for A matrix operand
+    typename LayoutA,
+    /// Access granularity of A matrix in units of elements
+    int kAlignmentA,
+    /// Element type for B matrix operand
+    typename ElementB,
+    /// Layout type for B matrix operand
+    typename LayoutB,
+    /// Access granularity of A matrix in units of elements
+    int kAlignmentB,
+    /// Element type for quant scales
+    typename ElementQScale,
+    /// Element type for quant offsets
+    typename ElementQOffset,
+    /// Layout type for quant scales
+    typename LayoutQMeta,
+    /// Blocking dimensions for quantization
+    typename QuantBlocking,
+    /// Access granularity of quant scales in units of elements
+    typename ElementC,
+    /// Layout type for C and D matrix operand
+    typename LayoutC,
+    /// Element type for internal accumulation
+    typename ElementAccumulator,
+    /// Threadblock-level tile size (concept: GemmShape)
+    typename ThreadblockShape,
+    /// Warp-level tile size (concept: GemmShape)
+    typename WarpShape,
+    /// Warp-level tile size (concept: GemmShape)
+    typename InstructionShape,
+    /// Epilogue output operator
+    typename EpilogueOutputOp,
+    /// Threadblock-level swizzling operator
+    typename ThreadblockSwizzle,
+    /// Number of stages used in the pipelined mainloop
+    int Stages,
+    /// If true, kernel is configured to support serial reduction in the
+    /// epilogue
+    bool SplitKSerial,
+    /// Operation performed by GEMM
+    typename Operator,
+    /// Gather operand A by using an index array
+    bool GatherA,
+    /// Gather operand B by using an index array
+    bool GatherB,
+    /// Scatter result D by using an index array
+    bool ScatterD,
+    /// Permute result D
+    typename PermuteDLayout,
+    /// Permute operand A
+    typename PermuteALayout,
+    /// Permute operand B
+    typename PermuteBLayout
+>
+struct DefaultQuantBGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
+                         ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking,
+                         ElementC, LayoutC, ElementAccumulator,
+                         arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape,
+                         InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages,
+                         SplitKSerial, Operator, GatherA, GatherB, ScatterD,
+                         PermuteDLayout, PermuteALayout, PermuteBLayout> {
+
+  static_assert((platform::is_same<LayoutC, layout::RowMajor>::value
+             || platform::is_same<LayoutC, layout::AffineRankN<2>>::value),
+             "Epilogue in the kernel level must be row major");
+
+  /// Define the threadblock-scoped matrix multiply-accumulate
+  using Mma = typename cutlass::gemm::threadblock::DefaultQuantBMma<
+      ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
+      ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking,
+      ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80,
+      ThreadblockShape, WarpShape, InstructionShape, Stages,
+      Operator, false, GatherA, GatherB,
+      PermuteALayout, PermuteBLayout>::ThreadblockMma;
+
+  static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
+
+  /// Define the epilogue
+  using RegularEpilogue =
+      typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
+          ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
+          EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue;
+
+  using Affine2Epilogue =
+      typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpAffineRankN<
+          2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
+          EpilogueOutputOp::kCount>::Epilogue;
+
+  using Epilogue = typename platform::conditional<platform::is_same<LayoutC, layout::RowMajor>::value,
+                                                  RegularEpilogue,
+                                                  Affine2Epilogue>::type;
+
+  /// Define the kernel-level GEMM operator.
+  using GemmKernel = kernel::QuantBGemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+}  // namespace kernel
+}  // namespace gemm
+}  // namespace cutlass
diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h
new file mode 100644
index 0000000000000..6e5ad8f406147
--- /dev/null
+++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h
@@ -0,0 +1,462 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+/**
+ * Modifications Copyright (c) Microsoft.
+ * Licensed under the MIT license.
+ *
+ * @file quantb_gemm.h
+ * @brief Modified from cutlass/gemm/kernel/gemm.h.
+ *        Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
+ */
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/matrix_coord.h"
+#include "cutlass/semaphore.h"
+#include "cutlass/arch/arch.h"
+
+#include "cutlass/util/debug.h"
+#include "cutlass/util/device_dump.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace kernel {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <
+  typename Mma_,                  ///! Threadblock-scoped matrix multiply-accumulate
+  typename Epilogue_,             ///! Epilogue
+  typename ThreadblockSwizzle_,   ///! Threadblock swizzling function
+  bool SplitKSerial               ///! If true, code supporting split-K via serial reduction is enabled.
+>
+struct QuantBGemm {
+
+  using Mma = Mma_;
+  using Epilogue = Epilogue_;
+  using OutputOp = typename Epilogue::OutputOp;
+  using ThreadblockSwizzle = ThreadblockSwizzle_;
+  static bool const kSplitKSerial = SplitKSerial;
+
+  static constexpr bool kHasQOffset = Mma::kHasQOffset;
+
+  /// Warp count (concept: GemmShape)
+  using WarpCount = typename Mma::WarpCount;
+  static int const kThreadCount = 32 * WarpCount::kCount;
+
+  /// Parameters structure
+  struct Params {
+    cutlass::gemm::GemmCoord problem_size;
+    cutlass::gemm::GemmCoord grid_tiled_shape;
+    int swizzle_log_tile;
+    typename Mma::IteratorA::Params params_A;
+    typename Mma::IteratorA::TensorRef ref_A;
+    typename Mma::IteratorB::Params params_B;
+    typename Mma::IteratorB::TensorRef ref_B;
+    typename Mma::IteratorQScale::Params params_QScale;
+    typename Mma::IteratorQScale::TensorRef ref_QScale;
+    typename Mma::IteratorQOffset::Params params_QOffset;
+    typename Mma::IteratorQOffset::TensorRef ref_QOffset;
+    typename Epilogue::OutputTileIterator::Params params_C;
+    typename Epilogue::OutputTileIterator::TensorRef ref_C;
+    typename Epilogue::OutputTileIterator::Params params_D;
+    typename Epilogue::OutputTileIterator::TensorRef ref_D;
+    typename OutputOp::Params output_op;
+    int *semaphore;
+    int gemm_k_size;  // how many k vectors are processed by this threadblock
+    // For gather+scatter operations
+    int const *gather_A_indices;
+    int const *gather_B_indices;
+    int const *scatter_D_indices;
+
+    //
+    // Methods
+    //
+
+    CUTLASS_HOST_DEVICE
+    Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { }
+
+    CUTLASS_HOST_DEVICE
+    Params(
+      cutlass::gemm::GemmCoord const & problem_size,
+      cutlass::gemm::GemmCoord const & grid_tiled_shape,
+      typename Mma::IteratorA::TensorRef ref_A,
+      typename Mma::IteratorB::TensorRef ref_B,
+      typename Mma::IteratorQScale::TensorRef ref_QScale,
+      typename Mma::IteratorQOffset::TensorRef ref_QOffset,
+      typename Epilogue::OutputTileIterator::TensorRef ref_C,
+      typename Epilogue::OutputTileIterator::TensorRef ref_D,
+      typename OutputOp::Params output_op = typename OutputOp::Params(),
+      int *workspace = nullptr,
+      int const *gather_A_indices = nullptr,
+      int const *gather_B_indices = nullptr,
+      int const *scatter_D_indices = nullptr
+    ):
+      problem_size(problem_size),
+      grid_tiled_shape(grid_tiled_shape),
+      swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
+      params_A(ref_A.layout()),
+      ref_A(ref_A),
+      params_B(ref_B.layout()),
+      ref_B(ref_B),
+      params_QScale(ref_QScale.layout()),
+      ref_QScale(ref_QScale),
+      params_QOffset(ref_QOffset.layout()),
+      ref_QOffset(ref_QOffset),
+      params_C(ref_C.layout()),
+      ref_C(ref_C),
+      params_D(ref_D.layout()),
+      ref_D(ref_D),
+      output_op(output_op),
+      gather_A_indices(gather_A_indices),
+      gather_B_indices(gather_B_indices),
+      scatter_D_indices(scatter_D_indices) {
+      int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
+      int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
+
+      gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
+
+      semaphore = workspace;
+    }
+  };
+
+  /// Shared memory storage structure
+  union SharedStorage {
+    typename Mma::SharedStorage main_loop;
+    typename Epilogue::SharedStorage epilogue;
+  };
+
+  //
+  // Methods
+  //
+
+  CUTLASS_HOST_DEVICE
+  QuantBGemm() { }
+
+  /// Determines whether kernel satisfies alignment
+  CUTLASS_HOST_DEVICE
+  static Status can_implement(
+    cutlass::gemm::GemmCoord const & problem_size,
+    typename Mma::IteratorA::TensorRef ref_A,
+    typename Mma::IteratorB::TensorRef ref_B,
+    typename Mma::IteratorQScale::TensorRef ref_QScale,
+    typename Mma::IteratorQOffset::TensorRef ref_QOffset,
+    typename Epilogue::OutputTileIterator::TensorRef ref_C,
+    typename Epilogue::OutputTileIterator::TensorRef ref_D) {
+
+    // TODO check problem_size K, N must be multiple of QuantBlocking
+
+    static int const kAlignmentA = (platform::is_same<typename Mma::IteratorA::Layout,
+                                                      layout::ColumnMajorInterleaved<32>>::value)
+                                   ? 32
+                                   : (platform::is_same<typename Mma::IteratorA::Layout,
+                                                        layout::ColumnMajorInterleaved<64>>::value)
+                                     ? 64
+                                     : Mma::IteratorA::AccessType::kElements;
+    static int const kAlignmentB =  (platform::is_same<typename Mma::IteratorB::Layout,
+                                                       layout::RowMajorInterleaved<32>>::value)
+                                   ? 32
+                                   : (platform::is_same<typename Mma::IteratorB::Layout,
+                                                        layout::RowMajorInterleaved<64>>::value)
+                                     ? 64
+                                     : Mma::IteratorB::AccessType::kElements;
+    static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
+                                                      layout::ColumnMajorInterleaved<32>>::value)
+                                   ? 32
+                                   : (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
+                                                        layout::ColumnMajorInterleaved<64>>::value)
+                                     ? 64
+                                     : Epilogue::OutputTileIterator::kElementsPerAccess;
+
+    if (!TensorRef_aligned(ref_A, kAlignmentA)) {
+      return Status::kErrorMisalignedOperand;
+    }
+
+    if (!TensorRef_aligned(ref_B, kAlignmentB)) {
+      return Status::kErrorMisalignedOperand;
+    }
+
+    if (problem_size.k() % Mma::Shape::kK != 0) {
+      // Currently we don't support this case due to the way
+      // predicate iterator works, it loads the partial tile
+      // in the first iteration and then the full tile in the
+      // remaining iterations. This will cause the blockwise
+      // quantization parameters to go out of step with the
+      // weights. We can fix this by adding a predicate iterator
+      // that loads the full tile in the first iterations and
+      // then the partial tile in the last iteration.
+      return Status::kErrorInvalidProblem;
+    }
+
+    int qscale_k = problem_size.k() / Mma::QuantBlocking::kRow;
+    int qscale_n = problem_size.n() / Mma::QuantBlocking::kColumn;
+    if ((qscale_k == 0) || (qscale_k * Mma::QuantBlocking::kRow != problem_size.k())) {
+      // partial block not supported
+      return Status::kErrorInvalidProblem;
+    }
+    if ((qscale_n == 0) || (qscale_n * Mma::QuantBlocking::kColumn != problem_size.n())) {
+      // partial block not supported
+      return Status::kErrorInvalidProblem;
+    }
+
+    if (!TensorRef_aligned(ref_QScale, Mma::IteratorQScale::AccessType::kElements)) {
+      return Status::kErrorMisalignedOperand;
+    }
+
+    if constexpr(kHasQOffset) {
+      if (!TensorRef_aligned(ref_QOffset, Mma::IteratorQOffset::AccessType::kElements)) {
+        return Status::kErrorMisalignedOperand;
+      }
+    }
+
+    if (!TensorRef_aligned(ref_C, kAlignmentC)) {
+      return Status::kErrorMisalignedOperand;
+    }
+
+    if (!TensorRef_aligned(ref_D, kAlignmentC)) {
+      return Status::kErrorMisalignedOperand;
+    }
+
+    return Status::kSuccess;
+  }
+
+  /// Executes one GEMM
+  CUTLASS_DEVICE
+  void operator()(Params const &params, SharedStorage &shared_storage) {
+
+    // Compute threadblock location
+    ThreadblockSwizzle threadblock_swizzle;
+
+    cutlass::gemm::GemmCoord threadblock_tile_offset =
+        threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
+
+    // Early exit if CTA is out of range
+    if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
+      params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
+
+      return;
+    }
+
+    // Compute initial location in logical coordinates
+    cutlass::MatrixCoord tb_offset_A{
+      threadblock_tile_offset.m() * Mma::Shape::kM,
+      threadblock_tile_offset.k() * params.gemm_k_size,
+    };
+
+    cutlass::MatrixCoord tb_offset_B{
+      (threadblock_tile_offset.k() * params.gemm_k_size) / 2,
+      (threadblock_tile_offset.n() * Mma::Shape::kN) / 2
+    };
+
+    // Problem size is a function of threadblock index in the K dimension
+    int problem_size_k = min(
+      params.problem_size.k(),
+      (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
+
+    // Compute threadblock-scoped matrix multiply-add
+    int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
+
+    // Compute position within threadblock
+    int thread_idx = threadIdx.x;
+
+    // Construct iterators to A and B operands
+    typename Mma::IteratorA iterator_A(
+      params.params_A,
+      params.ref_A.data(),
+      {params.problem_size.m(), problem_size_k},
+      thread_idx,
+      tb_offset_A,
+      params.gather_A_indices);
+
+    typename Mma::IteratorB iterator_B(
+      params.params_B,
+      params.ref_B.data(),
+      {problem_size_k/2, params.problem_size.n()/2},
+      thread_idx,
+      tb_offset_B,
+      params.gather_B_indices);
+
+    const int qscale_k = problem_size_k / Mma::QuantBlocking::kRow;
+    const int qscale_n = params.problem_size.n() / Mma::QuantBlocking::kColumn;
+
+    // should have been verified by can_implement()
+    assert((qscale_k > 0) && (qscale_k * Mma::QuantBlocking::kRow == problem_size_k));
+    assert((qscale_n > 0) && (qscale_n * Mma::QuantBlocking::kColumn == params.problem_size.n()));
+
+    cutlass::MatrixCoord tb_offset_QScale{
+      threadblock_tile_offset.k() * (params.gemm_k_size/Mma::QuantBlocking::kRow),
+      threadblock_tile_offset.n() * (Mma::Shape::kN/Mma::QuantBlocking::kColumn)
+    };
+
+    typename Mma::IteratorQScale iterator_QScale(
+      params.params_QScale,
+      params.ref_QScale.data(),
+      {qscale_k, qscale_n},
+      thread_idx,
+      tb_offset_QScale,
+      nullptr);
+
+    typename Mma::IteratorQOffset iterator_QOffset(
+      params.params_QOffset,
+      params.ref_QOffset.data(),
+      {qscale_k, qscale_n},
+      thread_idx,
+      tb_offset_QScale);
+
+    // Broadcast the warp_id computed by lane 0 to ensure dependent code
+    // is compiled as warp-uniform.
+    const int warp_idx = canonical_warp_idx();
+    const int lane_idx = threadIdx.x % 32;
+
+    //
+    // Main loop
+    //
+
+    // Construct thread-scoped matrix multiply
+    Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
+
+    typename Mma::FragmentC accumulators;
+
+    accumulators.clear();
+
+    if (!kSplitKSerial || gemm_k_iterations > 0) {
+      // Compute threadblock-scoped matrix multiply-add
+      mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_QScale, iterator_QOffset, accumulators);
+    }
+
+    //
+    // Epilogue
+    //
+
+    OutputOp output_op(params.output_op);
+
+    //
+    // Masked tile iterators constructed from members
+    //
+
+    threadblock_tile_offset =
+        threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
+
+    //assume identity swizzle
+    MatrixCoord threadblock_offset(
+      threadblock_tile_offset.m() * Mma::Shape::kM,
+      threadblock_tile_offset.n() * Mma::Shape::kN
+    );
+
+    int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
+
+    // Construct the semaphore.
+    Semaphore semaphore(params.semaphore + block_idx, thread_idx);
+
+    // If performing a reduction via split-K, fetch the initial synchronization
+    if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
+
+      // Fetch the synchronization lock initially but do not block.
+      semaphore.fetch();
+
+      // Indicate which position in a serial reduction the output operator is currently updating
+      output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
+    }
+
+    // Tile iterator loading from source tensor.
+    typename Epilogue::OutputTileIterator iterator_C(
+      params.params_C,
+      params.ref_C.data(),
+      params.problem_size.mn(),
+      thread_idx,
+      threadblock_offset,
+      params.scatter_D_indices
+    );
+
+    // Tile iterator writing to destination tensor.
+    typename Epilogue::OutputTileIterator iterator_D(
+      params.params_D,
+      params.ref_D.data(),
+      params.problem_size.mn(),
+      thread_idx,
+      threadblock_offset,
+      params.scatter_D_indices
+    );
+
+    Epilogue epilogue(
+      shared_storage.epilogue,
+      thread_idx,
+      warp_idx,
+      lane_idx);
+
+    // Wait on the semaphore - this latency may have been covered by iterator construction
+    if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
+
+      // For subsequent threadblocks, the source matrix is held in the 'D' tensor.
+      if (threadblock_tile_offset.k()) {
+        iterator_C = iterator_D;
+      }
+
+      semaphore.wait(threadblock_tile_offset.k());
+
+    }
+
+    // Execute the epilogue operator to update the destination tensor.
+    epilogue(output_op, iterator_D, accumulators, iterator_C);
+
+    //
+    // Release the semaphore
+    //
+
+    if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
+
+      int lock = 0;
+      if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
+
+        // The final threadblock resets the semaphore for subsequent grids.
+        lock = 0;
+      }
+      else {
+        // Otherwise, the semaphore is incremented
+        lock = threadblock_tile_offset.k() + 1;
+      }
+
+      semaphore.release(lock);
+    }
+  }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace kernel
+} // namespace gemm
+} // namespace cutlass
diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h
new file mode 100644
index 0000000000000..0af604f090e1f
--- /dev/null
+++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h
@@ -0,0 +1,248 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+/**
+ * Modifications Copyright (c) Microsoft.
+ * Licensed under the MIT license.
+ *
+ * @file default_quantb_mma.h
+ * @brief Modified from cutlass/gemm/threadblock/default_mma.h.
+ *        Defining global memory data layout and iterators, combinging with mma core and
+ *        pipelined GEMM kernel.
+ */
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include "cutlass/numeric_types.h"
+#include "cutlass/arch/arch.h"
+#include "cutlass/arch/wmma.h"
+
+#include "cutlass/layout/matrix.h"
+#include "cutlass/layout/permute.h"
+#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
+#include "cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h"
+
+#include "cutlass/gemm/gemm.h"
+#include "cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h"
+#include "cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace threadblock {
+
+////////////////////////////////////////////////////////////////////////////////
+
+template <
+    /// Element type for A matrix operand
+    typename ElementA_,
+    /// Layout type for A matrix operand
+    typename LayoutA_,
+    /// Access granularity of A matrix in units of elements
+    int kAlignmentA,
+    /// Element type for B matrix operand
+    typename ElementB_,
+    /// Layout type for B matrix operand
+    typename LayoutB_,
+    /// Access granularity of B matrix in units of elements
+    int kAlignmentB,
+    /// Element type for quant scales
+    typename ElementQScale_,
+    /// Element type for quant offsets
+    typename ElementQOffset_,
+    /// Layout for quant scales and offsets
+    typename LayoutQMeta_,
+    /// Blocking size for quantization
+    typename QuantBlocking_,
+    /// Element type for internal accumulation
+    typename ElementAccumulator_,
+    /// Layout type for C and D matrix operands
+    typename LayoutC_,
+    /// Operator class tag
+    typename OperatorClass_,
+    /// Tag indicating architecture to tune for
+    typename ArchTag_,
+    /// Threadblock-level tile size (concept: GemmShape)
+    typename ThreadblockShape_,
+    /// Warp-level tile size (concept: GemmShape)
+    typename WarpShape_,
+    /// Instruction-level tile size (concept: GemmShape)
+    typename InstructionShape_,
+    /// Number of stages used in the pipelined mainloop
+    int Stages,
+    /// Operation perfomed by GEMM
+    typename Operator,
+    /// Store the accumulators in row major or column major.  Row major is used
+    /// when output layout is interleaved.
+    bool AccumulatorsInRowMajor = false,
+    /// Gather operand A by using an index array
+    bool GatherA = false,
+    /// Gather operand B by using an index array
+    bool GatherB = false,
+    /// Permute operand A
+    typename PermuteALayout = layout::NoPermute,
+    /// Permute operand B
+    typename PermuteBLayout = layout::NoPermute
+    >
+struct DefaultQuantBMma;
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Specialization for row-major output (OperatorClass TensorOp)
+template <
+    /// Element type for A matrix operand
+    typename ElementA,
+    /// Layout type for A matrix operand
+    typename LayoutA,
+    /// Access granularity of A matrix in units of elements
+    int kAlignmentA,
+    /// Element type for B matrix operand
+    typename ElementB,
+    /// Layout type for B matrix operand
+    typename LayoutB,
+    /// Access granularity of B matrix in units of elements
+    int kAlignmentB,
+    /// Element type for quant scales
+    typename ElementQScale,
+    /// Element type for quant offsets
+    typename ElementQOffset,
+    /// Layout for quant scales and offsets
+    typename LayoutQMeta,
+    /// Blocking size for quantization
+    typename QuantBlocking,
+    /// Element type for internal accumulation
+    typename ElementAccumulator,
+    /// Layout type for C and D matrix operand
+    typename LayoutC,
+    /// Tag indicating architecture to tune for
+    typename ArchTag,
+    /// Threadblock-level tile size (concept: GemmShape)
+    typename ThreadblockShape,
+    /// Warp-level tile size (concept: GemmShape)
+    typename WarpShape,
+    /// Instruction-level tile size (concept: GemmShape)
+    typename InstructionShape,
+    /// Number of stages used in the multistage mainloop
+    int Stages,
+    /// Operation perfomed by GEMM
+    typename Operator,
+    /// Gather operand A by using an index array
+    bool GatherA,
+    /// Gather operand B by using an index array
+    bool GatherB,
+    /// Permute operand A
+    typename PermuteALayout,
+    /// Permute operand B
+    typename PermuteBLayout
+    >
+struct DefaultQuantBMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
+                  kAlignmentB, ElementQScale, ElementQOffset,
+                  LayoutQMeta, QuantBlocking,
+                  ElementAccumulator, LayoutC,
+                  arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape,
+                  InstructionShape, Stages, Operator, false,
+                  GatherA, GatherB, PermuteALayout, PermuteBLayout> {
+
+  static_assert(platform::is_same<LayoutC, layout::RowMajor>::value
+             || platform::is_same<LayoutC, layout::AffineRankN<2>>::value,
+             "simt epilogue must be row major");
+
+  static cutlass::arch::CacheOperation::Kind const CacheOpA =
+      ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
+          ? cutlass::arch::CacheOperation::Global
+          : cutlass::arch::CacheOperation::Always;
+
+  static cutlass::arch::CacheOperation::Kind const CacheOpB =
+      ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
+          ? cutlass::arch::CacheOperation::Global
+          : cutlass::arch::CacheOperation::Always;
+
+  // Define the MmaCore components
+  using MmaCore = typename cutlass::gemm::threadblock::DefaultQuantBMmaCore<
+      ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
+      ElementB, LayoutB, ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking,
+      ElementAccumulator, LayoutC, arch::OpClassTensorOp,
+      Stages, Operator, false, CacheOpA, CacheOpB>;
+
+  // Define iterators over tiles from the A operand
+  using ThreadMapA = typename MmaCore::IteratorThreadMapA;
+  using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
+  using IteratorA =
+      cutlass::transform::threadblock::PredicatedTileAccessIterator<
+          cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
+          ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA, PermuteALayout>;
+
+  // Define iterators over tiles from the B operand
+  using ThreadMapB = typename MmaCore::IteratorThreadMapB;
+  using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
+  using IteratorB =
+      cutlass::transform::threadblock::PredicatedTileAccessIterator<
+          cutlass::MatrixShape<ThreadblockShape::kK/2, ThreadblockShape::kN/2>,
+          ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB, PermuteBLayout>;
+
+  // Define iterators over tiles from the quant scales
+  using ThreadMapQScale = typename MmaCore::IteratorThreadMapQScale;
+  using AccessTypeQScale =
+      cutlass::Array<ElementQScale, ThreadMapQScale::kElementsPerAccess>;
+  using IteratorQScale =
+      cutlass::transform::threadblock::PredicatedTileAccessIterator<
+          typename MmaCore::ThreadblockQShape,
+          ElementQScale, LayoutQMeta, 0, ThreadMapQScale, AccessTypeQScale>;
+
+  using ThreadMapQOffset = typename MmaCore::IteratorThreadMapQOffset;
+  using AccessTypeQOffset =
+      cutlass::Array<ElementQOffset, ThreadMapQOffset::kElementsPerAccess>;
+  using IteratorQOffset =
+      cutlass::transform::threadblock::OptionalPredicatedTileAccessIterator<
+            typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQMeta,
+            0, ThreadMapQOffset, AccessTypeQOffset, MmaCore::kThreads>;
+
+  // Define the threadblock-scoped multistage matrix multiply
+  using ThreadblockMma = cutlass::gemm::threadblock::QuantBMmaMultistage<
+      typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
+      MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
+      MmaCore::kCacheOpB, IteratorQScale, typename MmaCore::SmemIteratorQScale,
+      cutlass::arch::CacheOperation::Global, IteratorQOffset,
+      typename MmaCore::SmemIteratorQOffset, cutlass::arch::CacheOperation::Global,
+      ElementAccumulator, LayoutC,
+      typename MmaCore::MmaPolicy, Stages>;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace gemm
+} // namespace cutlass
+
+////////////////////////////////////////////////////////////////////////////////
diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h
new file mode 100644
index 0000000000000..ad322f6505200
--- /dev/null
+++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h
@@ -0,0 +1,340 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+/**
+ * Modifications Copyright (c) Microsoft.
+ * Licensed under the MIT license.
+ *
+ * @file default_quantb_mma_core.h
+ * @brief Modified from cutlass/gemm/threadblock/default_mma_core.h.
+ *        Defining data layout in shared memory, and its iterators.
+ */
+
+#pragma once
+
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+
+#include "cutlass/layout/tensor_op_multiplicand_sm75.h"
+#include "cutlass/layout/tensor_op_multiplicand_sm80.h"
+
+#include "cutlass/gemm/warp/mma_simt_policy.h"
+#include "cutlass/gemm/warp/mma_simt.h"
+#include "cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h"
+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
+
+#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h"
+#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h"
+
+#include "cutlass/matrix_shape.h"
+#include "cutlass/numeric_types.h"
+#include "cutlass/transform/pitch_linear_thread_map.h"
+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h"
+#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h"
+#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h"
+#include "cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h"
+
+#include "cutlass/util/debug.h"
+#include "cutlass/util/device_dump.h"
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace threadblock {
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Template defininng default matrix multiply operators inferred from threadblock tile size,
+/// global memory data layout, and target math instruction.
+template <
+    /// Shape of threadblock-scoped matrix multiply operator
+    typename Shape,
+    /// Shape of warp-level matrix multiply operator
+    typename WarpShape,
+    /// Shape of one matrix production operation (concept: GemmShape)
+    typename InstructionShape,
+    /// Element data type of A operand
+    typename ElementA,
+    /// Layout of operand A
+    typename LayoutA,
+    /// Element data type of B operand
+    typename ElementB,
+    /// Layout of operand B
+    typename LayoutB,
+    /// Element data type of quant scale
+    typename ElementQScale,
+    /// Element data type of quant offset
+    typename ElementQOffset,
+    /// Layout of quant scale
+    typename LayoutQMeta,
+    /// Blocking dimensions for quantization
+    typename QuantBlocking,
+    /// Data type of accumulator
+    typename ElementC,
+    /// Layout of accumulator
+    typename LayoutC,
+    /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp)
+    typename OperatorClass,
+    /// Number of stages
+    int Stages = 2,
+    /// Operation performed by MMA
+    typename Operator = typename platform::conditional<
+        (platform::is_same<OperatorClass,
+                           cutlass::arch::OpClassTensorOp>::value) &&
+            (platform::is_same<ElementA, int8_t>::value ||
+             platform::is_same<ElementA, int4b_t>::value ||
+             platform::is_same<ElementA, uint8_t>::value ||
+             platform::is_same<ElementA, uint4b_t>::value),
+        cutlass::arch::OpMultiplyAddSaturate,
+        cutlass::arch::OpMultiplyAdd>::type,
+    /// Store the accumulators in row major or column major.  Row major is used
+    /// when output layout is interleaved.
+    bool AccumulatorsInRowMajor = false,
+    /// Cache operation of operand A
+    cutlass::arch::CacheOperation::Kind CacheOpA =
+        cutlass::arch::CacheOperation::Global,
+    /// Cache operation of operand B
+    cutlass::arch::CacheOperation::Kind CacheOpB =
+        cutlass::arch::CacheOperation::Global,
+    /// per-element transformation for elements of A
+    ComplexTransform TransformA = ComplexTransform::kNone,
+    /// per-element transformation for elements of B
+    ComplexTransform TransformB = ComplexTransform::kNone,
+    bool IsComplex = false // (is_complex<ElementA>::value || is_complex<ElementB>::value)
+>
+struct DefaultQuantBMmaCore;
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Partial specialization:
+///
+///   A: row-major
+///   B: column-major
+///   Operator: tensor op class
+///
+/// This uses the default warp-level operator given tile sizes
+template <
+    /// Shape of threadblock-scoped matrix multiply operator (concept:
+    /// GemmShape)
+    typename Shape_,
+    /// Shape of warp-level matrix multiply operator (concept: GemmShape)
+    typename WarpShape_,
+    /// Shape of one matrix production operation (concept: GemmShape)
+    typename InstructionShape_,
+    /// Data type of A operand
+    typename ElementA_,
+    /// Data type of B operand
+    typename ElementB_,
+    /// Element data type of quant scale
+    typename ElementQScale_,
+    /// Element data type of quant offset
+    typename ElementQOffset_,
+    /// Layout of quant scale
+    typename LayoutQMeta_,
+    /// Blocking dimensions for quantization
+    typename QuantBlocking_,
+    /// Data type of accumulator
+    typename ElementC_,
+    /// Layout of accumulator
+    typename LayoutC_,
+    /// Number of stages
+    int Stages,
+    /// Operation performed by MMA
+    typename Operator_,
+    /// Cache operation of operand A
+    cutlass::arch::CacheOperation::Kind CacheOpA,
+    /// Cache operation of operand B
+    cutlass::arch::CacheOperation::Kind CacheOpB>
+struct DefaultQuantBMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
+                      layout::RowMajor, ElementB_, layout::ColumnMajor,
+                      ElementQScale_, ElementQOffset_, LayoutQMeta_, QuantBlocking_,
+                      ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
+                      Operator_, false, CacheOpA, CacheOpB> {
+  using Shape = Shape_;
+  using WarpShape = WarpShape_;
+  using InstructionShape = InstructionShape_;
+  using ElementA = ElementA_;
+  using LayoutA = layout::RowMajor;
+  using ElementB = ElementB_;
+  using LayoutB = layout::ColumnMajor;
+
+  using ElementQScale = ElementQScale_;
+  using ElementQOffset = ElementQOffset_;
+  using LayoutQMeta = LayoutQMeta_;
+  using QuantBlocking = QuantBlocking_;
+
+  using ElementC = ElementC_;
+  using LayoutC = LayoutC_;
+  static int const kStages = Stages;
+  static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
+  static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
+
+  /// Number of warps present
+  using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
+                              Shape::kN / WarpShape::kN,
+                              Shape::kK / WarpShape::kK>;
+
+  // Divisility requirements
+  static_assert(
+      !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
+      "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
+
+  /// Number of threads per warp
+  static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
+
+  /// Number of threads total
+  static int const kThreads = WarpCount::kCount * kWarpSize;
+
+  /// Size of a threadblock-scoped access
+  static int const kAccessSizeInBits = 128;
+
+  /// Default Operator
+  using Operator = Operator_;
+
+  // Warp thread arrangement
+  static int const kWarpThreadArrangementContiguousA =
+      Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
+
+  static int const kWarpThreadArrangementStridedA =
+      kWarpSize / kWarpThreadArrangementContiguousA;
+
+  static int const kWarpThreadArrangementContiguousB =
+      (Shape::kK / 2) / (kAccessSizeInBits / sizeof_bits<ElementB>::value);
+
+  static int const kWarpThreadArrangementStridedB =
+      kWarpSize / kWarpThreadArrangementContiguousB;
+
+  //
+  // Shared memory layouts
+  //
+
+  using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise<
+      sizeof_bits<ElementA>::value, Shape::kK>;
+
+  using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise<
+      sizeof_bits<ElementB>::value, Shape::kK/2>;
+
+  //
+  // Iterators to write to shared memory
+  //
+
+  /// ThreadMap of iterator A
+  using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
+      layout::PitchLinearShape<Shape::kK, Shape::kM>, kThreads,
+      layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
+                               kWarpThreadArrangementStridedA>,
+      kAccessSizeInBits / sizeof_bits<ElementA>::value>;
+
+  /// Shared memory iterator to A operand
+  using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
+      MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 0,
+      IteratorThreadMapA>;
+
+  /// ThreadMap of iterator B
+  using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
+      layout::PitchLinearShape<Shape::kK/2, Shape::kN/2>, kThreads,
+      layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
+                               kWarpThreadArrangementStridedB>,
+      kAccessSizeInBits / sizeof_bits<ElementB>::value>;
+
+  /// Shared memory iterator to B operand
+  using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
+      MatrixShape<Shape::kK/2, Shape::kN/2>, ElementB, SmemLayoutB, 1,
+      IteratorThreadMapB>;
+
+  using SmemLayoutQScale = LayoutQMeta;
+  using SmemLayoutQOffset = LayoutQMeta;
+
+  /// Threadblock-level quantization meta data shape
+  using ThreadblockQShape = MatrixShape<Shape::kK / QuantBlocking::kRow, Shape::kN / QuantBlocking::kColumn>;
+  static_assert(Shape::kK % QuantBlocking::kRow == 0, "K must be multiple of QuantBlocking::kRow");
+  static_assert(Shape::kN % QuantBlocking::kColumn == 0, "N must be multiple of QuantBlocking::kColumn");
+  static_assert(ThreadblockQShape::kCount > 0, "QuantBlocking too big to fit in a thread block!");
+  static_assert(QuantBlocking::kRow == 1 || QuantBlocking::kColumn == 1,
+        "Only support single column or row quantize blocking!");
+  static_assert(QuantBlocking::kColumn != 1 || std::is_same<LayoutQMeta, layout::RowMajor>::value,
+        "Quant scale matrix's major dimension must have more elements, to facilitate fast loading!");
+
+  /// Threadblock-level quantization meta data shape in pitch-linear layout
+  using TBQPitchLinearShape = typename std::conditional<
+      std::is_same<LayoutQMeta, layout::RowMajor>::value,
+      layout::PitchLinearShape<ThreadblockQShape::kColumn, ThreadblockQShape::kRow>,
+      layout::PitchLinearShape<ThreadblockQShape::kRow, ThreadblockQShape::kColumn>>::type;
+
+  /// By default we would like to use 128b load. However, we can't load more than
+  /// a column at a time in a column major layout.
+  static int const kElementsPerAccessQScale =
+      (kAccessSizeInBits / sizeof_bits<ElementQScale>::value) > TBQPitchLinearShape::kContiguous
+          ? TBQPitchLinearShape::kContiguous
+          : (kAccessSizeInBits / sizeof_bits<ElementQScale>::value);
+
+  /// quant scale is tiny.  Not all threads are needed.
+  static int const kAccessCntQScale = ThreadblockQShape::kCount / kElementsPerAccessQScale;
+  static int const kThreadsQScale = (kAccessCntQScale > kThreads) ? kThreads : kAccessCntQScale;
+
+  using IteratorThreadMapQScale = transform::PitchLinearStripminedThreadMap<
+      TBQPitchLinearShape, kThreadsQScale, kElementsPerAccessQScale>;
+
+  using SmemIteratorQScale = transform::threadblock::RegularTileAccessIterator<
+        ThreadblockQShape, ElementQScale, SmemLayoutQScale, 1, IteratorThreadMapQScale>;
+
+  static int const kElementsPerAccessQOffset =
+      (kAccessSizeInBits / sizeof_bits<ElementQOffset>::value) > TBQPitchLinearShape::kContiguous
+          ? TBQPitchLinearShape::kContiguous
+          : (kAccessSizeInBits / sizeof_bits<ElementQOffset>::value);
+  static int const kAccessCntQOffset = ThreadblockQShape::kCount / kElementsPerAccessQOffset;
+  static int const kThreadsQOffset = (kAccessCntQOffset > kThreads) ? kThreads : kAccessCntQOffset;
+
+  using IteratorThreadMapQOffset = transform::PitchLinearStripminedThreadMap<
+      TBQPitchLinearShape, kThreadsQOffset, kElementsPerAccessQOffset>;
+
+  using SmemIteratorQOffset = transform::threadblock::OptionalRegularTileAccessIterator<
+        ThreadblockQShape, ElementQOffset, SmemLayoutQOffset, 1, IteratorThreadMapQOffset, kThreads>;
+
+  //
+  // Warp-level matrix multiply operator
+  //
+
+  // Define the warp-level tensor op
+  using MmaTensorOp = typename cutlass::gemm::warp::DefaultQuantBMmaTensorOp<
+      WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
+      ElementQScale, SmemLayoutQScale, ElementQOffset, SmemLayoutQScale, QuantBlocking,
+      ElementC, LayoutC, Operator, WarpCount::kK>::Type;
+
+  /// Policy used to define MmaPipelined
+  using MmaPolicy = MmaPolicy<MmaTensorOp, MatrixShape<0, 0>,
+                                        MatrixShape<0, 0>, WarpCount::kK>;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+}  // namespace threadblock
+}  // namespace gemm
+}  // namespace cutlass
diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h
new file mode 100644
index 0000000000000..6f27a692a3a2e
--- /dev/null
+++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h
@@ -0,0 +1,314 @@
+/**
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the MIT license.
+ *
+ * @file optional_predicated_tile_access_iter.h
+ * @brief Templates for loading and storing optional tiles of matrix data.
+ *   This iterator is just a wrapper of PredicatedTileAccessIterator, with
+ *   the option to turn it off at compile time and minimize its runtime
+ *   footprint. Also, it utilize the higher numbered threads in the
+ *   threadblock when  the iterator can not utilize all the threads.
+ */
+
+#pragma once
+
+#include <variant>
+
+#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace transform {
+namespace threadblock {
+
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Optional 2-D matrix data loader, when element is std::monostate, the
+/// iterator becomes no-op with minimal runtime footprint. Also, it utilize the
+/// higher numbered threads in the threadblock when the iterator can not utilize
+/// all the threads.
+///
+template <
+    /// Tile shape of the iterator
+    typename Shape_,
+    /// Element data type of the iterator, no-op when it is std::monostate
+    typename Element_,
+    /// Layout of the source matrix
+    typename Layout_,
+    int AdvanceRank_,
+    typename ThreadMap_,
+    typename AccessType_,
+    /// Number of threads in the threadblock, when provided, the iterator
+    /// will utilize the higher numbered threads
+    int kThreadBlockSize_ = -1>
+class OptionalPredicatedTileAccessIterator{
+ public:
+
+  using Shape = Shape_;
+  using Element = Element_;
+  using Layout = Layout_;
+  using ThreadMap = ThreadMap_;
+  using AccessType = AccessType_;
+
+  static constexpr int kAdvanceRank = AdvanceRank_;
+  static constexpr int kThreadblockSize = kThreadBlockSize_;
+
+  static_assert(!std::is_same<Element, std::monostate>::value,
+      "Disabled Iterator failed to match the specialized version below.");
+  static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads,
+      "kThreadblockSize must be no smaller than ThreadMap::kThreads");
+
+  using Base = PredicatedTileAccessIterator<Shape, Element, Layout, kAdvanceRank, ThreadMap, AccessType>;
+
+  using LongIndex = typename Base::LongIndex;
+  using Mask = typename Base::Mask;
+  using TensorCoord = typename Base::TensorCoord;
+  using TensorRef = typename Base::TensorRef;
+  using Params = typename Base::Params;
+  using Pointer = typename Base::Pointer;
+
+  static constexpr int kAccessesPerVector = Base::kAccessesPerVector;
+
+  CUTLASS_HOST_DEVICE
+  static int flip_thread_id(int thread_id){
+    if constexpr (kThreadblockSize > 0) {
+      return kThreadblockSize - 1 - thread_id;
+    }
+    return thread_id;
+  }
+
+ public:
+   Base base_;
+
+  /// Default constructor
+  OptionalPredicatedTileAccessIterator(): base_() {};
+
+  /// Constructs a TileIterator from its precomputed state, threadblock offset,
+  /// and thread ID
+  CUTLASS_HOST_DEVICE
+  OptionalPredicatedTileAccessIterator(
+      /// Precomputed parameters object
+      Params const &params,
+      /// Pointer to start of tensor
+      Pointer pointer,
+      /// Extent of tensor
+      TensorCoord extent,
+      /// ID of each participating thread
+      int thread_id,
+      /// Initial offset of threadblock
+      TensorCoord const &threadblock_offset)
+      : base_(params, pointer, extent, flip_thread_id(thread_id), threadblock_offset) {}
+
+  /// Construct a PredicatedTileAccessIterator with zero threadblock offset
+  CUTLASS_HOST_DEVICE
+  OptionalPredicatedTileAccessIterator(
+      /// Precomputed parameters object
+      Params const &params,
+      /// Pointer to start of tensor
+      Pointer pointer,
+      /// Extent of tensor
+      TensorCoord extent,
+      ///< ID of each participating thread
+      int thread_id)
+      : OptionalPredicatedTileAccessIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) {}
+
+  /// Overrides the internal iteration index
+  CUTLASS_HOST_DEVICE
+  void set_iteration_index(int index) {
+    base_.set_iteration_index(index);
+  }
+
+  /// Adds a pointer offset in units of Element
+  CUTLASS_HOST_DEVICE
+  void add_pointer_offset(LongIndex pointer_offset) {
+    base_.add_pointer_offset(pointer_offset);
+  }
+
+  /// Advances an iterator along logical dimensions of matrix in units of whole tiles
+  CUTLASS_DEVICE
+  void add_tile_offset(
+      TensorCoord const &tile_offset) {
+    base_.add_tile_offset(tile_offset);
+  }
+
+  /// Returns a pointer
+  CUTLASS_HOST_DEVICE
+  AccessType *get() const {
+    return base_.get();
+  }
+
+  /// Increment and return an instance to self.
+  CUTLASS_HOST_DEVICE
+  OptionalPredicatedTileAccessIterator &operator++() {
+    ++base_;
+    return *this;
+  }
+
+  /// Increment and return an instance to self.
+  CUTLASS_HOST_DEVICE
+  OptionalPredicatedTileAccessIterator operator++(int) {
+    OptionalPredicatedTileAccessIterator self(*this);
+    operator++();
+    return self;
+  }
+
+  /// Clears the predicate set efficiently
+  CUTLASS_HOST_DEVICE
+  void clear_mask(bool enable = true) {
+    base_.clear_mask(enable);
+  }
+
+  /// Clears the predicate set efficiently
+  CUTLASS_HOST_DEVICE
+  void enable_mask() {
+    base_.enable_mask();
+  }
+
+  /// Sets the predicate mask, overriding value stored in predicate iterator
+  CUTLASS_HOST_DEVICE
+  void set_mask(Mask const &mask) {
+    base_.set_mask(mask);
+  }
+
+  /// Gets the mask
+  CUTLASS_HOST_DEVICE
+  void get_mask(Mask &mask) {
+    base_.get_mask(mask);
+  }
+
+  /// Returns whether access is valid or not
+  CUTLASS_HOST_DEVICE
+  bool valid() {
+    return base_.valid();
+  }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Specialization for the disabled version
+/// Reduce runtime overhead
+///
+template <
+    /// Tile shape of the iterator
+    typename Shape_,
+    typename Layout_,
+    int AdvanceRank_,
+    typename ThreadMap_,
+    typename AccessType_,
+    int kThreadBlockSize_>
+class OptionalPredicatedTileAccessIterator<Shape_, std::monostate, Layout_, AdvanceRank_, ThreadMap_, AccessType_, kThreadBlockSize_>{
+ public:
+
+  using Shape = Shape_;
+  using Element = std::monostate;
+  using Layout = Layout_;
+  static int const kAdvanceRank = AdvanceRank_;
+  using ThreadMap = ThreadMap_;
+  using AccessType = AccessType_;
+
+  static constexpr int kThreadblockSize = kThreadBlockSize_;
+
+  using Base = PredicatedTileAccessIterator<Shape, Element, Layout, kAdvanceRank, ThreadMap, AccessType>;
+
+  using LongIndex = typename Base::LongIndex;
+  using Mask = typename Base::Mask;
+  using TensorCoord = typename Base::TensorCoord;
+  using TensorRef = typename Base::TensorRef;
+  using Params = typename Base::Params;
+  using Pointer = typename Base::Pointer;
+
+  static constexpr int kAccessesPerVector = Base::kAccessesPerVector;
+
+ public:
+  std::monostate base_;
+
+  /// Default constructor
+  OptionalPredicatedTileAccessIterator(): base_() {};
+
+  /// Constructs a TileIterator from its precomputed state, threadblock offset,
+  /// and thread ID
+  CUTLASS_HOST_DEVICE
+  OptionalPredicatedTileAccessIterator(
+      /// Precomputed parameters object
+      Params const &params,
+      /// Pointer to start of tensor
+      Pointer pointer,
+      /// Extent of tensor
+      TensorCoord extent,
+      /// ID of each participating thread
+      int thread_id,
+      /// Initial offset of threadblock
+      TensorCoord const &threadblock_offset)
+      : base_() {}
+
+  /// Construct a PredicatedTileAccessIterator with zero threadblock offset
+  CUTLASS_HOST_DEVICE
+  OptionalPredicatedTileAccessIterator(
+      /// Precomputed parameters object
+      Params const &params,
+      /// Pointer to start of tensor
+      Pointer pointer,
+      /// Extent of tensor
+      TensorCoord extent,
+      ///< ID of each participating thread
+      int thread_id)
+      : base_() {}
+
+  /// Overrides the internal iteration index
+  CUTLASS_HOST_DEVICE
+  void set_iteration_index(int index) {}
+
+  /// Adds a pointer offset in units of Element
+  CUTLASS_HOST_DEVICE
+  void add_pointer_offset(LongIndex pointer_offset) {}
+
+  /// Advances an iterator along logical dimensions of matrix in units of whole tiles
+  CUTLASS_DEVICE
+  void add_tile_offset(
+      TensorCoord const &tile_offset) {}
+
+  /// Returns a pointer
+  CUTLASS_HOST_DEVICE
+  AccessType *get() const {
+    return nullptr;
+  }
+
+  /// Increment and return an instance to self.
+  CUTLASS_HOST_DEVICE
+  OptionalPredicatedTileAccessIterator &operator++() {
+    return *this;
+  }
+
+  /// Increment and return an instance to self.
+  CUTLASS_HOST_DEVICE
+  OptionalPredicatedTileAccessIterator operator++(int) {
+    return *this;
+  }
+
+  /// Clears the predicate set efficiently
+  CUTLASS_HOST_DEVICE
+  void clear_mask(bool enable = true) {}
+
+  /// Clears the predicate set efficiently
+  CUTLASS_HOST_DEVICE
+  void enable_mask() {}
+
+  /// Sets the predicate mask, overriding value stored in predicate iterator
+  CUTLASS_HOST_DEVICE
+  void set_mask(Mask const &mask) {}
+
+  /// Gets the mask
+  CUTLASS_HOST_DEVICE
+  void get_mask(Mask &mask) {}
+
+  /// Returns whether access is valid or not
+  CUTLASS_HOST_DEVICE
+  bool valid() const { return false; }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+}  // namespace threadblock
+}  // namespace transform
+}  // namespace cutlass
diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h
new file mode 100644
index 0000000000000..4b0ae5317f8bb
--- /dev/null
+++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h
@@ -0,0 +1,224 @@
+/**
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the MIT license.
+ *
+ * @file optional_regular_tile_access_iter.h
+ * @brief Templates implementing the address computation of storing of tiles
+ *   from pitch-linear rank=2 tensors.
+ *
+ *   This iterator is just a wrapper of RegularTileAccessIterator, with the
+ *   option to turn it off at compile time and minimize its runtime footprint.
+ *   Also, it utilize the higher numbered threads in the threadblock when the
+ *   iterator can not utilize all the threads.
+ *
+ *   Must be used in conjunction with OptionalPredicatedTileAccessIterator,
+ *   with the same template parameters.
+ */
+
+#pragma once
+
+#include <variant>
+
+#include "cutlass/cutlass.h"
+#include "cutlass/numeric_types.h"
+#include "cutlass/transform/threadblock/regular_tile_access_iterator.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace transform {
+namespace threadblock {
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Optional 2-D tile iterator, when element is std::monostate, the iterator
+/// becomes no-op with minimal runtime footprint. Also, it utilize the higher
+/// numbered threads in the threadblock when the iterator can not utilize all
+/// the threads.
+///
+template <
+    /// Tile shape of the iterator
+    typename Shape_,
+    typename Element_,
+    typename Layout_,
+    int AdvanceRank,
+    typename ThreadMap_,
+    /// Number of threads in the threadblock, when not -1, the iterator
+    /// will utilize the higher numbered threads
+    int ThreadblockSize_ = -1,
+    int Alignment =
+        sizeof_bits<Element_>::value * ThreadMap_::kElementsPerAccess / 8>
+class OptionalRegularTileAccessIterator{
+ public:
+
+  using Shape = Shape_;
+  using Element = Element_;
+  using Layout = Layout_;
+  using ThreadMap = ThreadMap_;
+  static constexpr int kAlignment = Alignment;
+  static constexpr int kThreadblockSize = ThreadblockSize_;
+
+  static_assert(!std::is_same<Element, std::monostate>::value,
+      "Disabled Iterator failed to match the specialized template");
+  static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads,
+      "kThreadblockSize must be no smaller than ThreadMap::kThreads");
+
+  using Base = RegularTileAccessIterator<Shape, Element, Layout, AdvanceRank, ThreadMap, Alignment>;
+
+  using LongIndex = typename Base::LongIndex;
+  using TensorRef = typename Base::TensorRef;
+  using TensorCoord = typename Base::TensorCoord;
+  using AccessType = typename Base::AccessType;
+
+  CUTLASS_HOST_DEVICE
+  static int flip_thread_id(int thread_id){
+    if constexpr (kThreadblockSize > 0) {
+      return kThreadblockSize - 1 - thread_id;
+    }
+    return thread_id;
+  }
+
+ private:
+
+  Base base_;
+
+ public:
+  /// Construct a TileIterator with zero threadblock offset
+  CUTLASS_HOST_DEVICE
+  OptionalRegularTileAccessIterator(TensorRef ref,  ///< Pointer to start of tensor
+                            int thread_id   ///< ID of each participating thread
+                            )
+      : base_(ref, flip_thread_id(thread_id)) {}
+
+  /// Overrides the internal iteration index
+  CUTLASS_HOST_DEVICE
+  void set_iteration_index(int index) {
+    base_.set_iteration_index(index);
+  }
+
+  /// Adds a pointer offset in units of Element
+  CUTLASS_HOST_DEVICE
+  void add_pointer_offset(LongIndex pointer_offset) {
+    base_.add_pointer_offset(pointer_offset);
+  }
+
+  /// Returns a pointer
+  CUTLASS_DEVICE
+  AccessType *get() const {
+    return base_.get();
+  }
+
+  /// Advances to the next tile in memory.
+  CUTLASS_HOST_DEVICE
+  OptionalRegularTileAccessIterator &operator++() {
+    ++base_;
+    return *this;
+  }
+
+  /// Advances to the next tile in memory.
+  CUTLASS_HOST_DEVICE
+  OptionalRegularTileAccessIterator operator++(int) {
+    RegularTileAccessIterator prev(*this);
+    this->operator++();
+
+    return prev;
+  }
+
+  /// Adds a tile offset in the unit of tile.
+  /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory.
+  /// Below layouts are the shared memory layouts.  Current SM50 SIMT kernels only use col major A and row major B.
+  ///   For row major A operand, k dimension is contiguous dimension;
+  ///   For col major A operand, k dimension is strided dimension;
+  ///   For row major B operand, k dimension is strided dimension;
+  ///   For col major B operand, k dimension is contiguous dimension.
+  /// Below two classes map col/row major to the pitch linear coordinates used
+  /// in this base class.
+  CUTLASS_DEVICE
+  void add_tile_offset(TensorCoord const &coord) {
+    base_.add_tile_offset(coord);
+  }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Specialization when Element is std::monostate, the iterator becomes no-op
+///
+template <
+    typename Shape_,
+    typename Layout_,
+    int AdvanceRank,
+    typename ThreadMap_,
+    int ThreadblockSize_,
+    int Alignment>
+class OptionalRegularTileAccessIterator<Shape_, std::monostate, Layout_,
+    AdvanceRank, ThreadMap_, ThreadblockSize_, Alignment>{
+ public:
+
+  using Shape = Shape_;
+  using Element = std::monostate;
+  using Layout = Layout_;
+  using ThreadMap = ThreadMap_;
+  static constexpr int kAlignment = Alignment;
+  static constexpr int kThreadblockSize = ThreadblockSize_;
+
+  using Base = RegularTileAccessIterator<Shape, Element, Layout, AdvanceRank, ThreadMap, Alignment>;
+
+  using LongIndex = typename Base::LongIndex;
+  using TensorRef = typename Base::TensorRef;
+  using TensorCoord = typename Base::TensorCoord;
+  using AccessType = typename Base::AccessType;
+
+ private:
+
+  std::monostate base_;
+
+ public:
+  /// Construct a TileIterator with zero threadblock offset
+  CUTLASS_HOST_DEVICE
+  OptionalRegularTileAccessIterator(TensorRef ref,  ///< Pointer to start of tensor
+                            int thread_id   ///< ID of each participating thread
+                            )
+      : base_() {}
+
+  /// Overrides the internal iteration index
+  CUTLASS_HOST_DEVICE
+  void set_iteration_index(int index) {}
+
+  /// Adds a pointer offset in units of Element
+  CUTLASS_HOST_DEVICE
+  void add_pointer_offset(LongIndex pointer_offset) {}
+
+  /// Returns a pointer
+  CUTLASS_DEVICE
+  AccessType *get() const {
+    return nullptr;
+  }
+
+  /// Advances to the next tile in memory.
+  CUTLASS_HOST_DEVICE
+  OptionalRegularTileAccessIterator &operator++() {
+    return *this;
+  }
+
+  /// Advances to the next tile in memory.
+  CUTLASS_HOST_DEVICE
+  OptionalRegularTileAccessIterator operator++(int) {
+    return *this;
+  }
+
+  /// Adds a tile offset in the unit of tile.
+  /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory.
+  /// Below layouts are the shared memory layouts.  Current SM50 SIMT kernels only use col major A and row major B.
+  ///   For row major A operand, k dimension is contiguous dimension;
+  ///   For col major A operand, k dimension is strided dimension;
+  ///   For row major B operand, k dimension is strided dimension;
+  ///   For col major B operand, k dimension is contiguous dimension.
+  /// Below two classes map col/row major to the pitch linear coordinates used
+  /// in this base class.
+  CUTLASS_DEVICE
+  void add_tile_offset(TensorCoord const &coord) {}
+};
+
+}  // namespace threadblock
+}  // namespace transform
+}  // namespace cutlass
diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h
new file mode 100644
index 0000000000000..8b6bac8c5099a
--- /dev/null
+++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h
@@ -0,0 +1,1290 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+/**
+ * Modifications Copyright (c) Microsoft.
+ * Licensed under the MIT license.
+ *
+ * @file quantb_mma_multistage.h
+ * @brief Modified from cutlass/gemm/threadblock/mma_multistage.h.
+ * Added the quantized data memory pipeline, dequantization, and feeding
+ * to tensor cores. Mainloop pipeline is heavily modified.
+ */
+
+#pragma once
+
+#include "cutlass/aligned_buffer.h"
+#include "cutlass/arch/memory.h"
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/matrix_shape.h"
+#include "cutlass/numeric_types.h"
+
+#include "cutlass/gemm/threadblock/mma_base.h"
+
+#include "cutlass/util/debug.h"
+#include "cutlass/util/device_dump.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+namespace{
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+/// Utilities for printing layout for the prepacked weights and quantization parameters
+///
+template<
+    /// Data type of the prepacked weights
+    typename ElementWeight,
+    /// Data type of the quant scales
+    typename ElementQScale,
+    /// Data type of the quant offsets
+    typename ElementQOffset>
+struct QuantBLayoutDebug{
+  static constexpr bool debug_smem = true;
+  static constexpr bool debug_fragment = true;
+  ElementWeight* smem_b_ptr_;
+  ElementQScale* smem_qscale_ptr_;
+  ElementQOffset* smem_qoffset_ptr_;
+  int warp_id_;
+  int lane_id_;
+  int block_id_;
+
+  template<typename Element, int Size>
+  CUTLASS_DEVICE
+  static void print_fragment(cutlass::Array<Element, Size> const& frag, char label, int block_id, int warp_id, int lane_id){
+    static_assert(Size % 4 == 0, "Size must be multiple of 4");
+    if constexpr (debug_fragment){
+      if (block_id == 1 && warp_id == 0){
+        const Element* ptr = reinterpret_cast<const Element*>(&frag);
+        for (int i = 0; i < Size/4; i++, ptr+=4){
+          if constexpr(std::is_integral<Element>::value){
+            printf("T%.2d%c%d, %3d, %3d, %3d, %3d\n",
+                   threadIdx.x, label, i,
+                   ptr[0], ptr[1], ptr[2], ptr[3]);
+          } else {
+            printf("T%.2d%c%d, %.3f, %.3f, %.3f, %.3f\n",
+                   threadIdx.x, label, i,
+                   float(ptr[0]), float(ptr[1]), float(ptr[2]), float(ptr[3]));
+          }
+        }
+      }
+    }
+  }
+
+  template<typename Element, int Size>
+  CUTLASS_DEVICE
+  static void print_as_int4(cutlass::Array<Element, Size> const& frag, char label, int block_id, int warp_id, int lane_id){
+    constexpr int I8Size = Size * cutlass::sizeof_bits<Element>::value / 8;
+    static_assert(I8Size % 2 == 0, "Size must be multiple of 4");
+    if constexpr (debug_fragment){
+      if (block_id == 1 && warp_id == 0){
+        const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&frag);
+        for (int i = 0; i < I8Size/2; i++, ptr+=2){
+          printf("T%.2dW%d, %d, %d, %d, %d\n", threadIdx.x, i, ptr[0] & 0x0f, ptr[0] >> 4, ptr[1] & 0x0f, ptr[1] >> 4);
+        }
+      }
+    }
+  }
+
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Dummy type when quant offset is not used, to avoid compilation error,
+/// and reduce runtime footprint
+///
+struct DummyType{
+  std::monostate dummy_;
+ public:
+  DummyType() = default;
+
+  CUTLASS_HOST_DEVICE
+  void* data() const {
+    return nullptr;
+  }
+
+  CUTLASS_HOST_DEVICE
+  std::monostate& operator[](int idx) {
+    return dummy_;
+  }
+};
+
+}
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace threadblock {
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Structure to compute the matrix product targeting CUDA cores and SIMT math
+/// instructions.
+template <
+    /// Size of the Gemm problem - concept: gemm::GemmShape<>
+    typename Shape_,
+    /// Policy describing tuning details (concept: MmaPolicy)
+    typename Policy_,
+    /// Number of stages,
+    int Stages,
+    /// Used for partial specialization
+    typename Enable = bool>
+class QuantBMmaBase {
+ public:
+  ///< Size of the Gemm problem - concept: gemm::GemmShape<>
+  using Shape = Shape_;
+
+  ///< Policy describing tuning details
+  using Policy = Policy_;
+
+  //
+  // Dependent types
+  //
+
+  /// Warp-level Mma
+  using Operator = typename Policy::Operator;
+
+  /// Shape describing the overall GEMM computed from shared memory
+  /// by each warp.
+  using WarpGemm = typename Policy::Operator::Shape;
+
+  /// Shape describing the number of warps filling the CTA
+  using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
+                              Shape::kN / WarpGemm::kN,
+                              Shape::kK / WarpGemm::kK>;
+
+  /// Number of warp-level GEMM oeprations
+  static int const kWarpGemmIterations =
+      (WarpGemm::kK / Operator::Policy::MmaShape::kK);
+
+  /// Number of stages
+  static int const kStages = Stages;
+
+  static constexpr bool kHasQOffset = !std::is_same<typename Operator::ElementQOffset, std::monostate>::value;
+
+  /// Tensor reference to the A operand
+  using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
+
+  /// Tensor reference to the prepacked weights
+  using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
+
+  static_assert(kWarpGemmIterations > 1,
+                "The pipelined structure requires at least two warp-level "
+                "GEMM operations.");
+
+  static_assert((kWarpGemmIterations % 2) == 0,
+                "Inner loop iteration must be an even number.");
+
+  // Tensor reference to the quantization scales
+  using TensorRefQScale = TensorRef<typename Operator::ElementQScale, typename Operator::SmemLayoutQScale>;
+  using TensorRefQOffset = TensorRef<typename Operator::ElementQOffset, typename Operator::SmemLayoutQOffset>;
+
+  // Block size of the quantization (one set of quantization parameters per block of weights)
+  using QuantBlocking = typename Operator::QuantBlocking;
+
+  //
+  // Nested structs
+  //
+
+  /// Shared storage object needed by threadblock-scoped GEMM
+  class SharedStorage {
+   public:
+    //
+    // Type definitions
+    //
+
+    /// Shape of the A matrix operand in shared memory
+    using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
+                               Shape::kK * kStages +
+                                   Policy::SmemPaddingA::kColumn>;
+
+    /// Shape of the prepacked weights in shared memory
+    using ShapeB =
+        MatrixShape<Shape::kK / 2 * kStages + Policy::SmemPaddingB::kRow,
+                    Shape::kN / 2 + Policy::SmemPaddingB::kColumn>;
+
+    /// Shape of the quantization parameter matrix in shared memory
+    /// Validation done in mma core class ThreadblockQShape
+    using ShapeQScale =
+        MatrixShape<(Shape::kK / QuantBlocking::kRow) * kStages,
+                    Shape::kN / QuantBlocking::kColumn>;
+
+    using BufTypeQOffset = std::conditional_t<kHasQOffset,
+          AlignedBuffer<typename Operator::ElementQOffset, ShapeQScale::kCount>,
+          DummyType>;
+   public:
+    //
+    // Data members
+    //
+
+    /// Buffer for A operand
+    AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
+
+    /// Buffer for prepacked weights
+    AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
+
+    /// Buffer for quantization scales
+    AlignedBuffer<typename Operator::ElementQScale, ShapeQScale::kCount> operand_QScale;
+
+    /// Buffer for quantization offsets
+    BufTypeQOffset operand_QOffset;
+
+   public:
+
+    //
+    // Methods
+    //
+
+    /// Returns a layout object for the A matrix
+    CUTLASS_DEVICE
+    static typename Operator::LayoutA LayoutA() {
+      return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
+    }
+
+    /// Returns a layout object for the B matrix
+    CUTLASS_HOST_DEVICE
+    static typename Operator::LayoutB LayoutB() {
+      return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
+    }
+
+    CUTLASS_HOST_DEVICE
+    static typename Operator::SmemLayoutQScale LayoutQMeta() {
+      return Operator::SmemLayoutQScale::packed({ShapeQScale::kRow, ShapeQScale::kColumn});
+    }
+
+    CUTLASS_HOST_DEVICE
+    static typename Operator::SmemLayoutQOffset LayoutQOffset() {
+      return Operator::SmemLayoutQOffset::packed({ShapeQScale::kRow, ShapeQScale::kColumn});
+    }
+
+    /// Returns a TensorRef to the A operand
+    CUTLASS_HOST_DEVICE
+    TensorRefA operand_A_ref() {
+      return TensorRefA{operand_A.data(), LayoutA()};
+    }
+
+    /// Returns a TensorRef to the prepacked weights
+    CUTLASS_HOST_DEVICE
+    TensorRefB operand_B_ref() {
+      return TensorRefB{operand_B.data(), LayoutB()};
+    }
+
+    /// Returns a TensorRef to the quantization scales
+    CUTLASS_HOST_DEVICE
+    TensorRefQScale operand_QScale_ref() {
+      return TensorRefQScale{operand_QScale.data(), LayoutQMeta()};
+    }
+
+    CUTLASS_HOST_DEVICE
+    TensorRefQOffset operand_QOffset_ref() {
+      if constexpr (!kHasQOffset){
+        return TensorRefQOffset();
+      } else {
+        return TensorRefQOffset{operand_QOffset.data(), LayoutQOffset()};
+      }
+    }
+  };
+
+ protected:
+
+  //
+  // Data members
+  //
+
+  /// Iterator to load a warp-scoped tile of A operand from shared memory
+  typename Operator::IteratorA warp_tile_iterator_A_;
+
+  /// Iterator to load a warp-scoped tile of B operand from shared memory
+  typename Operator::IteratorB warp_tile_iterator_B_;
+
+  /// Iterator to load a warp-scoped tile of quant scales from shared memory
+  typename Operator::IteratorQMeta warp_tile_iterator_QScale_;
+
+public:
+
+  /// Construct from tensor references
+  CUTLASS_DEVICE
+  QuantBMmaBase(
+      ///< Shared storage needed for internal use by threadblock-scoped GEMM
+      SharedStorage &shared_storage,
+      ///< ID within the threadblock
+      int thread_idx,
+      ///< ID of warp
+      int warp_idx,
+      ///< ID of each thread within a warp
+      int lane_idx
+    ):
+      warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
+      warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx),
+      warp_tile_iterator_QScale_(shared_storage.operand_QScale_ref(),
+             shared_storage.operand_QOffset_ref(), lane_idx)
+  {}
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Structure to compute the matrix product targeting CUDA cores and SIMT math
+/// instructions.
+template <
+    /// Size of the Gemm problem - concept: gemm::GemmShape<>
+    typename Shape_,
+    /// Iterates over tiles of A operand in global memory
+    //  (concept: ReadableTileIterator | ForwardTileIterator |
+    //  MaskedTileIterator)
+    typename IteratorA_,
+    /// Iterates over tiles of A operand in shared memory
+    /// (concept: WriteableTileIterator | RandomAccessTileIterator)
+    typename SmemIteratorA_,
+    /// Cache operation for operand A
+    cutlass::arch::CacheOperation::Kind CacheOpA,
+    /// Iterates over tiles of B operand in global memory
+    //  (concept: ReadableTileIterator | ForwardTileIterator |
+    //  MaskedTileIterator)
+    typename IteratorB_,
+    /// Iterates over tiles of B operand in shared memory
+    /// (concept: WriteableTileIterator | RandomAccessTileIterator)
+    typename SmemIteratorB_,
+    /// Cache operation for operand B
+    cutlass::arch::CacheOperation::Kind CacheOpB,
+    /// Iterators over tiles of quant scales in global memory
+    typename IteratorQScale_,
+    /// Iterators over tiles of quant scales in shared memory
+    typename SmemIteratorQScale_,
+    /// Cache operation for quant scales
+    cutlass::arch::CacheOperation::Kind CacheOpQScale,
+    /// Iterators over tiles of quant scales in global memory
+    typename IteratorQOffset_,
+    /// Iterators over tiles of quant scales in shared memory
+    typename SmemIteratorQOffset_,
+    /// Cache operation for quant scales
+    cutlass::arch::CacheOperation::Kind CacheOpQOffset,
+    /// Data type of accumulator matrix
+    typename ElementC_,
+    /// Data type of accumulator matrix
+    typename LayoutC_,
+    /// Policy describing tuning details (concept: MmaPolicy)
+    typename Policy_,
+    /// Number of stages,
+    int Stages,
+    /// Used for partial specialization
+    typename Enable = bool>
+class QuantBMmaMultistage :
+  public QuantBMmaBase<Shape_, Policy_, Stages> {
+public:
+  ///< Base class
+  using Base = QuantBMmaBase<Shape_, Policy_, Stages>;
+  ///< Size of the Gemm problem - concept: gemm::GemmShape<>
+  using Shape = Shape_;
+  ///< Iterates over tiles of A operand in global memory
+  using IteratorA = IteratorA_;
+  ///< Iterates over tiles of B operand in global memory
+  using IteratorB = IteratorB_;
+  ///< Data type of accumulator matrix
+  using ElementC = ElementC_;
+  ///< Layout of accumulator matrix
+  using LayoutC = LayoutC_;
+  ///< Policy describing tuning details
+  using Policy = Policy_;
+
+  using SmemIteratorA = SmemIteratorA_;
+  using SmemIteratorB = SmemIteratorB_;
+
+  static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
+  static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
+
+  using IteratorQScale = IteratorQScale_;
+  using IteratorQOffset = IteratorQOffset_;
+  using SmemIteratorQScale = SmemIteratorQScale_;
+  using SmemIteratorQOffset = SmemIteratorQOffset_;
+  using QuantBlocking = typename Base::QuantBlocking;
+
+  static cutlass::arch::CacheOperation::Kind const kCacheOpQScale = CacheOpQScale;
+  static cutlass::arch::CacheOperation::Kind const kCacheOpQOffset = CacheOpQOffset;
+  static constexpr bool kHasQOffset = Base::kHasQOffset;
+
+  //
+  // Dependent types
+  //
+
+  /// Fragment of accumulator tile
+  using FragmentC = typename Policy::Operator::FragmentC;
+
+  /// Warp-level Mma
+  using Operator = typename Policy::Operator;
+
+  /// Minimum architecture is Sm80 to support cp.async
+  using ArchTag = arch::Sm80;
+
+  /// Complex transform on A operand
+  static ComplexTransform const kTransformA = Operator::kTransformA;
+
+  /// Complex transform on B operand
+  static ComplexTransform const kTransformB = Operator::kTransformB;
+
+  /// Internal structure exposed for introspection.
+  struct Detail {
+
+    /// Number of cp.async instructions to load one stage of operand A
+    static int const AsyncCopyIterationsPerStageA =
+        IteratorA::ThreadMap::Iterations::kCount;
+
+    /// Number of cp.async instructions to load one stage of packed weights
+    static int const AsyncCopyIterationsPerStageB =
+        IteratorB::ThreadMap::Iterations::kCount;
+
+    /// Number of stages
+    static int const kStages = Stages;
+
+    /// Number of cp.async instructions to load on group of operand A
+    static int const kAccessesPerGroupA =
+        (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
+
+    /// Number of cp.async instructions to load on group of operand B
+    static int const kAccessesPerGroupB =
+        (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
+
+    static int const AsyncCopyIterationsPerStageQScale =
+        IteratorQScale::ThreadMap::Iterations::kCount;
+
+    /// Number of cp.async instructions to load one stage of quant scale
+    static int const kAccessesPerGroupQScale =
+        (AsyncCopyIterationsPerStageQScale + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
+
+    static int const AsyncCopyIterationsPerStageQOffset =
+        IteratorQOffset::ThreadMap::Iterations::kCount;
+
+    /// Number of cp.async instructions to load one stage of quant offset
+    static int const kAccessesPerGroupQOffset =
+        (AsyncCopyIterationsPerStageQOffset + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
+
+    // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical
+    // accuracy, where each mainloop iteration first accumulates into a temporary
+    // set of freshly-cleared accumulators, which are subsequently added to the
+    // final accumulator set.
+    static bool const kStagedAccumulation = arch::UseStagedAccumulation<typename Operator::MathOperator>::value;
+  };
+
+ private:
+
+
+  // Structure encapsulating pipeline state live from one iteration to the next
+  struct PipeState {
+
+    using WarpLoadedFragmentA = typename Operator::FragmentA;
+    using WarpLoadedFragmentB = typename Operator::FragmentB;
+    using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
+    using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
+
+    /// Temporary accumulator to facilitate staged-accumulation
+    FragmentC tmp_accum_;
+
+    /// Pair of A fragments used to overlap shared memory loads and math instructions
+    WarpLoadedFragmentA warp_loaded_frag_A_[2];
+
+    /// Pair of B fragments used to overlap shared memory loads and math instructions
+    WarpLoadedFragmentB warp_loaded_frag_B_;
+    WarpTransformedFragmentB warp_transformed_frag_B_[2];
+
+    using WarpLoadedFragmentQScale = typename Operator::FragmentQScale;
+    WarpLoadedFragmentQScale warp_loaded_frag_QScale_;
+
+    using WarpLoadedFragmentQOffset = typename std::conditional<kHasQOffset,
+            typename Operator::FragmentQOffset,
+            std::monostate>::type;
+    WarpLoadedFragmentQOffset warp_loaded_frag_QOffset_;
+  };
+
+
+ private:
+
+  //
+  // Data members
+  //
+
+  /// Warp-level MMA operator
+  Operator warp_mma_;
+
+  /// Iterator to write threadblock-scoped tile of A operand to shared memory
+  SmemIteratorA smem_iterator_A_;
+
+  /// Iterator to write threadblock-scoped tile of B operand to shared memory
+  SmemIteratorB smem_iterator_B_;
+
+  /// Iterator to write threadblock-scoped tile of quant meta data to shared memory
+  SmemIteratorQScale smem_iterator_QScale_;
+  SmemIteratorQOffset smem_iterator_QOffset_;
+
+  /// Shared memory write stage index
+  int smem_write_stage_idx_;
+
+  /// Shared memory read stage index
+  int smem_read_stage_idx_;
+
+  /// very small meta data tensor require less threads to load
+  bool const should_load_qscale_;
+  bool const should_load_qoffset_;
+
+  /// Shared memory pointers for debug dumping
+  static constexpr bool debug_layout = false;
+  using LayoutDebugType = typename std::conditional<debug_layout,
+      QuantBLayoutDebug<typename IteratorB::Element, typename IteratorQScale::Element, typename IteratorQOffset::Element>,
+      std::monostate>::type;
+  LayoutDebugType layout_debug_;
+
+public:
+
+  /// Construct from tensor references
+  CUTLASS_DEVICE
+  QuantBMmaMultistage(
+      ///< Shared storage needed for internal use by threadblock-scoped GEMM
+      typename Base::SharedStorage &shared_storage,
+      ///< ID within the threadblock
+      int thread_idx,
+      ///< ID of warp
+      int warp_idx,
+      ///< ID of each thread within a warp
+      int lane_idx
+    ):
+      Base(shared_storage, thread_idx, warp_idx, lane_idx),
+      smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
+      smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
+      smem_iterator_QScale_(shared_storage.operand_QScale_ref(), thread_idx),
+      smem_iterator_QOffset_(shared_storage.operand_QOffset_ref(), thread_idx),
+      should_load_qscale_(thread_idx < IteratorQScale::ThreadMap::kThreads),
+      should_load_qoffset_(thread_idx >= IteratorQOffset::kThreadblockSize - IteratorQOffset::ThreadMap::kThreads),
+      smem_write_stage_idx_(0),
+      smem_read_stage_idx_(0)
+  {
+    // Compute warp location within threadblock tile by mapping the warp_id to
+    // three coordinates:
+    //   _m: the warp's position within the threadblock along the M dimension
+    //   _n: the warp's position within the threadblock along the N dimension
+    //   _k: the warp's position within the threadblock along the K dimension
+    if constexpr(debug_layout){
+      layout_debug_.smem_b_ptr_ = shared_storage.operand_B_ref().data();
+      layout_debug_.smem_qscale_ptr_ = shared_storage.operand_QScale_ref().data();
+      if constexpr(kHasQOffset){
+        layout_debug_.smem_qoffset_ptr_ = shared_storage.operand_QOffset_ref().data();
+      } else {
+        layout_debug_.smem_qoffset_ptr_ = nullptr;
+      }
+      layout_debug_.warp_id_ = warp_idx;
+      layout_debug_.lane_id_ = lane_idx;
+      layout_debug_.block_id_ = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
+    }
+
+    int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
+    int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
+
+    int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
+    int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
+
+    // Add per-warp offsets in units of warp-level tiles
+    this->warp_tile_iterator_A_.add_tile_offset(
+        {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
+    this->warp_tile_iterator_B_.add_tile_offset(
+        {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
+    this->warp_tile_iterator_QScale_.add_tile_offset(
+        {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
+  }
+
+  /// Advance shared memory read-iterators to the next stage
+  CUTLASS_DEVICE
+  void advance_smem_read_stage()
+  {
+    ++smem_read_stage_idx_;
+
+    if (smem_read_stage_idx_ == Base::kStages) {
+      // Wrap back around to the 'start' of the circular buffer in shared memory
+      this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
+      this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
+      this->warp_tile_iterator_QScale_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
+
+      smem_read_stage_idx_ = 0;
+    }
+  }
+
+  /// Advance global memory read-iterators and shared memory write-iterators to the stage
+  CUTLASS_DEVICE
+  void advance_smem_write_stage(
+    IteratorA &iterator_A,
+    IteratorB &iterator_B,
+    IteratorQScale &iterator_QScale,
+    IteratorQOffset &iterator_QOffset)
+  {
+    // Advance global iterators
+    iterator_A.add_tile_offset({0, 1});
+    iterator_B.add_tile_offset({1, 0});
+    iterator_QScale.add_tile_offset({1, 0});
+
+    // Advance shared iterators
+    smem_iterator_A_.add_tile_offset({0, 1});
+    smem_iterator_B_.add_tile_offset({1, 0});
+    smem_iterator_QScale_.add_tile_offset({1, 0});
+
+    if constexpr (kHasQOffset) {
+      iterator_QOffset.add_tile_offset({1, 0});
+      smem_iterator_QOffset_.add_tile_offset({1, 0});
+    }
+
+    // Increment shared memory write stage index
+    ++smem_write_stage_idx_;
+
+    if (smem_write_stage_idx_ == Base::kStages) {
+      // Wrap back around to the 'start' of the circular buffer in shared memory
+      smem_iterator_A_.add_tile_offset({0, -Base::kStages});
+      smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
+      smem_iterator_QScale_.add_tile_offset({-Base::kStages, 0});
+      if constexpr (kHasQOffset) {
+        smem_iterator_QOffset_.add_tile_offset({-Base::kStages, 0});
+      }
+      smem_write_stage_idx_ = 0;
+    }
+  }
+
+  CUTLASS_DEVICE
+  void copy_qscale_tiles(IteratorQScale &iterator_QScale){
+    // Quant scale matrix is 1/block_size of the B matrix, for a 64x64 warp tile,
+    // it's only 64x64/block_size elements. For blocking size 16 ~ 64, it only
+    // takes 4 ~ 16 cp.async instructions to load. One warp has 32 threads, so
+    // it should be loaded in less than one cp.async instruction per thread.
+    // Even less for quant offset matrix.
+    static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1,
+                  "Quant scale should be loaded in one shot!");
+    static_assert(IteratorQScale::kAccessesPerVector == 1,
+                  "Quant scale should 1 access per vector!");
+
+    // Async Copy for quantization scale
+    typename IteratorQScale::AccessType *dst_ptr =
+        reinterpret_cast<typename IteratorQScale::AccessType *>(
+            this->smem_iterator_QScale_.get());
+
+    constexpr int kSrcBytes =
+        sizeof_bits<typename IteratorQScale::Element>::value *
+            IteratorQScale::ThreadMap::kElementsPerAccess / 8;
+
+    cutlass::arch::cp_async<kSrcBytes, kCacheOpQScale>(
+        dst_ptr, iterator_QScale.get(), iterator_QScale.valid());
+  }
+
+  CUTLASS_DEVICE
+  void copy_qoffset_tiles(IteratorQOffset & iterator_QOffset) {
+    static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1,
+                  "Quant offset should be loaded in one shot!");
+    static_assert(IteratorQOffset::kAccessesPerVector == 1,
+                  "Quant offset should 1 access per vector!");
+
+    if constexpr(kHasQOffset) {
+      // Async Copy for quantization offset
+      typename IteratorQOffset::AccessType *dst_ptr =
+          reinterpret_cast<typename IteratorQOffset::AccessType *>(
+              this->smem_iterator_QOffset_.get());
+
+      constexpr int kSrcBytes = sizeof_bits<typename IteratorQOffset::Element>::value *
+                                IteratorQOffset::ThreadMap::kElementsPerAccess / 8;
+
+      cutlass::arch::cp_async<kSrcBytes, kCacheOpQOffset>(
+            dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid());
+    }
+  }
+
+  CUTLASS_DEVICE
+  void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B,
+                              int group_start = 0) {
+    auto group_start_A = group_start * Detail::kAccessesPerGroupA;
+    iterator_A.set_iteration_index(group_start_A *
+                                   IteratorA::kAccessesPerVector);
+    this->smem_iterator_A_.set_iteration_index(group_start_A);
+
+    // Async Copy for operand A
+    CUTLASS_PRAGMA_UNROLL
+    for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
+      if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
+        typename IteratorA::AccessType *dst_ptr =
+            reinterpret_cast<typename IteratorA::AccessType *>(
+                this->smem_iterator_A_.get());
+
+        int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
+                              IteratorA::ThreadMap::kElementsPerAccess /
+                              IteratorA::kAccessesPerVector / 8;
+
+        CUTLASS_PRAGMA_UNROLL
+        for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
+          auto gmem_ptr = iterator_A.get();
+
+          cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
+              dst_ptr + v, gmem_ptr, iterator_A.valid());
+
+          ++iterator_A;
+        }
+
+        ++this->smem_iterator_A_;
+      }
+    }
+
+    auto group_start_B = group_start * Detail::kAccessesPerGroupB;
+    iterator_B.set_iteration_index(group_start_B *
+                                   IteratorB::kAccessesPerVector);
+    this->smem_iterator_B_.set_iteration_index(group_start_B);
+
+    // Async Copy for operand B
+    CUTLASS_PRAGMA_UNROLL
+    for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
+      if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
+        typename IteratorB::AccessType *dst_ptr =
+            reinterpret_cast<typename IteratorB::AccessType *>(
+                this->smem_iterator_B_.get());
+
+        int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
+                              IteratorB::ThreadMap::kElementsPerAccess /
+                              IteratorB::kAccessesPerVector / 8;
+
+        CUTLASS_PRAGMA_UNROLL
+        for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
+          auto gmem_ptr = iterator_B.get();
+
+          cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
+              dst_ptr + v, gmem_ptr, iterator_B.valid());
+
+          ++iterator_B;
+        }
+        ++this->smem_iterator_B_;
+      }
+    }
+  }
+
+  /// GEMM prologue.  Bootstrap the global->shared memory pipeline by fetching
+  /// the global fragments needed by the first kStages-1 threadblock mainloop iterations
+  CUTLASS_DEVICE
+  void prologue(
+    IteratorA &iterator_A,      ///< [in|out] iterator over A operand in global memory
+    IteratorB &iterator_B,      ///< [in|out] iterator over B operand in global memory
+    IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory
+    IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory
+    int &gemm_k_iterations)     ///< [in|out] number of threadblock mainloop iterations remaining
+  {
+    // Issue several complete stages
+    CUTLASS_PRAGMA_UNROLL
+    for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
+
+      // Disable global fetching if done with global fetch iterations
+      iterator_A.clear_mask(gemm_k_iterations == 0);
+      iterator_B.clear_mask(gemm_k_iterations == 0);
+      iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_);
+
+      iterator_A.set_iteration_index(0);
+      this->smem_iterator_A_.set_iteration_index(0);
+
+      // Async Copy for operand A
+      CUTLASS_PRAGMA_UNROLL
+      for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
+        typename IteratorA::AccessType *dst_ptr =
+            reinterpret_cast<typename IteratorA::AccessType *>(
+                this->smem_iterator_A_.get());
+
+        CUTLASS_PRAGMA_UNROLL
+        for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
+          int const kSrcBytes =
+              sizeof_bits<typename IteratorA::Element>::value *
+              IteratorA::ThreadMap::kElementsPerAccess /
+              IteratorA::kAccessesPerVector / 8;
+
+          int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
+
+          cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
+              dst_ptr + v, iterator_A.get(), iterator_A.valid());
+
+          ++iterator_A;
+        }
+
+        ++this->smem_iterator_A_;
+      }
+
+      iterator_B.set_iteration_index(0);
+      this->smem_iterator_B_.set_iteration_index(0);
+
+      // Async Copy for operand B
+      CUTLASS_PRAGMA_UNROLL
+      for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
+        typename IteratorB::AccessType *dst_ptr =
+            reinterpret_cast<typename IteratorB::AccessType *>(
+                this->smem_iterator_B_.get());
+
+        CUTLASS_PRAGMA_UNROLL
+        for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
+          int const kSrcBytes =
+              sizeof_bits<typename IteratorB::Element>::value *
+              IteratorB::ThreadMap::kElementsPerAccess /
+              IteratorB::kAccessesPerVector / 8;
+
+          cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
+              dst_ptr + v, iterator_B.get(), iterator_B.valid());
+
+          ++iterator_B;
+        }
+
+        ++this->smem_iterator_B_;
+      }
+
+      // Async Copy for quantization scale
+      static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, "Quant scale should be loaded in one shot!");
+      static_assert(IteratorQScale::kAccessesPerVector == 1, "Quant scale should 1 access per vector!");
+
+      typename IteratorQScale::AccessType *dst_ptr =
+          reinterpret_cast<typename IteratorQScale::AccessType *>(
+              this->smem_iterator_QScale_.get());
+
+      constexpr int kSrcBytes =
+          sizeof_bits<typename IteratorQScale::Element>::value *
+          IteratorQScale::ThreadMap::kElementsPerAccess / 8;
+
+      auto gmem_ptr = iterator_QScale.get();
+
+      cutlass::arch::cp_async<kSrcBytes, kCacheOpQScale>(
+          dst_ptr, gmem_ptr, iterator_QScale.valid());
+
+      if constexpr (kHasQOffset) {
+        iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_);
+
+        // Async Copy for quantization offset
+        static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, "Quant offset should be loaded in one shot!");
+        static_assert(IteratorQOffset::kAccessesPerVector == 1, "Quant offset should 1 access per vector!");
+        typename IteratorQOffset::AccessType *dst_ptr =
+            reinterpret_cast<typename IteratorQOffset::AccessType *>(
+                this->smem_iterator_QOffset_.get());
+
+        constexpr int kSrcBytes =
+            sizeof_bits<typename IteratorQOffset::Element>::value *
+                IteratorQOffset::ThreadMap::kElementsPerAccess / 8;
+
+        cutlass::arch::cp_async<kSrcBytes, kCacheOpQOffset>(
+            dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid());
+      }
+
+      // Move to the next write stage
+      advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset);
+
+      // Defines the boundary of a stage of cp.async.
+      cutlass::arch::cp_async_fence();
+    }
+  }
+
+
+  /// Wait until we have at least one completed global fetch stage
+  CUTLASS_DEVICE
+  void gmem_wait()
+  {
+    // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
+    cutlass::arch::cp_async_wait<Base::kStages - 2>();
+    __syncthreads();
+
+    if constexpr(debug_layout) {
+      if (LayoutDebugType::debug_smem && layout_debug_.block_id_ == 1) {
+        if (threadIdx.x == 0){
+          printf("stage: %d\n", smem_write_stage_idx_);
+        }
+        cutlass::debug::dump_shmem(layout_debug_.smem_qscale_ptr_, Base::SharedStorage::ShapeQScale::kCount);
+        if constexpr(kHasQOffset){
+          cutlass::debug::dump_shmem(layout_debug_.smem_qoffset_ptr_, Base::SharedStorage::ShapeQScale::kCount);
+        }
+      }
+    }
+  }
+
+  /// Perform a threadblock mainloop iteration of matrix multiply-accumulate
+  CUTLASS_DEVICE
+  void mac_loop_iter(
+    PipeState &pipe_state,          ///< [in|out] loop-carried pipeline state
+    FragmentC &accum,               ///< [in|out] destination accumulator tile
+    IteratorA &iterator_A,          ///< [in|out] iterator over A operand in global memory
+    IteratorB &iterator_B,          ///< [in|out] iterator over B operand in global memory
+    IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory
+    IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory
+    int &gemm_k_iterations)         ///< [in|out] number of threadblock mainloop iterations remaining
+  {
+    // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
+    CUTLASS_PRAGMA_UNROLL
+    for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
+      // Loading next warp-level tiles from shared memory. This can be skipped on the very
+      // last iteration where:
+      //   (gemm_k_iterations == (1 - Base::kStages)) && (warp_mma_k == (Base::kWarpGemmIterations - 1))
+      // However, evaluating this condition seems more expensive than simply loading the tiles
+      this->warp_tile_iterator_QScale_.load(
+          pipe_state.warp_loaded_frag_QScale_,
+          pipe_state.warp_loaded_frag_QOffset_);
+      ++this->warp_tile_iterator_QScale_;
+
+      this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
+      this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
+      ++this->warp_tile_iterator_B_;
+
+      this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
+      this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]);
+      ++this->warp_tile_iterator_A_;
+
+      // All warp-tiles issue their share of global->shared fragment copies
+      copy_tiles_and_advance(
+          iterator_A,
+          iterator_B,
+          (warp_mma_k + 1) % Base::kWarpGemmIterations);
+
+      if constexpr(debug_layout) {
+        if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){
+          printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations);
+        }
+        LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
+        LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
+        if constexpr(kHasQOffset){
+          LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
+        }
+      }
+
+      warp_mma_.transform(
+        pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2],
+        pipe_state.warp_loaded_frag_B_,
+        pipe_state.warp_loaded_frag_QScale_,
+        pipe_state.warp_loaded_frag_QOffset_);
+
+      if constexpr(debug_layout) {
+        LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
+      }
+
+      // Execute the current warp-tile of MMA operations
+      if (Detail::kStagedAccumulation) {
+        warp_mma_(
+          pipe_state.tmp_accum_,
+          pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
+          pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
+          pipe_state.tmp_accum_
+        );
+
+        if (warp_mma_k == 0) {
+          plus<FragmentC> plus_accum;
+          accum = plus_accum(accum, pipe_state.tmp_accum_);
+          pipe_state.tmp_accum_.clear();
+        }
+      } else {
+        warp_mma_(
+          accum,
+          pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
+          pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
+          accum
+        );
+      }
+
+      if (warp_mma_k == 0) {
+        copy_qscale_tiles(iterator_QScale);
+      }
+      if (warp_mma_k == 1) {
+        copy_qoffset_tiles(iterator_QOffset);
+      }
+
+      // The second-to-last warp-tile also moves to the next global fetch stage
+      if (warp_mma_k == Base::kWarpGemmIterations - 2) {
+        // Inserts a memory fence between stages of cp.async instructions.
+        cutlass::arch::cp_async_fence();
+
+        // Move to the next global fetch stage
+        advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset);
+        advance_smem_read_stage();
+
+        // Disable global fetching when done with global fetch iterations
+        --gemm_k_iterations;
+        iterator_A.clear_mask(gemm_k_iterations == 0);
+        iterator_B.clear_mask(gemm_k_iterations == 0);
+        iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_);
+        if constexpr(kHasQOffset){
+          iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_);
+        }
+
+        // Wait until we have at least one completed global fetch stage
+        gmem_wait();
+      }
+
+    }
+  }
+
+  /// Specialized mainloop iteration of matrix multiply-accumulate, for small M
+  CUTLASS_DEVICE
+  void mac_loop_iter_small_m(
+    PipeState &pipe_state,          ///< [in|out] loop-carried pipeline state
+    FragmentC &accum,               ///< [in|out] destination accumulator tile
+    IteratorA &iterator_A,          ///< [in|out] iterator over A operand in global memory
+    IteratorB &iterator_B,          ///< [in|out] iterator over B operand in global memory
+    IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory
+    IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory
+    int &gemm_k_iterations)         ///< [in|out] number of threadblock mainloop iterations remaining
+  {
+    // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
+    CUTLASS_PRAGMA_UNROLL
+    for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
+      // In the case of small M, memory latency dominates. We try to move uses far
+      // from their definitions to hide latency.
+      if constexpr(debug_layout) {
+        if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){
+          printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations);
+        }
+        LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
+        LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
+        if constexpr(kHasQOffset){
+          LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
+        }
+      }
+
+      warp_mma_.transform(
+        pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2],
+        pipe_state.warp_loaded_frag_B_,
+        pipe_state.warp_loaded_frag_QScale_,
+        pipe_state.warp_loaded_frag_QOffset_);
+
+      if constexpr(debug_layout) {
+        LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
+      }
+
+      // Loading next warp-level tiles from shared memory.
+      this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
+      this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
+      ++this->warp_tile_iterator_B_;
+
+      this->warp_tile_iterator_QScale_.load(
+          pipe_state.warp_loaded_frag_QScale_,
+          pipe_state.warp_loaded_frag_QOffset_);
+      ++this->warp_tile_iterator_QScale_;
+
+      this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
+      this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]);
+      ++this->warp_tile_iterator_A_;
+
+      // All warp-tiles issue their share of global->shared fragment copies
+      copy_tiles_and_advance(
+          iterator_A,
+          iterator_B,
+          (warp_mma_k + 1) % Base::kWarpGemmIterations);
+
+      // Execute the current warp-tile of MMA operations
+      if (Detail::kStagedAccumulation) {
+        warp_mma_(
+          pipe_state.tmp_accum_,
+          pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
+          pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
+          pipe_state.tmp_accum_
+        );
+
+        if (warp_mma_k == 0) {
+          plus<FragmentC> plus_accum;
+          accum = plus_accum(accum, pipe_state.tmp_accum_);
+          pipe_state.tmp_accum_.clear();
+        }
+      } else {
+        warp_mma_(
+          accum,
+          pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
+          pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
+          accum
+        );
+      }
+
+      // The second-to-last warp-tile also moves to the next global fetch stage
+      if (warp_mma_k == Base::kWarpGemmIterations - 2) {
+        // Inserts a memory fence between stages of cp.async instructions.
+        cutlass::arch::cp_async_fence();
+
+        // Move to the next global fetch stage
+        advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset);
+        advance_smem_read_stage();
+
+        // Disable global fetching when done with global fetch iterations
+        --gemm_k_iterations;
+        iterator_A.clear_mask(gemm_k_iterations == 0);
+        iterator_B.clear_mask(gemm_k_iterations == 0);
+        iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_);
+        if constexpr(kHasQOffset){
+          iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_);
+        }
+
+        copy_qscale_tiles(iterator_QScale);
+        copy_qoffset_tiles(iterator_QOffset);
+
+        // Wait until we have at least one completed global fetch stage
+        gmem_wait();
+      }
+
+    }
+  }
+
+
+  /// Perform the specified number of threadblock mainloop iterations of matrix
+  /// multiply-accumulate.  Assumes prologue has been initiated.
+  CUTLASS_DEVICE
+  void gemm_iters(
+      int gemm_k_iterations,        ///< number of threadblock mainloop iterations
+      FragmentC &accum,             ///< [in|out] accumulator tile
+      IteratorA &iterator_A,        ///< [in|out] iterator over A operand in global memory
+      IteratorB &iterator_B,        ///< [in|out] iterator over B operand in global memory
+      IteratorQScale &iterator_QScale, ///< [in|out] iterator over QScale operand in global memory
+      IteratorQOffset &iterator_QOffset) ///< [in|out] iterator over QOffset operand in global memory
+  {
+    PipeState pipe_state;
+
+    // Disable global fetching if done with global fetch iterations
+    iterator_A.clear_mask(gemm_k_iterations == 0);
+    iterator_B.clear_mask(gemm_k_iterations == 0);
+    iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_);
+    if constexpr(kHasQOffset) {
+      iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_);
+    }
+
+    // Load first warp-tile's B fragment from shared memory
+    this->warp_tile_iterator_QScale_.load(
+        pipe_state.warp_loaded_frag_QScale_,
+        pipe_state.warp_loaded_frag_QOffset_);
+    ++this->warp_tile_iterator_QScale_;
+
+    this->warp_tile_iterator_B_.set_kgroup_index(0);
+    this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
+    ++this->warp_tile_iterator_B_;
+
+    // Load first warp-tile's A fragment from shared memory
+    this->warp_tile_iterator_A_.set_kgroup_index(0);
+    this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]);
+    ++this->warp_tile_iterator_A_;
+
+    copy_tiles_and_advance(iterator_A, iterator_B, 0);
+
+    if constexpr(Shape::kM > 32) {
+      // the case of bigger m
+      if constexpr(debug_layout) {
+        if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){
+          printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, 0);
+        }
+        LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
+        LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
+        if constexpr(kHasQOffset){
+          LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
+        }
+      }
+
+      warp_mma_.transform(
+        pipe_state.warp_transformed_frag_B_[0],
+        pipe_state.warp_loaded_frag_B_,
+        pipe_state.warp_loaded_frag_QScale_,
+        pipe_state.warp_loaded_frag_QOffset_);
+
+      if constexpr(debug_layout) {
+        LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[0], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_);
+      }
+    } else {
+      // the case of small m
+      copy_qscale_tiles(iterator_QScale);
+      copy_qoffset_tiles(iterator_QOffset);
+    }
+
+    if (Detail::kStagedAccumulation) {
+      pipe_state.tmp_accum_.clear();
+    }
+
+    // Mainloop
+    CUTLASS_GEMM_LOOP
+    for (; gemm_k_iterations > (-Base::kStages + 1);) {
+      if constexpr(Shape::kM > 32) {
+        mac_loop_iter(
+          pipe_state,
+          accum,
+          iterator_A,
+          iterator_B,
+          iterator_QScale,
+          iterator_QOffset,
+          gemm_k_iterations);
+      } else {
+        mac_loop_iter_small_m(
+          pipe_state,
+          accum,
+          iterator_A,
+          iterator_B,
+          iterator_QScale,
+          iterator_QOffset,
+          gemm_k_iterations);
+      }
+    }
+
+    if (Detail::kStagedAccumulation) {
+      plus<FragmentC> plus_accum;
+      accum = plus_accum(accum, pipe_state.tmp_accum_);
+    }
+
+    // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
+    cutlass::arch::cp_async_fence();
+    cutlass::arch::cp_async_wait<0>();
+    __syncthreads();
+
+  }
+
+
+  /// Perform a threadblock-scoped matrix multiply-accumulate
+  CUTLASS_DEVICE
+  void operator()(
+      ///< problem size of GEMM
+      int gemm_k_iterations,
+      ///< destination accumulator tile
+      FragmentC &accum,
+      ///< iterator over A operand in global memory
+      IteratorA iterator_A,
+      ///< iterator over B operand in global memory
+      IteratorB iterator_B,
+      ///< iterator over quant scales in global memory
+      IteratorQScale iterator_QScale,
+      ///< Iterator over quant offsets in global memory
+      IteratorQOffset iterator_QOffset,
+      ///< initial value of accumulator
+      FragmentC const &src_accum) {
+
+    // Prologue (start fetching iterations of global fragments into shared memory)
+    prologue(iterator_A, iterator_B, iterator_QScale, iterator_QOffset, gemm_k_iterations);
+
+    // Wait until we have at least one completed global fetch stage
+    gmem_wait();
+
+    // Initialize destination accumulators with source accumulators
+    accum = src_accum;
+
+    // Perform the MAC-iterations
+    gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_QScale, iterator_QOffset);
+  }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+}  // namespace threadblock
+}  // namespace gemm
+}  // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h
new file mode 100644
index 0000000000000..2c49888c94504
--- /dev/null
+++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h
@@ -0,0 +1,112 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+/**
+ * Modifications Copyright (c) Microsoft.
+ * Licensed under the MIT license.
+ *
+ * @file default_quantb_mma_tensor_op.h
+ * @brief Modified from cutlass/gemm/warp/default_mma_tensor_op.h
+ * Default warp-level GEMM operators selected by data type, size, and layouts of operands.
+ */
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include "cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h"
+
+namespace cutlass {
+namespace gemm {
+namespace warp {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Partial specialization for m-by-n-by-kgroup
+template <
+    /// Shape of one matrix production operation (concept: GemmShape)
+    typename WarpShape_,
+    /// Shape of one matrix production operation (concept: GemmShape)
+    typename InstructionShape_,
+    /// Data type of A elements
+    typename ElementA,
+    /// Layout of A matrix (concept: MatrixLayout)
+    typename LayoutA,
+    /// Data type of B elements
+    typename ElementB,
+    /// Layout of B matrix (concept: MatrixLayout)
+    typename LayoutB,
+    /// Data type of quant scales
+    typename ElementQScale,
+    /// Layout of quant scales (concept: MatrixLayout)
+    typename SmemLayoutQScale,
+    /// Data type of quant offsets
+    typename ElementQOffset,
+    /// Layout of quant offsets (concept: MatrixLayout)
+    typename SmemLayoutQOffset,
+    /// Blocking size of quantization
+    typename QuantBlocking,
+    /// Element type of C matrix
+    typename ElementC,
+    /// Layout of C matrix (concept: MatrixLayout)
+    typename LayoutC,
+    /// Operator describing the tensor operation
+    typename Operator_ = arch::OpMultiplyAdd,
+    /// Number of partitions along K dimension
+    int PartitionsK = 1,
+    /// Store the accumulators in row major or column major.  Row major is used
+    /// when output layout is interleaved.
+    bool AccumulatorsInRowMajor = false>
+struct DefaultQuantBMmaTensorOp {
+  using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
+      cutlass::arch::Mma<InstructionShape_, 32, ElementA,
+                         cutlass::layout::RowMajor, ElementB,
+                         cutlass::layout::ColumnMajor, ElementC,
+                         cutlass::layout::RowMajor, Operator_>,
+      cutlass::MatrixShape<1, 1> >;
+
+  // Define the warp-level tensor op
+  using Type = cutlass::gemm::warp::QuantBMmaTensorOp<
+      WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementQScale, SmemLayoutQScale,
+      ElementQOffset, SmemLayoutQOffset, QuantBlocking, ElementC, LayoutC,
+      Policy, PartitionsK, AccumulatorsInRowMajor>;
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace warp
+} // namespace gemm
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h
new file mode 100644
index 0000000000000..4ba39dda3db8d
--- /dev/null
+++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h
@@ -0,0 +1,883 @@
+/**
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the MIT license.
+ *
+ * @file quantb_meta_mma_tensor_op_tile_iterator.h
+ * @brief Templates for loading quantization meta data for operand B
+ *        from shared memory to fragments. This is meant to be used in
+ *        lock step with the operand B tile iterator. Containing logic
+ *        to figure out the operand B layout in the tensor core,
+ *        and deliver each meta data element to its corresponding
+ *        operand B element for dequantization.
+ */
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+
+#include "cutlass/array.h"
+#include "cutlass/numeric_types.h"
+#include "cutlass/tensor_ref.h"
+#include "cutlass/matrix_shape.h"
+
+#include "cutlass/arch/memory_sm75.h"
+#include "cutlass/gemm/gemm.h"
+
+#include "cutlass/layout/matrix.h"
+#include "cutlass/layout/tensor.h"
+#include "cutlass/layout/pitch_linear.h"
+#include "cutlass/layout/tensor_op_multiplicand_sm75.h"
+
+#include "cutlass/platform/platform.h"
+#include "cutlass/fast_math.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace{
+
+struct b32_pair{
+  uint32_t a;
+  uint32_t b;
+};
+
+struct fp16_quad{
+  cutlass::half_t a;
+  cutlass::half_t b;
+  cutlass::half_t c;
+  cutlass::half_t d;
+};
+
+struct b16_quad{
+  int16_t a;
+  int16_t b;
+  int16_t c;
+  int16_t d;
+};
+
+union b64 {
+  uint64_t single;
+  b32_pair pair;
+  b16_quad quard;
+  fp16_quad fp16_quad;
+};
+
+static_assert(sizeof(b64) == 8, "b64 should be 64 bits");
+
+/// Convert packed 4b weights into fp16(weight + 16)
+/// Current bit hacking only supports fp16, need to add bf16 later.
+///
+template<int Size>
+CUTLASS_DEVICE
+void weights2Half(cutlass::Array<uint8_t,Size/2> const &weights,
+                 cutlass::Array<cutlass::half_t, Size>& dest)
+{
+  static_assert(Size % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile.");
+  uint32_t* dest_pair = reinterpret_cast<uint32_t*>(dest.data());
+  const uint32_t* w_oct = reinterpret_cast<const uint32_t*>(weights.data());
+
+  CUTLASS_PRAGMA_UNROLL
+  for (int oct_idx = 0; oct_idx < Size/8; oct_idx++, w_oct++, dest_pair += 4){
+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
+
+    // static_cast<cutlass::half_t>(16 + weight)
+    // 4b weights are prepacked into [0, 2, 4, 6, 1, 3, 5, 7], so that adjacent weights
+    // are in different 16b half words, making it easier to convert to fp16.
+    asm volatile(
+        "{\n\t"
+        "  shl.b32       %0, %4, 6;\n"
+        "  shl.b32       %1, %4, 2;\n"
+        "  shr.u32       %2, %4, 2;\n"
+        "  shr.u32       %3, %4, 6;\n"
+        "  lop3.b32      %0, %0, 0x03c003c0, 0x4c004c00, 0xea;\n" // a & 0x03c0 | 0x4c00
+        "  lop3.b32      %1, %1, 0x03c003c0, 0x4c004c00, 0xea;\n"
+        "  lop3.b32      %2, %2, 0x03c003c0, 0x4c004c00, 0xea;\n"
+        "  lop3.b32      %3, %3, 0x03c003c0, 0x4c004c00, 0xea;\n"
+        "}\n"
+        : "=r"(dest_pair[0]), "=r"(dest_pair[1]),
+          "=r"(dest_pair[2]), "=r"(dest_pair[3])
+        : "r"(*w_oct));
+#else
+    assert(0);
+#endif
+  }
+
+}
+
+} // namespace
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace warp {
+
+////////////////////////////////////////////////////////////////////////////////
+
+// Traits to describe the layout of quantization meta data layout in a MMA fragment
+// Since operand B is quantized on a per block basis, it's one meta data per block.
+
+template <
+  /// Shape of the operand B matrix to load in a warp (concept: MatrixShape<kK, kN>)
+  typename WarpShapeB_,
+  /// Block dimensions of the blockwise quantization. So the actual meta data
+  /// warp shape is WarpShapeB_ / BlockingShape_
+  typename BlockingShape_,
+  /// Underlying matrix multiply operator (concept: arch::Mma)
+  typename ArchMmaOperator_,
+  /// Number of threads participating in one matrix operation
+  int Threads>
+class QuantBMetaMmaTile{
+public:
+
+  using WarpShapeB = WarpShapeB_;
+  using BlockingShape = BlockingShape_;
+  using ArchMmaOperator = ArchMmaOperator_;
+
+  static_assert(Threads == 32, "This iterator should work in a warp only.");
+
+  /// Shape of the curresponding operand B tile iterator <instruction_k, warp_n>
+  using TileShapeB = MatrixShape<ArchMmaOperator::Shape::kK, WarpShapeB::kColumn>;
+
+  // Tensor core operand B layout is a column major 4x8 tile, divided
+  // into 32 threads (T0 ~ T31) as shown below. Each element of the tile is 32b,
+  // so for fp16 it becomes 8 x 8, and int8 it becomes 16 x 8.
+  //  T0 |  T4 |  T8 | T12 | T16 | T20 | T24 | T28
+  //  T1 |  T5 |  T9 | T13 | T17 | T21 | T25 | T29
+  //  T2 |  T6 | T10 | T14 | T18 | T22 | T26 | T30
+  //  T3 |  T7 | T11 | T15 | T19 | T23 | T27 | T31
+  using CoreTile = layout::PitchLinearShape<4, 8>;
+
+  /// Each thread holds a 32b fragment per tile: for half precision, it's 2 elements, 4 elements for int8
+  static int const kNumBsPerCoreTileFragement = 32 / sizeof_bits<typename ArchMmaOperator::ElementB>::value;
+
+  /// Each mma instruction can process either 1 or 2 tensor core operand B tiles (stacked on the k dimension)
+  static int const kBTilesPerMma =
+      sizeof_bits<typename ArchMmaOperator::ElementB>::value * ArchMmaOperator::FragmentB::kElements / 32;
+  static_assert(kBTilesPerMma == 1 || kBTilesPerMma == 2, "Only support 1 or 2 operand B tiles per mma.");
+
+  /// Each operand B tile iterator load covers a number of mma instructions
+  static int const kMmaIterationsB = WarpShapeB::kColumn / ArchMmaOperator::Shape::kN;
+
+  /// Number of B elements a fragment of meta data should cover
+  static int const kExpandedSize = kNumBsPerCoreTileFragement * kBTilesPerMma * kMmaIterationsB;
+
+  // Now we figure out how many meta data elements to load for each TileShapeB
+
+  /// Number of meta elements per CoreTile.
+  static int const kCoreTileFragementSize = (kNumBsPerCoreTileFragement + BlockingShape::kRow - 1) / BlockingShape::kRow;
+
+  /// Number of core tiles per mma instruction, different from kBTilesPerMma when blocking size on K dimension
+  /// exceeds the tile depth, so two tiles share the same meta data
+  static int const kTilesPerMma = ((kBTilesPerMma == 2) &&
+                                  (BlockingShape::kRow <= kNumBsPerCoreTileFragement * CoreTile::kContiguous))
+                                  ? 2 : 1;
+
+  /// stride to reach the meta data for the next CoreTile on the K dimension
+  static int const kKTileStride = (kNumBsPerCoreTileFragement * CoreTile::kContiguous + BlockingShape::kRow - 1) / BlockingShape::kRow;
+
+  /// Stride on N dimension should be the tile width, shrunk by blocking size on this dimension.
+  static int const kNStride = (CoreTile::kStrided + BlockingShape::kColumn - 1) / BlockingShape::kColumn;
+
+  /// On N dimension, how many tiles share the same meta data
+  static int const kNRepeats = (BlockingShape::kColumn + CoreTile::kStrided - 1) / CoreTile::kStrided;
+
+  /// Each fragment should cover kMmaIterationsB number of mma intructions on the N dimension.
+  /// When blocking size on this dimension exceeds the tile width, multiple iterations
+  /// would share the same data.
+  static int const kMmaIterations = (kMmaIterationsB + kNRepeats - 1) / kNRepeats;
+
+  static int const kFragementSize = kCoreTileFragementSize * kTilesPerMma * kMmaIterations;
+
+  CUTLASS_DEVICE
+  static MatrixCoord lane_position(int lane_id) {
+    if constexpr(kNumBsPerCoreTileFragement == 2
+                 && kBTilesPerMma == 2
+                 && BlockingShape::kRow == 1){
+      // Optimize for a special case of:
+      //    16b gemm (kNumBsPerCoreTileFragement == 2)
+      //    2 B operand tiles per mma (kBTilesPerMma == 2)
+      //    (1,n) quantization blocking
+      // The scale and offset tensors are prepacked to reduce the number of load instructions.
+      return make_Coord((lane_id % CoreTile::kContiguous) * 4,
+         lane_id / CoreTile::kContiguous);
+    } else {
+      return make_Coord((lane_id % CoreTile::kContiguous) * kNumBsPerCoreTileFragement,
+         lane_id / CoreTile::kContiguous);
+    }
+  }
+};
+
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// This tile iterator is to load quantization meta data for operand B from
+/// shared memory to fragments (hopefully allocated to registers by compilers).
+/// Examples of meta data include scale or offsets. The operand B matrix is
+/// quantized on a per block basis, meaning one element of meta data per block.
+///
+/// This is meant to be used in lock step with the operand B tile iterator.
+/// So all parameters are logical positions in the operand B tiles.
+/// The goal here is to deliver each meta data element to its corresponding
+/// operand B element for dequantization. As a result, we need to figure
+/// out the operand B layout in the tensor core.
+///
+template <
+  /// Shape of the operand B matrix to load in a warp (concept: MatrixShape<kK, kN>)
+  typename WarpShapeB_,
+  /// Block dimensions of the blockwise quantization. So the actual meta data
+  /// warp shape is WarpShapeB_ / BlockingShape_
+  typename BlockingShape_,
+  /// Data type of the quant scales
+  typename ElementScale_,
+  /// Layout of the quant scales
+  typename LayoutScale_,
+  /// Data type of quant offsets
+  typename ElementOffset_,
+  /// Layout of quant offsets
+  typename LayoutOffset_,
+  /// Underlying matrix multiply operator (concept: arch::Mma)
+  typename ArchMmaOperator_,
+  /// Number of threads participating in one matrix operation
+  int Threads,
+  /// Number of partitions along K dimension
+  int PartitionsK_ = 1>
+class QuantBMetaMmaTensorOpTileIterator;
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Specialization for column major layout
+
+template <
+  /// Shape of the operand B matrix to load in a warp (concept: MatrixShape<kK, kN>)
+  typename WarpShapeB_,
+  /// Block dimensions of the blockwise quantization. So the actual meta data
+  /// warp shape is WarpShapeB_ / BlockingShape_
+  typename BlockingShape_,
+  /// Data type of the meta data elements
+  typename ElementScale_,
+  /// Data type of quant offsets
+  typename ElementOffset_,
+  /// Underlying matrix multiply operator (concept: arch::Mma)
+  typename ArchMmaOperator_,
+  /// Number of threads participating in one matrix operation
+  int Threads>
+class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
+    ElementScale_, cutlass::layout::ColumnMajor,
+    ElementOffset_, cutlass::layout::ColumnMajor,
+    ArchMmaOperator_, Threads, 1>{
+public:
+
+  using WarpShapeB = WarpShapeB_;
+  using BlockingShape = BlockingShape_;
+  using ElementScale = ElementScale_;
+  using Layout = cutlass::layout::ColumnMajor;
+  using ElementOffset = ElementOffset_;
+  using ArchMmaOperator = ArchMmaOperator_;
+
+  static constexpr bool kHasOffset = !(std::is_same<ElementOffset, std::monostate>::value);
+
+  static_assert(BlockingShape::kRow == 1 && BlockingShape::kColumn > 1,
+          "Only support row blocking for column major layout");
+
+  using MetaTile = QuantBMetaMmaTile<WarpShapeB, BlockingShape, ArchMmaOperator, Threads>;
+
+  /// Number of MMA instructions for this tile
+  static constexpr int kMmaIterationsB = MetaTile::kMmaIterationsB;
+
+  /// Number of B elements per mma tile fragment (32b), 2 for half precision, 4 for int8
+  static constexpr int kNumBsPerCoreTileFragement = MetaTile::kNumBsPerCoreTileFragement;
+
+  /// Each mma instruction can process either 1 or 2 operand B tiles (stacked on the k dimension)
+  static constexpr int kBTilesPerMma = MetaTile::kBTilesPerMma;
+
+  /// Number of B elements a fragment of meta data should cover
+  static constexpr int kExpandedSize = MetaTile::kExpandedSize;
+
+  /// Number of meta elements per core tile fragment
+  static constexpr int kCoreTileFragementSize = MetaTile::kCoreTileFragementSize;
+
+  /// stride for reaching the next core tile (if there is one) on the K dimension
+  static constexpr int kKTileStride = MetaTile::kKTileStride;
+
+  /// do we need to load meta data for the next core tile on the K dimension?
+  static constexpr int kTilesPerMma = MetaTile::kTilesPerMma;
+
+  static constexpr int kNStride = MetaTile::kNStride;
+  static constexpr int kNRepeats = MetaTile::kNRepeats;
+  static constexpr int kMmaIterations = MetaTile::kMmaIterations;
+
+  using TensorRefScale = TensorRef<ElementScale, Layout>;
+  using TensorRefOffset = TensorRef<ElementOffset, Layout>;
+  using TensorCoord = typename Layout::TensorCoord;
+
+  using Index = typename Layout::Index;
+  using LongIndex = typename Layout::LongIndex;
+  using StrideIndex = typename Layout::Stride::Index;
+
+  using FragmentScale = Array<ElementScale, MetaTile::kFragementSize>;
+  using FragmentOffset = typename std::conditional<kHasOffset,
+          Array<ElementOffset, MetaTile::kFragementSize>,
+          std::monostate>::type;
+
+  using AccessTypeScale = Array<ElementScale, kCoreTileFragementSize>;
+  using AccessTypeOffset = Array<ElementOffset, kCoreTileFragementSize>;
+
+private:
+
+  ElementScale *pointer_;
+  Layout layout_;
+
+  ElementOffset *pointer_offset_;
+  Layout layout_offset_;
+
+  TensorCoord lane_position_;
+
+public:
+
+  CUTLASS_DEVICE
+  QuantBMetaMmaTensorOpTileIterator() { }
+
+  CUTLASS_DEVICE
+  QuantBMetaMmaTensorOpTileIterator(
+    TensorRefScale const &ref,
+    TensorRefOffset const &ref_offset,
+    int lane_idx
+  ):
+    pointer_(ref.data()),
+    layout_(ref.layout()),
+    pointer_offset_(ref_offset.data()),
+    layout_offset_(ref_offset.layout()),
+    lane_position_(MetaTile::lane_position(lane_idx)){}
+
+  /// Loads a fragment
+  CUTLASS_HOST_DEVICE
+  void load(FragmentScale &frag, FragmentOffset &frag_offset) {
+    if constexpr(kNumBsPerCoreTileFragement == 2
+                 && kBTilesPerMma == 2){
+      // Optimize for a special case of:
+      //    16b gemm (kNumBsPerCoreTileFragement == 2)
+      //    2 B operand tiles per mma (kBTilesPerMma == 2)
+      //    (1,n) quantization blocking (BlockingShape::kRow == 1)
+      // The scale and offset tensors are prepacked to reduce the number of load instructions needed
+      const int row = lane_position_.row();
+      const int column = lane_position_.column() / BlockingShape::kColumn;
+
+      Array<ElementScale, 4> *dst_ptr = reinterpret_cast<Array<ElementScale, 4>*>(frag.data());
+      CUTLASS_PRAGMA_UNROLL
+      for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){
+        Array<ElementScale, 4> *src_ptr = reinterpret_cast<Array<ElementScale, 4>*>(pointer_ + layout_({row, c}));
+        *dst_ptr = *src_ptr;
+        dst_ptr++;
+      }
+
+      if constexpr(kHasOffset){
+        Array<ElementOffset, 4> *dst_ptr_offset = reinterpret_cast<Array<ElementOffset, 4>*>(frag_offset.data());
+        CUTLASS_PRAGMA_UNROLL
+        for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){
+          Array<ElementOffset, 4> *src_ptr_offset = reinterpret_cast<Array<ElementOffset, 4>*>(pointer_offset_ + layout_offset_({row, c}));
+          *dst_ptr_offset = *src_ptr_offset;
+          dst_ptr_offset++;
+        }
+      }
+
+    } else {
+      // Other cases, offsets and scales are not prepacked.
+
+      const int row = lane_position_.row() / BlockingShape::kRow;
+      const int column = lane_position_.column() / BlockingShape::kColumn;
+
+      AccessTypeScale* dst_ptr = reinterpret_cast<AccessTypeScale*>(frag.data());
+      CUTLASS_PRAGMA_UNROLL
+      for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){
+        CUTLASS_PRAGMA_UNROLL
+        for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){
+          AccessTypeScale* src_ptr = reinterpret_cast<AccessTypeScale*>(pointer_ + layout_({r, c}));
+          *dst_ptr = *src_ptr;
+          dst_ptr++;
+        }
+      }
+
+      if constexpr(kHasOffset){
+        AccessTypeOffset* dst_ptr = reinterpret_cast<AccessTypeOffset*>(frag_offset.data());
+        CUTLASS_PRAGMA_UNROLL
+        for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){
+          CUTLASS_PRAGMA_UNROLL
+          for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){
+            AccessTypeOffset* src_ptr = reinterpret_cast<AccessTypeOffset*>(pointer_offset_ + layout_offset_({r, c}));
+            *dst_ptr = *src_ptr;
+            dst_ptr++;
+          }
+        }
+      }
+    }
+  }
+
+  template <typename ElementT>
+  CUTLASS_HOST_DEVICE
+  static Array<ElementT, kExpandedSize> debug_expand(Array<ElementT, MetaTile::kFragementSize> const &frag){
+    Array<ElementT, kExpandedSize> ret;
+    int out_idx = 0;
+    CUTLASS_PRAGMA_UNROLL
+    for (int n_out = 0; n_out < kMmaIterationsB; n_out++){
+      int n_idx = n_out / kNRepeats;
+      CUTLASS_PRAGMA_UNROLL
+      for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){
+        int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma);
+        CUTLASS_PRAGMA_UNROLL
+        for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){
+          int elem_idx = elem_out_idx / BlockingShape::kRow;
+          int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma;
+          ret[out_idx] = frag[idx];
+          out_idx++;
+        }
+      }
+    }
+    return ret;
+  }
+
+  CUTLASS_HOST_DEVICE
+  static void dequant(FragmentScale const &scales,
+                      FragmentOffset const &offsets,
+                      Array<uint8_t,kExpandedSize/2> const &weights,
+                      Array<ElementScale, kExpandedSize>& dest){
+    static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm.");
+    static_assert(kExpandedSize % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile.");
+
+    // First convert 4b weight into fp16(weight + 16)
+    weights2Half(weights, dest);
+
+    if constexpr(kBTilesPerMma == 2){
+      // Optimize for a special case of:
+      //    2 B operand tiles per mma (kBTilesPerMma == 2)
+      //    (1,n) quantization blocking (BlockingShape::kRow == 1)
+
+      uint32_t* dest_pair = reinterpret_cast<uint32_t*>(dest.data());
+      const b64* scales_ptr = reinterpret_cast<const b64*>(scales.data());
+      const ElementOffset* offsets_ptr = nullptr;
+      if constexpr(kHasOffset) { offsets_ptr = offsets.data(); }
+
+      CUTLASS_PRAGMA_UNROLL
+      for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){
+        // dequantize: d = scale * (weight - offset)
+        // to use FMA, d = scale * weight + (scale * (-offset))
+
+        b64 offsets;
+        if constexpr(kHasOffset){
+          const uint32_t* p = reinterpret_cast<const uint32_t*>(offsets_ptr);
+
+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
+          asm volatile(
+              "{\n\t"
+              "  .reg  .b32    rb0, rb1;\n"      // b32 regs for fp16x2 mul operands
+
+              // static_cast<cutlass::half_t>(-16 - offset)
+              // input [d, b, c, a],
+              "  shl.b32       rb0, %4, 6;\n"     // rb0 = [x, b, x, a] << 6
+              "  shr.u32       rb1, %4, 2;\n"     // rb1 = [x, d, x, c] << 6
+              "  lop3.b32      rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00
+              "  lop3.b32      rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n"
+              "  mul.rn.f16x2  %0, %2, rb0;\n"    // offset = scale * (-16 - offset)
+              "  mul.rn.f16x2  %1, %3, rb1;\n"
+              "}\n"
+              : "=r"(offsets.pair.a), "=r"(offsets.pair.b)
+              : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b),
+                "r"(p[0]));
+#else
+          assert(0);
+#endif
+
+          offsets_ptr += 4;
+        } else {
+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
+          asm volatile(
+              "{\n\t"
+              "  .reg  .b32    rb0;\n"
+              "  mov.u32       rb0, 0xce00ce00;\n"
+              "  mul.rn.f16x2  %0, %2, rb0;\n"    // offset = scale * (-16 - 8)
+              "  mul.rn.f16x2  %1, %3, rb0;\n"
+              "}\n"
+              : "=r"(offsets.pair.a), "=r"(offsets.pair.b)
+              : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b));
+#else
+          offsets.fp16_quad.a = scales_ptr->fp16_quad.a * static_cast<cutlass::half_t>(-16-8);
+          offsets.fp16_quad.b = scales_ptr->fp16_quad.b * static_cast<cutlass::half_t>(-16-8);
+          offsets.fp16_quad.c = scales_ptr->fp16_quad.c * static_cast<cutlass::half_t>(-16-8);
+          offsets.fp16_quad.d = scales_ptr->fp16_quad.d * static_cast<cutlass::half_t>(-16-8);
+#endif
+        }
+
+        CUTLASS_PRAGMA_UNROLL
+        for (int n_r = 0; n_r < kNRepeats; n_r++){
+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
+          asm volatile(
+              "{\n\t"
+              "  fma.rn.f16x2  %0, %2, %0, %4;\n" // dest = scale * (16 + weight) +  (scale * (-16 - offset))
+              "  fma.rn.f16x2  %1, %3, %1, %5;\n"
+              "}\n"
+              : "+r"(dest_pair[0]), "+r"(dest_pair[1])
+              : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b),
+                "r"(offsets.pair.a), "r"(offsets.pair.b));
+#else
+          assert(0);
+#endif
+          dest_pair += 2;
+        }
+        scales_ptr++;
+      }
+
+    } else {
+      // unoptiomized path for other cases, very slow
+      int out_idx = 0;
+      ElementScale offset;
+      CUTLASS_PRAGMA_UNROLL
+      for (int n_out = 0; n_out < kMmaIterationsB; n_out++){
+        int n_idx = n_out / kNRepeats;
+        CUTLASS_PRAGMA_UNROLL
+        for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){
+          int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma);
+          CUTLASS_PRAGMA_UNROLL
+          for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){
+            int elem_idx = elem_out_idx / BlockingShape::kRow;
+            int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma;
+            ElementScale s = scales[idx];
+            if constexpr(kHasOffset){
+              offset = s * static_cast<ElementScale>(-16 - int(offsets[idx]));
+            } else {
+              offset = s * static_cast<ElementScale>(-16-8);
+            }
+            dest[out_idx] = s * dest[out_idx] + offset;
+            out_idx++;
+          }
+        }
+      }
+
+    }
+
+  }
+
+  /// Advances the pointer
+  CUTLASS_HOST_DEVICE
+  QuantBMetaMmaTensorOpTileIterator &operator++() {
+    // This is for operand B, so advance on the K dimension
+    lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0);
+    return *this;
+  }
+
+  CUTLASS_DEVICE
+  QuantBMetaMmaTensorOpTileIterator &add_tile_offset(
+      TensorCoord const &tile_offset) {
+    int rows = tile_offset.row() * MetaTile::TileShapeB::kRow;
+    int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn;
+    lane_position_ += TensorCoord(rows, columns);
+    return *this;
+  }
+
+};
+
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Specialization for row major layout
+
+template <
+  /// Shape of the operand B matrix to load in a warp (concept: MatrixShape<kK, kN>)
+  typename WarpShapeB_,
+  /// Block dimensions of the blockwise quantization. So the actual meta data
+  /// warp shape is WarpShapeB_ / BlockingShape_
+  typename BlockingShape_,
+  /// Data type of the meta data elements
+  typename ElementScale_,
+  /// Data type of quant offsets
+  typename ElementOffset_,
+  /// Underlying matrix multiply operator (concept: arch::Mma)
+  typename ArchMmaOperator_,
+  /// Number of threads participating in one matrix operation
+  int Threads>
+class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
+    ElementScale_, cutlass::layout::RowMajor,
+    ElementOffset_, cutlass::layout::RowMajor,
+    ArchMmaOperator_, Threads, 1>{
+public:
+
+  using WarpShapeB = WarpShapeB_;
+  using BlockingShape = BlockingShape_;
+  using ElementScale = ElementScale_;
+  using ElementOffset = ElementOffset_;
+  using Layout = cutlass::layout::RowMajor;
+  using ArchMmaOperator = ArchMmaOperator_;
+
+  static constexpr bool kHasOffset = !(std::is_same<ElementOffset, std::monostate>::value);
+
+  static_assert(BlockingShape::kColumn == 1 && BlockingShape::kRow > 1,
+          "Only support column blocking for row major layout");
+
+  using MetaTile = QuantBMetaMmaTile<WarpShapeB, BlockingShape, ArchMmaOperator, Threads>;
+
+  /// Number of MMA instructions for this tile
+  static constexpr int kMmaIterationsB = MetaTile::kMmaIterationsB;
+
+  /// Number of B elements per mma tile fragment (32b), 2 for half precision, 4 for int8
+  static constexpr int kNumBsPerCoreTileFragement = MetaTile::kNumBsPerCoreTileFragement;
+
+  /// Each mma instruction can process either 1 or 2 operand B tiles (stacked on the k dimension)
+  static constexpr int kBTilesPerMma = MetaTile::kBTilesPerMma;
+
+  /// Number of B elements a fragment of meta data should cover
+  static constexpr int kExpandedSize = MetaTile::kExpandedSize;
+
+  /// Number of meta elements per core tile fragment
+  static constexpr int kCoreTileFragementSize = MetaTile::kCoreTileFragementSize;
+
+  /// stride for reaching the next core tile (if there is one) on the K dimension
+  static constexpr int kKTileStride = MetaTile::kKTileStride;
+
+  /// do we need to load meta data for the next core tile on the K dimension?
+  static constexpr int kTilesPerMma = MetaTile::kTilesPerMma;
+
+  static constexpr int kNStride = MetaTile::kNStride;
+  static constexpr int kNRepeats = MetaTile::kNRepeats;
+  static constexpr int kMmaIterations = MetaTile::kMmaIterations;
+
+  using TensorRefScale = TensorRef<ElementScale, Layout>;
+  using TensorRefOffset = TensorRef<ElementOffset, Layout>;
+  using TensorCoord = typename Layout::TensorCoord;
+
+  using Index = typename Layout::Index;
+  using LongIndex = typename Layout::LongIndex;
+  using StrideIndex = typename Layout::Stride::Index;
+
+  using FragmentScale = Array<ElementScale, MetaTile::kFragementSize>;
+  using FragmentOffset = typename std::conditional<kHasOffset,
+          Array<ElementOffset, MetaTile::kFragementSize>,
+          std::monostate>::type;
+
+private:
+
+  ElementScale *pointer_;
+  Layout layout_;
+
+  ElementOffset *pointer_offset_;
+  Layout layout_offset_;
+
+  TensorCoord lane_position_;
+
+public:
+
+  CUTLASS_DEVICE
+  QuantBMetaMmaTensorOpTileIterator() { }
+
+  CUTLASS_DEVICE
+  QuantBMetaMmaTensorOpTileIterator(
+    TensorRefScale const &ref,
+    TensorRefOffset const &ref_offset,
+    int lane_idx
+  ):
+    pointer_(ref.data()),
+    layout_(ref.layout()),
+    pointer_offset_(ref_offset.data()),
+    layout_offset_(ref_offset.layout()),
+    lane_position_(MetaTile::lane_position(lane_idx))
+     {}
+
+  /// Loads a fragment
+  CUTLASS_HOST_DEVICE
+  void load(FragmentScale &frag, FragmentOffset &frag_offset) {
+    const int row = lane_position_.row() / BlockingShape::kRow;
+    const int column = lane_position_.column() / BlockingShape::kColumn;
+    static_assert(kTilesPerMma * kCoreTileFragementSize == 1, "Only support one meta data per core tile");
+
+    ElementScale* src_ptr = pointer_ + layout_({row, column});
+    ElementScale* dst_ptr = frag.data();
+    CUTLASS_PRAGMA_UNROLL
+    for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){
+      dst_ptr[n_idx] = src_ptr[n_idx * kNStride];
+    }
+
+    if constexpr(kHasOffset){
+      ElementOffset* src_ptr_offset = pointer_offset_ + layout_offset_({row, column});
+      ElementOffset* dst_ptr_offset = frag_offset.data();
+      CUTLASS_PRAGMA_UNROLL
+      for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){
+        dst_ptr_offset[n_idx] = src_ptr_offset[n_idx * kNStride];
+      }
+    }
+  }
+
+  template <typename ElementT>
+  CUTLASS_HOST_DEVICE
+  static Array<ElementT, kExpandedSize> debug_expand(Array<ElementT, MetaTile::kFragementSize> const &frag){
+    Array<ElementT, kExpandedSize> ret;
+
+    int out_idx = 0;
+    CUTLASS_PRAGMA_UNROLL
+    for (int n_out = 0; n_out < kMmaIterationsB; n_out++){
+      int n_idx = n_out / kNRepeats;
+      CUTLASS_PRAGMA_UNROLL
+      for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){
+        int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma);
+        CUTLASS_PRAGMA_UNROLL
+        for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){
+          int elem_idx = elem_out_idx / BlockingShape::kRow;
+          int col = elem_idx + mma_tile_idx * kCoreTileFragementSize;
+          int idx = col * kMmaIterations + n_idx;
+          ret[out_idx] = frag[idx];
+          out_idx++;
+        }
+      }
+    }
+    return ret;
+  }
+
+  CUTLASS_HOST_DEVICE
+  static void dequant(FragmentScale const &scales,
+                      FragmentOffset const &offsets,
+                      Array<uint8_t,kExpandedSize/2> const &weights,
+                      Array<ElementScale, kExpandedSize>& dest){
+    static_assert(kNRepeats == 1, "This is implied by BlockingShape::kColumn == 1");
+    static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm now.");
+
+    // First convert 4b weight into fp16(weight + 16)
+    weights2Half(weights, dest);
+
+    ElementScale addon[kMmaIterationsB];
+    if constexpr (kMmaIterationsB % 4 == 0) {
+      const b64* scales_ptr = reinterpret_cast<const b64*>(scales.data());
+      uint32_t* addon_ptr = reinterpret_cast<uint32_t*>(addon);
+      if constexpr(kHasOffset){
+        const uint32_t* p = reinterpret_cast<const uint32_t*>(offsets.data());
+        CUTLASS_PRAGMA_UNROLL
+        for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){
+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
+          asm volatile(
+            "{\n\t"
+            "  .reg  .b32    rb0, rb1, rb2;\n"
+
+            // offset from [d, c, b, a] --> [d, b, c, a]
+            "  prmt.b32      rb2, %4, rb0, 0x3120;\n"
+
+            // static_cast<cutlass::half_t>(-16 - offset)
+            // input [d, b, c, a],
+            "  shl.b32       rb0, rb2, 6;\n"     // rb0 = [x, b, x, a] << 6
+            "  shr.u32       rb1, rb2, 2;\n"     // rb1 = [x, d, x, c] << 6
+            "  lop3.b32      rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00
+            "  lop3.b32      rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n"
+            "  mul.rn.f16x2  %0, %2, rb0;\n"    // offset = scale * (-16 - offset)
+            "  mul.rn.f16x2  %1, %3, rb1;\n"
+            "}\n"
+            : "=r"(addon_ptr[0]), "=r"(addon_ptr[1])
+            : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b),
+              "r"(p[0]));
+#else
+          assert(0);
+#endif
+          scales_ptr++;
+          p++;
+          addon_ptr += 2;
+        }
+      } else {
+        CUTLASS_PRAGMA_UNROLL
+        for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){
+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
+          asm volatile(
+            "{\n\t"
+            "  .reg  .b32    rb0;\n"
+            "  mov.u32       rb0, 0xce00ce00;\n"
+            "  mul.rn.f16x2  %0, %2, rb0;\n"    // offset = scale * (-16 - 8)
+            "  mul.rn.f16x2  %1, %3, rb0;\n"
+            "}\n"
+            : "=r"(addon_ptr[0]), "=r"(addon_ptr[1])
+            : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b));
+#else
+          assert(0);
+#endif
+          scales_ptr++;
+          addon_ptr += 2;
+        }
+      }
+    } else if constexpr (kMmaIterationsB % 2 == 0) {
+      const uint32_t* scales_ptr = reinterpret_cast<const uint32_t*>(scales.data());
+      uint32_t* addon_ptr = reinterpret_cast<uint32_t*>(addon);
+
+      if constexpr (kHasOffset){
+        // possible buffer over read 2 bytes here.
+        const uint32_t* p = reinterpret_cast<const uint32_t*>(offsets.data());
+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
+        asm volatile(
+          "{\n\t"
+          "  .reg  .b32    rb0, rb1, rb2;\n"
+
+          // offset from [?, ?, b, a] --> [?, b, ?, a]
+          "  prmt.b32      rb2, %2, rb0, 0x3120;\n"
+
+          // static_cast<cutlass::half_t>(-16 - offset)
+          // input [d, b, c, a],
+          "  shl.b32       rb0, rb2, 6;\n"     // rb0 = [x, b, x, a] << 6
+          "  lop3.b32      rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00
+          "  mul.rn.f16x2  %0, %1, rb0;\n"    // offset = scale * (-16 - offset)
+          "}\n"
+          : "=r"(addon_ptr[0])
+          : "r"(scales_ptr[0])
+            "r"(p[0]));
+#else
+        assert(0);
+#endif
+      } else {
+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
+        asm volatile(
+          "{\n\t"
+          "  .reg  .b32    rb0;\n"
+          "  mov.u32       rb0, 0xce00ce00;\n"
+          "  mul.rn.f16x2  %0, %1, rb0;\n"    // offset = scale * (-16 - 8)
+          "}\n"
+          : "=r"(addon_ptr[0])
+          : "r"(scales_ptr[0]));
+#else
+        assert(0);
+#endif
+      }
+    } else {
+      // kMmaIterationsB == 1
+      if constexpr(kHasOffset){
+        uint8_t zp = offsets[0];
+        addon[0] = scales[0] * static_cast<ElementScale>(-16 - static_cast<int>(zp));
+      } else {
+        addon[0] = scales[0] * static_cast<ElementScale>(-16-8);
+      }
+    }
+
+    int out_idx = 0;
+    CUTLASS_PRAGMA_UNROLL
+    for (int n_out = 0; n_out < kMmaIterationsB; n_out++){
+      CUTLASS_PRAGMA_UNROLL
+      for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){
+        dest[out_idx] = scales[n_out] * dest[out_idx] + addon[n_out];
+        dest[out_idx + 1] = scales[n_out] * dest[out_idx + 1] + addon[n_out];
+        out_idx += 2;
+      }
+    }
+  }
+
+  /// Advances the pointer
+  CUTLASS_HOST_DEVICE
+  QuantBMetaMmaTensorOpTileIterator &operator++() {
+    // This is for operand B, so advance on the K dimension
+    lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0);
+    return *this;
+  }
+
+  CUTLASS_DEVICE
+  QuantBMetaMmaTensorOpTileIterator &add_tile_offset(
+      TensorCoord const &tile_offset) {
+    int rows = tile_offset.row() * MetaTile::TileShapeB::kRow;
+    int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn;
+    lane_position_ += TensorCoord(rows, columns);
+    return *this;
+  }
+
+};
+
+
+////////////////////////////////////////////////////////////////////////////////
+}  // namespace warp
+}  // namespace gemm
+}  // namespace cutlass
diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h
new file mode 100644
index 0000000000000..f29cedf326a44
--- /dev/null
+++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h
@@ -0,0 +1,361 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+/**
+ * Modifications Copyright (c) Microsoft.
+ * Licensed under the MIT license.
+ *
+ * @file quantb_mma_tensor_op.h
+ * @brief Modified from cutlass/gemm/warp/mma_tensor_op.h
+ * Templates implementing warp-level matrix multiply-accumulate operations
+ * targeting tensor cores.
+ */
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include "cutlass/array.h"
+#include "cutlass/platform/platform.h"
+
+#include "cutlass/numeric_conversion.h"
+#include "cutlass/numeric_types.h"
+#include "cutlass/matrix_shape.h"
+
+#include "cutlass/arch/memory_sm75.h"
+#include "cutlass/arch/mma_sm75.h"
+#include "cutlass/arch/mma_sm80.h"
+
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/warp/mma.h"
+#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
+#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
+
+#include "cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace warp {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
+template <
+  /// Size of the Gemm problem - concept: gemm::GemmShape<>
+  typename Shape_,
+  /// Data type of A elements
+  typename ElementA_,
+  /// Layout of A matrix (concept: MatrixLayout)
+  typename LayoutA_,
+  /// Data type of B elements
+  typename ElementB_,
+  /// Layout of B matrix (concept: MatrixLayout)
+  typename LayoutB_,
+  /// Data type of quant scales
+  typename ElementQScale_,
+  /// Layout of quant scales (concept: MatrixLayout)
+  typename SmemLayoutQScale_,
+  /// Data type of quant offsets
+  typename ElementQOffset_,
+  /// Layout of quant offsets (concept: MatrixLayout)
+  typename SmemLayoutQOffset_,
+  /// Blocking dimensions of quantization
+  typename QuantBlocking_,
+  /// Element type of C matrix
+  typename ElementC_,
+  /// Layout of C matrix (concept: MatrixLayout)
+  typename LayoutC_,
+  /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
+  typename Policy_,
+  /// Number of partitions along K dimension
+  int PartitionsK_ = 1,
+  /// Store the accumulators in row major or column major.  Row major is used
+  /// when output layout is interleaved.
+  bool AccumulatorsInRowMajor = false,
+  /// Used for partial specialization
+  typename Enable = bool
+>
+class QuantBMmaTensorOp {
+public:
+  /// Shape of warp-level matrix operation (concept: GemmShape)
+  using Shape = Shape_;
+
+  /// Data type of multiplicand A
+  using ElementA = ElementA_;
+
+  /// Layout of multiplicand A
+  using LayoutA = LayoutA_;
+
+  /// Data type of multiplicand B
+  using ElementB = ElementB_;
+
+  /// Layout of multiplicand B
+  using LayoutB = LayoutB_;
+
+  /// Data type of accumulator matrix C
+  using ElementC = ElementC_;
+
+  /// Layout of accumulator matrix C
+  using LayoutC = LayoutC_;
+
+  /// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
+  using Policy = Policy_;
+
+  /// Underlying matrix multiply operator (concept: arch::Mma)
+  using ArchMmaOperator = typename Policy::Operator;
+
+  /// Indicates math operator
+  using MathOperator = typename ArchMmaOperator::Operator;
+
+  /// Architecture tag from underlying instruction
+  using ArchTag = typename ArchMmaOperator::ArchTag;
+
+  /// Indicates class of matrix operator
+  using OperatorClass = arch::OpClassTensorOp;
+
+  /// Shape of underlying instruction
+  using InstructionShape = typename ArchMmaOperator::Shape;
+
+  /// Complex transform on A operand
+  static ComplexTransform const kTransformA = ComplexTransform::kNone;
+
+  /// Complex transform on B operand
+  static ComplexTransform const kTransformB = ComplexTransform::kNone;
+
+  /// Number of threads participating in warp-level matrix product
+  static int const kThreadCount = 32;
+
+  /// Number of partitions along K dimension
+  static int const kPartitionsK = PartitionsK_;
+
+public:
+
+  /// Iterates over the A operand in memory
+  using IteratorA = MmaTensorOpMultiplicandTileIterator<
+     MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
+     MatrixShape<ArchMmaOperator::Shape::kM, ArchMmaOperator::Shape::kK>,
+     Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
+
+  /// Storage for A tile
+  using FragmentA = typename IteratorA::Fragment;
+
+  /// Storage for transformed A tile
+  using TransformedFragmentA =
+      Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
+
+  /// Iterates over the B operand in memory
+  using IteratorB = MmaTensorOpMultiplicandTileIterator<
+      MatrixShape<Shape::kK/2, Shape::kN/2>, Operand::kB, ElementB, LayoutB,
+      MatrixShape<ArchMmaOperator::Shape::kK/2, ArchMmaOperator::Shape::kN/2>,
+      Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
+  // warp B MatrixShape<64, 64>,
+  // layout B cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<16, 64>,
+  // instruction op shape cutlass::MatrixShape<16, 8>,
+  // kPartitionsK 1
+  // FragmentB::kElements 32
+
+  /// Storage for B tile
+  using FragmentB = typename IteratorB::Fragment; // cutlass::Array<cutlass::half_t, 8>
+
+  /// Storage for transformed B tile
+  /// When loading weights, we packed 4 int4 weights into one 2-byte-element, when expanded
+  /// we multiply the number of elements by 4.
+  /// TODO: make sure ArchMmaOperator::ElementB same as dequantized ElementB
+  /// and change the transform function below to perform dequantization
+  using TransformedFragmentB =
+      Array<typename ArchMmaOperator::ElementB, FragmentB::kElements * 4>;
+
+  /// Iterates over the C operand in memory
+  using IteratorC = MmaTensorOpAccumulatorTileIterator<
+     MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
+     typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
+
+  /// Storage for C tile
+  using FragmentC = typename IteratorC::Fragment;
+
+  using ElementQScale = ElementQScale_;
+  using SmemLayoutQScale = SmemLayoutQScale_;
+  using QuantBlocking = QuantBlocking_;
+
+  using ElementQOffset = ElementQOffset_;
+  using SmemLayoutQOffset = SmemLayoutQOffset_;
+
+  /// Iterates over the quantization parameters in memory
+  using WarpQScaleShape = MatrixShape<(Shape::kK / QuantBlocking::kRow), (Shape::kN / QuantBlocking::kColumn)>;
+  static_assert(Shape::kK % QuantBlocking::kRow == 0, "K must be multiple of QuantBlocking::kRow");
+  static_assert(Shape::kN % QuantBlocking::kColumn == 0, "N must be multiple of QuantBlocking::kColumn");
+  static_assert(WarpQScaleShape::kCount > 0, "QuantBlocking too big to fit in a warp block!");
+
+  // TODO This is an expanding iterator, it needs to replicate the quantization parameters
+  // to all threads in the warp.
+  using IteratorQMeta = QuantBMetaMmaTensorOpTileIterator<
+    MatrixShape<Shape::kK, Shape::kN>, QuantBlocking, ElementQScale, SmemLayoutQScale,
+    ElementQOffset, SmemLayoutQOffset,
+    ArchMmaOperator, kThreadCount, kPartitionsK>;
+
+  using FragmentQScale = typename IteratorQMeta::FragmentScale;
+  using FragmentQOffset = typename IteratorQMeta::FragmentOffset;
+
+  /// Number of mma operations performed
+  using MmaIterations = MatrixShape<
+    (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
+    (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN
+  >;
+
+public:
+
+  /// Underlying matrix multiply operator (concept: arch::Mma)
+  ArchMmaOperator mma;
+
+public:
+
+  //
+  // Methods
+  //
+
+  /// Ctor
+  CUTLASS_DEVICE
+  QuantBMmaTensorOp() {}
+
+  /// Performs a warp-level matrix multiply-accumulate operation
+  CUTLASS_DEVICE
+  void operator()(
+    FragmentC &D,
+    TransformedFragmentA const &A,
+    TransformedFragmentB const &B,
+    FragmentC const &C
+  ) const {
+
+    using MmaOperandA = typename ArchMmaOperator::FragmentA;
+    using MmaOperandB = typename ArchMmaOperator::FragmentB;
+    using MmaOperandC = typename ArchMmaOperator::FragmentC;
+
+    D = C;
+
+    MmaOperandA const *ptr_A = reinterpret_cast<MmaOperandA const *>(&A);
+    MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);
+    MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);
+
+    #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
+      // Serpentine visitation order maximizing reuse of Rb
+      // The visitation order is like
+      //      _
+      //   | | | |
+      //   | | | |
+      //   |_| |_|
+      //
+      // Down Up Down Up
+
+      CUTLASS_PRAGMA_UNROLL
+      for (int n = 0; n < MmaIterations::kColumn; ++n) {
+
+        CUTLASS_PRAGMA_UNROLL
+        for (int m = 0; m < MmaIterations::kRow; ++m) {
+
+          int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
+
+          if (AccumulatorsInRowMajor) {  // matrix B is reordered
+            mma(
+              ptr_D[n + m_serpentine * MmaIterations::kColumn],
+              ptr_A[m_serpentine],
+              ptr_B[n],
+              ptr_D[n + m_serpentine * MmaIterations::kColumn]);
+          } else {
+            mma(
+              ptr_D[m_serpentine + n * MmaIterations::kRow],
+              ptr_A[m_serpentine],
+              ptr_B[n],
+              ptr_D[m_serpentine + n * MmaIterations::kRow]);
+          }
+        }
+      }
+    #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
+      // Serpentine visitation order maximizing reuse of Ra
+      // The visitation order is like
+      //   _________
+      //   _________|
+      //  |_________
+      //  __________|
+      //
+      // Right Left Right Left
+
+      CUTLASS_PRAGMA_UNROLL
+      for (int m = 0; m < MmaIterations::kRow; ++m) {
+
+        CUTLASS_PRAGMA_UNROLL
+        for (int n = 0; n < MmaIterations::kColumn; ++n) {
+
+          int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
+
+          if (AccumulatorsInRowMajor) {  // matrix B is reordered
+            mma(
+              ptr_D[n_serpentine + m * MmaIterations::kColumn],
+              ptr_A[m],
+              ptr_B[n_serpentine],
+              ptr_D[n_serpentine + m * MmaIterations::kColumn]);
+          } else {
+            mma(ptr_D[m + n_serpentine * MmaIterations::kRow],
+                ptr_A[m],
+                ptr_B[n_serpentine],
+                ptr_D[m + n_serpentine * MmaIterations::kRow]);
+          }
+        }
+      }
+    #else
+      assert(0);
+    #endif
+  }
+
+  /// Transform the mma operands to the required types
+  CUTLASS_DEVICE
+  void transform(TransformedFragmentB &dst_B,
+                 FragmentB const &B,
+                 FragmentQScale const &scales,
+                 FragmentQOffset const &offsets) const {
+
+    Array<uint8_t, FragmentB::kElements * 2> const *ptr_B =
+        reinterpret_cast<Array<uint8_t, FragmentB::kElements * 2> const *>(&B);
+    IteratorQMeta::dequant(scales, offsets, *ptr_B, dst_B);
+  }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace warp
+} // namespace gemm
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+//#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp
index c4c54a9be34d8..9d7b0ae06e220 100644
--- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp
+++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp
@@ -687,171 +687,314 @@ QuantizeARow_CompInt8(
     }
 }
 
-template <size_t NCols, size_t SubBlkLen, bool HasZeroPoint>
-MLAS_FORCEINLINE void
-ComputeDotProducts_BlkBitWidth4_CompInt8(
-    size_t BlkLen,
-    const std::byte* QuantARowPtr,
-    const std::byte* QuantBDataColPtr,
-    const float* QuantBScaleColPtr,
-    const std::byte* QuantBZeroPointColPtr,
-    float* SumPtr,
-    size_t CountK,
-    size_t StrideQuantBData,
-    size_t StrideQuantBScale,
-    size_t StrideQuantBZeroPoint,
-    const float* BiasPtr
+template <bool HasZeroPoint>
+void
+SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen16(
+    const std::byte* QuantA,
+    const std::byte* QuantBData,
+    const float* QuantBScale,
+    const std::byte* QuantBZeroPoint,
+    float* C,
+    size_t CountN,
+    size_t BlockCountK,
+    const float* Bias
 )
 {
     constexpr size_t BlkBitWidth = 4;
+    constexpr size_t BlkLen = 16;
 
-    static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4");
-    static_assert(SubBlkLen == 16 || SubBlkLen == 32, "SubBlkLen must be 16 or 32");
+    float* CRowPtr = C;
 
-    assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0);
+    const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
+    const size_t StrideQuantBScale = BlockCountK;
+    const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(BlockCountK);
 
-    [[maybe_unused]] const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F);     // only used if SubBlkLen == 16
-    [[maybe_unused]] const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F);  // only used if SubBlkLen == 32
+    const float* BiasPtr = Bias;
 
-    const std::byte* QuantA = QuantARowPtr;
+    const std::byte* QuantBDataColPtr = QuantBData;
+    const float* QuantBScaleColPtr = QuantBScale;
+    const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint;
 
-    const std::byte* QuantBData = QuantBDataColPtr;
-    const float* QuantBScale = QuantBScaleColPtr;
-    [[maybe_unused]] size_t QuantBZeroPointIdx = 0;  // track half byte increments with this index instead of a pointer
-                                                     // only used if HasZeroPoint == true
+    float* SumPtr = CRowPtr;
 
-    float32x4_t acc[NCols]{};
+    const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F);
+    const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F);
 
-    for (size_t k = 0; k < CountK; k += BlkLen) {
-        const size_t k_blk_len = std::min(CountK - k, BlkLen);
+    for (size_t n = 0; n < CountN; ++n) {
+        const std::byte* QuantAPtr = QuantA;
+        const std::byte* QuantBDataPtr = QuantBDataColPtr;
+        const float* QuantBScalePtr = QuantBScaleColPtr;
+        const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr;
 
-        const float a_scale = Q8BlkScale(QuantA);
-        const int8_t* a_data = Q8BlkData(QuantA);
+        float32x4_t acc0{}, acc1{};
 
-        float b_scale[NCols];
-        UnrolledLoop<NCols>([&](size_t i) { b_scale[i] = QuantBScale[i * StrideQuantBScale]; });
+        size_t k_blks_remaining = BlockCountK;
+        for (; k_blks_remaining > 1; k_blks_remaining -= 2) {
+            const std::byte* QuantABlk0 = QuantAPtr;
+            const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen);
 
-        [[maybe_unused]] int8_t b_zp[NCols];  // only used if HasZeroPoint == true
-        if constexpr (HasZeroPoint) {
-            UnrolledLoop<NCols>([&](size_t i) {
-                const std::byte zp_packed =
-                    QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2];
-                b_zp[i] = ((QuantBZeroPointIdx & 1) == 1)
-                              ? std::to_integer<int8_t>(zp_packed >> 4)
-                              : std::to_integer<int8_t>(zp_packed & std::byte{0x0F});
-            });
-        }
+            // compute combined scale
+            const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]);
+            const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]);
 
-        for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) {
-            // load A row vector
-            int8x16_t av[SubBlkLen / 16];
-            UnrolledLoop<SubBlkLen / 16>([&](size_t i) {
-                av[i] = vld1q_s8(a_data + k_idx_in_blk + i * 16);
-            });
+            // load B zero point
+            const int8x16_t bzp0 = vdupq_n_s8(
+                HasZeroPoint ? std::to_integer<int8_t>(QuantBZeroPointPtr[0] & std::byte{0x0F}) : 8
+            );
+            const int8x16_t bzp1 = vdupq_n_s8(
+                HasZeroPoint ? std::to_integer<int8_t>(QuantBZeroPointPtr[0] >> 4) : 8
+            );
 
-            // load B column vectors
-            int8x16_t bv[NCols][SubBlkLen / 16];
+            // load A
+            const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0));
+            const int8x16_t av1 = vld1q_s8(Q8BlkData(QuantABlk1));
 
-            const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8;
+            // load B
+            const uint8x16_t bv_packed01 = vld1q_u8(reinterpret_cast<const uint8_t*>(QuantBDataPtr));
 
-            if constexpr (SubBlkLen == 16) {
-                uint8x8_t bv_packed[NCols];
-                UnrolledLoop<NCols>([&](size_t i) {
-                    bv_packed[i] = vld1_u8(
-                        reinterpret_cast<const uint8_t*>(QuantBData) + i * StrideQuantBData + b_data_block_offset
-                    );
-                });
+            const uint8x16_t bv_lo01 = vandq_u8(bv_packed01, LowMaskU8x16);
+            const uint8x16_t bv_hi01 = vshrq_n_u8(bv_packed01, 4);
 
-                UnrolledLoop<NCols>([&](size_t i) {
-                    const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMaskU8x8));
-                    const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4));
-                    bv[i][0] = vcombine_s8(lo, hi);
-                });
-            } else {
-                static_assert(SubBlkLen == 32);
+            int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(vget_low_u8(bv_lo01), vget_low_u8(bv_hi01)));
+            int8x16_t bv1 = vreinterpretq_s8_u8(vcombine_u8(vget_high_u8(bv_lo01), vget_high_u8(bv_hi01)));
 
-                uint8x16_t bv_packed[NCols];
-                UnrolledLoop<NCols>([&](size_t i) {
-                    bv_packed[i] = vld1q_u8(
-                        reinterpret_cast<const uint8_t*>(QuantBData) + i * StrideQuantBData + b_data_block_offset
-                    );
-                });
+            // subtract B zero point
+            bv0 = vsubq_s8(bv0, bzp0);
+            bv1 = vsubq_s8(bv1, bzp1);
 
-                UnrolledLoop<NCols>([&](size_t i) {
-                    bv[i][0] = vreinterpretq_s8_u8(vandq_u8(bv_packed[i], LowMaskU8x16));
-                    bv[i][1] = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed[i], 4));
-                });
+            // quantized dot product
+            const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0);
+            const int32x4_t dot1 = vdotq_s32(vdupq_n_s32(0), av1, bv1);
+
+            // convert to float
+            const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0);
+            const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1);
+
+            // multiply by scale and update accumulator
+            acc0 = vfmaq_f32(acc0, dot_f32_0, scale0);
+            acc1 = vfmaq_f32(acc1, dot_f32_1, scale1);
+
+            // increment block pointers
+
+            QuantAPtr += Q8BlkSize(BlkLen) * 2;
+            QuantBDataPtr += 8 * 2;
+            QuantBScalePtr += 2;
+            if constexpr (HasZeroPoint) {
+                QuantBZeroPointPtr += 1;
             }
+        }
+
+        if (k_blks_remaining > 0) {
+            const std::byte* QuantABlk0 = QuantAPtr;
+
+            // compute combined scale
+            const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr));
+
+            // load B zero point
+            const int8x16_t bzp0 = vdupq_n_s8(
+                HasZeroPoint ? std::to_integer<int8_t>(QuantBZeroPointPtr[0] & std::byte{0x0F}) : 8
+            );
+
+            // load A
+            const int8x16_t av0 = vld1q_s8(Q8BlkData(QuantABlk0));
+
+            // load B
+            const uint8x8_t bv_packed0 = vld1_u8(reinterpret_cast<const uint8_t*>(QuantBDataPtr));
+
+            const uint8x8_t bv_lo0 = vand_u8(bv_packed0, LowMaskU8x8);
+            const uint8x8_t bv_hi0 = vshr_n_u8(bv_packed0, 4);
+
+            int8x16_t bv0 = vreinterpretq_s8_u8(vcombine_u8(bv_lo0, bv_hi0));
 
             // subtract B zero point
-            if constexpr (HasZeroPoint) {
-                UnrolledLoop<NCols>([&](size_t i) {
-                    const int8x16_t zp_v = vdupq_n_s8(b_zp[i]);
-                    UnrolledLoop<SubBlkLen / 16>([&](size_t j) {
-                        bv[i][j] = vsubq_s8(bv[i][j], zp_v);
-                    });
-                });
-            } else {
-                const int8x16_t zp_v = vdupq_n_s8(8);
+            bv0 = vsubq_s8(bv0, bzp0);
 
-                UnrolledLoop<NCols>([&](size_t i) {
-                    UnrolledLoop<SubBlkLen / 16>([&](size_t j) {
-                        bv[i][j] = vsubq_s8(bv[i][j], zp_v);
-                    });
-                });
-            }
+            // quantized dot product
+            const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0);
 
-            // compute quantized dot product
-            int32x4_t dot[NCols]{};
-            UnrolledLoop<NCols>([&](size_t i) {
-                UnrolledLoop<SubBlkLen / 16>([&](size_t j) {
-                    dot[i] = vdotq_s32(dot[i], av[j], bv[i][j]);
-                });
-            });
+            // convert to float
+            const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0);
 
-            // convert dot product result to float
-            float32x4_t dot_f32[NCols];
-            UnrolledLoop<NCols>([&](size_t i) {
-                dot_f32[i] = vcvtq_f32_s32(dot[i]);
-            });
+            // multiply by scale and update accumulator
+            acc0 = vfmaq_f32(acc0, dot_f32_0, scale0);
+        }
 
-            // multiply dot product result by scale and update accumulator
-            UnrolledLoop<NCols>([&](size_t i) {
-                const float32x4_t scale_v = vdupq_n_f32(a_scale * b_scale[i]);
-                acc[i] = vfmaq_f32(acc[i], dot_f32[i], scale_v);
-            });
+        *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1);
+        if (BiasPtr) {
+            *SumPtr += *BiasPtr;
         }
 
-        // increment pointers to next block
-        QuantA += Q8BlkSize(BlkLen);
-        QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
-        QuantBScale += 1;
+        // move to next column
+
+        QuantBDataColPtr += StrideQuantBData;
+        QuantBScaleColPtr += StrideQuantBScale;
         if constexpr (HasZeroPoint) {
-            QuantBZeroPointIdx += 1;
+            QuantBZeroPointColPtr += StrideQuantBZeroPoint;
         }
+
+        BiasPtr += BiasPtr != nullptr ? 1 : 0;
+        SumPtr += 1;
     }
+}
 
-    if constexpr (NCols == 4) {
-        float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]);
+template <bool HasZeroPoint>
+void
+SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32(
+    const std::byte* QuantA,
+    const std::byte* QuantBData,
+    const float* QuantBScale,
+    const std::byte* QuantBZeroPoint,
+    float* C,
+    size_t CountN,
+    size_t BlockCountK,
+    const float* Bias
+)
+{
+    constexpr size_t BlkBitWidth = 4;
+    constexpr size_t BlkLen = 32;
 
-        if (BiasPtr != nullptr) {
-            sum = vaddq_f32(sum, vld1q_f32(BiasPtr));
-        }
+    float* CRowPtr = C;
 
-        vst1q_f32(SumPtr, sum);
-    } else {
-        for (size_t i = 0; i < NCols; ++i) {
-            SumPtr[i] = vaddvq_f32(acc[i]);
-            if (BiasPtr != nullptr) {
-                SumPtr[i] += BiasPtr[i];
+    const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
+    const size_t StrideQuantBScale = BlockCountK;
+    const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(BlockCountK);
+
+    const float* BiasPtr = Bias;
+
+    const std::byte* QuantBDataColPtr = QuantBData;
+    const float* QuantBScaleColPtr = QuantBScale;
+    const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint;
+
+    float* SumPtr = CRowPtr;
+
+    const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F);
+
+    for (size_t n = 0; n < CountN; ++n) {
+        const std::byte* QuantAPtr = QuantA;
+        const std::byte* QuantBDataPtr = QuantBDataColPtr;
+        const float* QuantBScalePtr = QuantBScaleColPtr;
+        const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr;
+
+        float32x4_t acc0{}, acc1{};
+
+        size_t k_blks_remaining = BlockCountK;
+        for (; k_blks_remaining > 1; k_blks_remaining -= 2) {
+            const std::byte* QuantABlk0 = QuantAPtr;
+            const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen);
+
+            // compute combined scale
+            const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * QuantBScalePtr[0]);
+            const float32x4_t scale1 = vdupq_n_f32(Q8BlkScale(QuantABlk1) * QuantBScalePtr[1]);
+
+            // load B zero point
+            const int8x16_t bzp0 = vdupq_n_s8(
+                HasZeroPoint ? std::to_integer<int8_t>((*QuantBZeroPointPtr) & std::byte{0x0F}) : 8
+            );
+            const int8x16_t bzp1 = vdupq_n_s8(
+                HasZeroPoint ? std::to_integer<int8_t>((*QuantBZeroPointPtr) >> 4) : 8
+            );
+
+            // load A
+            const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0));
+            const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16);
+            const int8x16_t av_lo1 = vld1q_s8(Q8BlkData(QuantABlk1));
+            const int8x16_t av_hi1 = vld1q_s8(Q8BlkData(QuantABlk1) + 16);
+
+            // load B
+            const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast<const uint8_t*>(QuantBDataPtr));
+            const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast<const uint8_t*>(QuantBDataPtr) + 16);
+
+            int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16));
+            int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4));
+            int8x16_t bv_lo1 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16));
+            int8x16_t bv_hi1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4));
+
+            // subtract B zero point
+            bv_lo0 = vsubq_s8(bv_lo0, bzp0);
+            bv_hi0 = vsubq_s8(bv_hi0, bzp0);
+            bv_lo1 = vsubq_s8(bv_lo1, bzp1);
+            bv_hi1 = vsubq_s8(bv_hi1, bzp1);
+
+            // quantized dot product
+            int32x4_t dot0{}, dot1{};
+            dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0);
+            dot1 = vdotq_s32(vdotq_s32(dot1, av_lo1, bv_lo1), av_hi1, bv_hi1);
+
+            // convert to float
+            const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0);
+            const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1);
+
+            // multiply by scale and update accumulator
+            acc0 = vfmaq_f32(acc0, dot_f32_0, scale0);
+            acc1 = vfmaq_f32(acc1, dot_f32_1, scale1);
+
+            // increment block pointers
+
+            QuantAPtr += Q8BlkSize(BlkLen) * 2;
+            QuantBDataPtr += 16 * 2;
+            QuantBScalePtr += 2;
+            if constexpr (HasZeroPoint) {
+                QuantBZeroPointPtr += 1;
             }
         }
+
+        if (k_blks_remaining > 0) {
+            const std::byte* QuantABlk0 = QuantAPtr;
+
+            // compute combined scale
+            const float32x4_t scale0 = vdupq_n_f32(Q8BlkScale(QuantABlk0) * (*QuantBScalePtr));
+
+            // load B zero point
+            const int8x16_t bzp0 = vdupq_n_s8(
+                HasZeroPoint ? std::to_integer<int8_t>((*QuantBZeroPoint) & std::byte{0x0F}) : 8
+            );
+
+            // load A
+            const int8x16_t av_lo0 = vld1q_s8(Q8BlkData(QuantABlk0));
+            const int8x16_t av_hi0 = vld1q_s8(Q8BlkData(QuantABlk0) + 16);
+
+            // load B
+            const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast<const uint8_t*>(QuantBDataPtr));
+
+            int8x16_t bv_lo0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16));
+            int8x16_t bv_hi0 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4));
+
+            // subtract B zero point
+            bv_lo0 = vsubq_s8(bv_lo0, bzp0);
+            bv_hi0 = vsubq_s8(bv_hi0, bzp0);
+
+            // quantized dot product
+            int32x4_t dot0{};
+            dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0);
+
+            // convert to float
+            const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0);
+
+            // multiply by scale and update accumulator
+            acc0 = vfmaq_f32(acc0, dot_f32_0, scale0);
+        }
+
+        *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1);
+        if (BiasPtr) {
+            *SumPtr += *BiasPtr;
+        }
+
+        // move to next column
+
+        QuantBDataColPtr += StrideQuantBData;
+        QuantBScaleColPtr += StrideQuantBScale;
+        if constexpr (HasZeroPoint) {
+            QuantBZeroPointColPtr += StrideQuantBZeroPoint;
+        }
+
+        BiasPtr += BiasPtr != nullptr ? 1 : 0;
+        SumPtr += 1;
     }
 }
 
-template <size_t NCols, size_t SubBlkLen, bool HasZeroPoint>
+template <bool HasZeroPoint>
 void
-SQ4BitGemmM1Kernel_CompInt8_Impl(
+SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLenGreaterThan32(
     size_t BlkLen,
     const std::byte* QuantA,
     const std::byte* QuantBData,
@@ -859,17 +1002,16 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(
     const std::byte* QuantBZeroPoint,
     float* C,
     size_t CountN,
-    size_t CountK,
-    size_t BlockStrideQuantB,
+    size_t BlockCountK,
     const float* Bias
 )
 {
     constexpr size_t BlkBitWidth = 4;
 
-    const std::byte* QuantARowPtr = QuantA;
-    float* CRowPtr = C;
+    assert(BlkLen > 32);
+    assert(BlkLen % 32 == 0);
 
-    const size_t BlockCountK = BlockStrideQuantB;
+    float* CRowPtr = C;
 
     const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
     const size_t StrideQuantBScale = BlockCountK;
@@ -883,39 +1025,91 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(
 
     float* SumPtr = CRowPtr;
 
-    int64_t nblk = static_cast<int64_t>(CountN) - NCols;
+    const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F);
 
-    while (nblk >= 0) {
-        ComputeDotProducts_BlkBitWidth4_CompInt8<NCols, SubBlkLen, HasZeroPoint>(
-            BlkLen,
-            QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
-            StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
-            BiasPtr
-        );
+    // process blocks in 32-element sub-blocks
+    const size_t SubBlksPerBlk = BlkLen / 32;
 
-        // move to next `NCols` columns
+    for (size_t n = 0; n < CountN; ++n) {
+        const std::byte* QuantAPtr = QuantA;
+        const std::byte* QuantBDataPtr = QuantBDataColPtr;
+        const float* QuantBScalePtr = QuantBScaleColPtr;
+        const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr;
 
-        QuantBDataColPtr += NCols * StrideQuantBData;
-        QuantBScaleColPtr += NCols * StrideQuantBScale;
-        if constexpr (HasZeroPoint) {
-            QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint;
-        }
+        float32x4_t acc0{}, acc1{};
 
-        BiasPtr += BiasPtr != nullptr ? NCols : 0;
-        SumPtr += NCols;
+        for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) {
+            // compute combined scale
+            const float32x4_t scale = vdupq_n_f32(Q8BlkScale(QuantAPtr) * (*QuantBScalePtr));
 
-        nblk -= NCols;
-    }
+            // load B zero point
+            const int8x16_t bzp = [&]() -> int8x16_t {
+                if constexpr (HasZeroPoint) {
+                    return vdupq_n_s8(
+                        ((k_blk_idx & 1) == 0) ? std::to_integer<int8_t>((*QuantBZeroPointPtr) & std::byte{0x0F})
+                                               : std::to_integer<int8_t>((*QuantBZeroPointPtr) >> 4)
+                    );
+                } else {
+                    return vdupq_n_s8(8);
+                }
+            }();
+
+            const int8_t* QuantADataPtr = Q8BlkData(QuantAPtr);
+
+            for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; sub_blk_idx += 2) {
+                // load A
+                const int8x16_t av0 = vld1q_s8(QuantADataPtr + 0);
+                const int8x16_t av1 = vld1q_s8(QuantADataPtr + 16);
+                const int8x16_t av2 = vld1q_s8(QuantADataPtr + 32);
+                const int8x16_t av3 = vld1q_s8(QuantADataPtr + 48);
+
+                // load B
+                const uint8x16_t bv_packed0 = vld1q_u8(reinterpret_cast<const uint8_t*>(QuantBDataPtr));
+                const uint8x16_t bv_packed1 = vld1q_u8(reinterpret_cast<const uint8_t*>(QuantBDataPtr) + 16);
+
+                int8x16_t bv0 = vreinterpretq_s8_u8(vandq_u8(bv_packed0, LowMaskU8x16));
+                int8x16_t bv1 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed0, 4));
+                int8x16_t bv2 = vreinterpretq_s8_u8(vandq_u8(bv_packed1, LowMaskU8x16));
+                int8x16_t bv3 = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed1, 4));
+
+                // subtract B zero point
+                bv0 = vsubq_s8(bv0, bzp);
+                bv1 = vsubq_s8(bv1, bzp);
+                bv2 = vsubq_s8(bv2, bzp);
+                bv3 = vsubq_s8(bv3, bzp);
+
+                // quantized dot product
+                int32x4_t dot0{}, dot1{};
+                dot0 = vdotq_s32(vdotq_s32(dot0, av0, bv0), av1, bv1);
+                dot1 = vdotq_s32(vdotq_s32(dot1, av2, bv2), av3, bv3);
+
+                // convert to float
+                const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0);
+                const float32x4_t dot_f32_1 = vcvtq_f32_s32(dot1);
+
+                // multiply by scale and update accumulator
+                acc0 = vfmaq_f32(acc0, dot_f32_0, scale);
+                acc1 = vfmaq_f32(acc1, dot_f32_1, scale);
+
+                // increment block data pointers to next sub-block
+                QuantADataPtr += 16 * 4;
+                QuantBDataPtr += 16 * 2;
+            }
 
-    // left over columns less than `NCols`?
-    nblk += NCols;
-    for (int64_t n = 0; n < nblk; ++n) {
-        ComputeDotProducts_BlkBitWidth4_CompInt8<1, SubBlkLen, HasZeroPoint>(
-            BlkLen,
-            QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK,
-            StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint,
-            BiasPtr
-        );
+            // increment other block pointers
+
+            QuantAPtr += Q8BlkSize(BlkLen);
+            QuantBScalePtr += 1;
+
+            if constexpr (HasZeroPoint) {
+                QuantBZeroPointPtr += ((k_blk_idx & 1) == 0) ? 0 : 1;
+            }
+        }
+
+        *SumPtr = vaddvq_f32(acc0) + vaddvq_f32(acc1);
+        if (BiasPtr) {
+            *SumPtr += *BiasPtr;
+        }
 
         // move to next column
 
@@ -940,26 +1134,34 @@ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(
     const std::byte* QuantBZeroPoint,
     float* C,
     size_t CountN,
-    size_t CountK,
     size_t BlockStrideQuantB,
     const float* Bias
 )
 {
     if (BlkLen == 16) {
-        SQ4BitGemmM1Kernel_CompInt8_Impl<4, 16, HasZeroPoint>(
-            BlkLen,
+        SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen16<HasZeroPoint>(
+            QuantA,
+            QuantBData,
+            QuantBScale,
+            QuantBZeroPoint,
+            C,
+            CountN,
+            BlockStrideQuantB,
+            Bias
+        );
+    } else if (BlkLen == 32) {
+        SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLen32<HasZeroPoint>(
             QuantA,
             QuantBData,
             QuantBScale,
             QuantBZeroPoint,
             C,
             CountN,
-            CountK,
             BlockStrideQuantB,
             Bias
         );
     } else {
-        SQ4BitGemmM1Kernel_CompInt8_Impl<4, 32, HasZeroPoint>(
+        SQ4BitGemmM1Kernel_CompInt8_Impl_BlkLenGreaterThan32<HasZeroPoint>(
             BlkLen,
             QuantA,
             QuantBData,
@@ -967,7 +1169,6 @@ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(
             QuantBZeroPoint,
             C,
             CountN,
-            CountK,
             BlockStrideQuantB,
             Bias
         );
@@ -984,7 +1185,7 @@ SQ4BitGemmM1Kernel_CompInt8(
     const std::byte* QuantBZeroPoint,
     float* C,
     size_t CountN,
-    size_t CountK,
+    size_t /*CountK*/,
     size_t BlockStrideQuantB,
     const float* Bias
 )
@@ -998,7 +1199,6 @@ SQ4BitGemmM1Kernel_CompInt8(
             QuantBZeroPoint,
             C,
             CountN,
-            CountK,
             BlockStrideQuantB,
             Bias
         );
@@ -1011,7 +1211,6 @@ SQ4BitGemmM1Kernel_CompInt8(
             QuantBZeroPoint,
             C,
             CountN,
-            CountK,
             BlockStrideQuantB,
             Bias
         );
diff --git a/onnxruntime/core/optimizer/common_subexpression_elimination.cc b/onnxruntime/core/optimizer/common_subexpression_elimination.cc
index b2e7ef0b4f558..48df511d0c672 100644
--- a/onnxruntime/core/optimizer/common_subexpression_elimination.cc
+++ b/onnxruntime/core/optimizer/common_subexpression_elimination.cc
@@ -4,6 +4,7 @@
 #include "common_subexpression_elimination.h"
 #include "core/optimizer/utils.h"
 #include "core/graph/graph_utils.h"
+#include "core/framework/tensorprotoutils.h"
 
 #include <memory>
 #include <type_traits>
@@ -170,6 +171,32 @@ bool AreRangesEqual(const Range& lhs, const Range& rhs) {
          std::equal(lhs.begin(), lhs.end(), rhs.begin());
 }
 
+// Check if two tensor attributes are equal scalar tensors, mainly to support ConstantOfShape Op.
+// Currently support float, float16 and int64 data types, and requires the data are raw data in TensorProto.
+bool AreScalarTensorAttributeEqual(const ONNX_NAMESPACE::TensorProto& lhs_t, const ONNX_NAMESPACE::TensorProto& rhs_t) {
+  if (!(utils::HasDataType(lhs_t) && utils::HasDataType(rhs_t) && lhs_t.data_type() == rhs_t.data_type() &&
+        (lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT ||
+         lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT16 ||
+         lhs_t.data_type() == onnx::TensorProto_DataType_INT64) &&
+        lhs_t.dims_size() == 1 && rhs_t.dims_size() == 1 && lhs_t.dims()[0] == 1 && rhs_t.dims()[0] == 1 &&
+        utils::HasRawData(lhs_t) && utils::HasRawData(rhs_t))) {
+    return false;
+  }
+  const void* lhs_value = lhs_t.raw_data().data();
+  const void* rhs_value = rhs_t.raw_data().data();
+  switch (lhs_t.data_type()) {
+    case onnx::TensorProto_DataType_FLOAT:
+      return *reinterpret_cast<const float*>(lhs_value) == *reinterpret_cast<const float*>(rhs_value);
+    case onnx::TensorProto_DataType_FLOAT16:
+      return *reinterpret_cast<const MLFloat16*>(lhs_value) == *reinterpret_cast<const MLFloat16*>(rhs_value);
+    case onnx::TensorProto_DataType_INT64:
+      return *reinterpret_cast<const int64_t*>(lhs_value) == *reinterpret_cast<const int64_t*>(rhs_value);
+    default:
+      break;
+  }
+  return false;
+}
+
 bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::AttributeProto& rhs) {
   if (&lhs == &rhs) {
     return true;
@@ -193,6 +220,7 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A
     case onnx::AttributeProto_AttributeType_STRINGS:
       return AreRangesEqual(lhs.strings(), rhs.strings());
     case onnx::AttributeProto_AttributeType_TENSOR:
+      return AreScalarTensorAttributeEqual(lhs.t(), rhs.t());
     case onnx::AttributeProto_AttributeType_GRAPH:
     case onnx::AttributeProto_AttributeType_SPARSE_TENSOR:
     case onnx::AttributeProto_AttributeType_TYPE_PROTO:
@@ -207,6 +235,31 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A
   return false;
 }
 
+// Support scalar float/int64/fp16 tensor attribute only for now, and requires data is raw data in TensorProto.
+std::size_t GetTensorAttributeHash(const ONNX_NAMESPACE::TensorProto& attr_t) {
+  std::size_t hash = 0;
+  if (utils::HasDataType(attr_t) && attr_t.dims_size() == 1 && attr_t.dims()[0] == 1 && utils::HasRawData(attr_t)) {
+    int data_type = attr_t.data_type();
+    switch (data_type) {
+      case onnx::TensorProto_DataType_FLOAT:
+        UpdateHash(data_type, hash);
+        UpdateHash(*reinterpret_cast<const float*>(attr_t.raw_data().data()), hash);
+        break;
+      case onnx::TensorProto_DataType_FLOAT16:
+        UpdateHash(data_type, hash);
+        UpdateHash(static_cast<float>(*reinterpret_cast<const MLFloat16*>(attr_t.raw_data().data())), hash);
+        break;
+      case onnx::TensorProto_DataType_INT64:
+        UpdateHash(data_type, hash);
+        UpdateHash(*reinterpret_cast<const int64_t*>(attr_t.raw_data().data()), hash);
+        break;
+      default:
+        break;
+    }
+  }
+  return hash;
+}
+
 std::size_t GetAttributeHash(const ONNX_NAMESPACE::AttributeProto& attr) {
   std::size_t hash = 0;
   UpdateHash(
@@ -233,6 +286,8 @@ std::size_t GetAttributeHash(const ONNX_NAMESPACE::AttributeProto& attr) {
       UpdateHashWithContainer(attr.strings(), hash);
       break;
     case onnx::AttributeProto_AttributeType_TENSOR:
+      UpdateHash(attr.t(), &GetTensorAttributeHash, hash);
+      break;
     case onnx::AttributeProto_AttributeType_GRAPH:
     case onnx::AttributeProto_AttributeType_SPARSE_TENSOR:
     case onnx::AttributeProto_AttributeType_TYPE_PROTO:
diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc
index 9c98ed6d3e114..1516fb37a7e9f 100644
--- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc
+++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc
@@ -4,6 +4,7 @@
 #ifdef ENABLE_TRAINING
 
 #include <onnx/defs/attr_proto_util.h>
+#include "core/common/string_utils.h"
 #include "core/graph/graph_utils.h"
 #include "core/optimizer/initializer.h"
 #include "core/optimizer/utils.h"
@@ -26,38 +27,38 @@ UpStreamGatherGraphTransformer::UpStreamGatherGraphTransformer(
       // 2. Whether the outputs have the same dim changes if the Gather node moves before that operator.
       // 3. Should all inputs be allowed when tracking back further (bottom-up);
       //    if not, add the input index restriction as MatMul did.
-      {GetFullQualifiedOpName("Add", kOnnxDomain),
+      {utils::GetFullQualifiedOpName("Add", kOnnxDomain),
        OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<SimplePointwiseGatherActor<true>>(),
                                                             opset_14_13_7_6_1)},
-      {GetFullQualifiedOpName("BiasGelu", kMSDomain),
+      {utils::GetFullQualifiedOpName("BiasGelu", kMSDomain),
        OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<SimplePointwiseGatherActor<true>>(), opset_1)},
 
-      {GetFullQualifiedOpName("Cast", kOnnxDomain),
+      {utils::GetFullQualifiedOpName("Cast", kOnnxDomain),
        OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<SimplePointwiseGatherActor<true>>(),
                                                             opset_19_13_9_6_1)},
-      {GetFullQualifiedOpName("Div", kOnnxDomain),
+      {utils::GetFullQualifiedOpName("Div", kOnnxDomain),
        OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<SimplePointwiseGatherActor<true>>(),
                                                             opset_14_13_7_6_1)},
-      {GetFullQualifiedOpName("Dropout", kOnnxDomain),
+      {utils::GetFullQualifiedOpName("Dropout", kOnnxDomain),
        OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<SimplePointwiseGatherActor<true>>(),
                                                             opset_13_12_10_7_6_1)},
-      {GetFullQualifiedOpName("Gelu", kMSDomain),
+      {utils::GetFullQualifiedOpName("Gelu", kMSDomain),
        OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<SimplePointwiseGatherActor<true>>(),
                                                             opset_1)},
       {// Be noted, this is our own implementation of ONNX domain op.
-       GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
+       utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
        OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<LayerNormalizationGatherActor>(),
                                                             opset_1)},
-      {GetFullQualifiedOpName("MatMul", kOnnxDomain),
+      {utils::GetFullQualifiedOpName("MatMul", kOnnxDomain),
        OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<MatMulGatherActor>(),
                                                             opset_13_9_1)},
-      {GetFullQualifiedOpName("Reshape", kOnnxDomain),
+      {utils::GetFullQualifiedOpName("Reshape", kOnnxDomain),
        OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<ReshapeGatherActor>(),
                                                             opset_19_14_13_5_1)},
-      {GetFullQualifiedOpName("Softmax", kOnnxDomain),
+      {utils::GetFullQualifiedOpName("Softmax", kOnnxDomain),
        OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<SoftmaxGatherActor>(),
                                                             opset_13_11_1)},
-      {GetFullQualifiedOpName("Transpose", kOnnxDomain),
+      {utils::GetFullQualifiedOpName("Transpose", kOnnxDomain),
        OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<TransposeGatherActor>(),
                                                             opset_13_1)},
   });
@@ -69,7 +70,7 @@ bool UpStreamGatherGraphTransformer::UpStreamInternal(
     const OpPassThroughConfig<UpStreamGatherOperatorActorBase>& pass_through_config,
     const logging::Logger& logger) const {
   Node& slice_node = *info.node_ptr;
-  const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain());
+  const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain());
 
   std::unordered_map<int, int> propagate_input_indices;
   std::unordered_map<int, std::vector<DimCompare>> all_input_cmp_rets;
diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc
index f7b48de2caaf5..716988e93312c 100644
--- a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc
+++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc
@@ -4,6 +4,7 @@
 #ifdef ENABLE_TRAINING
 
 #include "core/framework/tensorprotoutils.h"
+#include "core/common/string_utils.h"
 #include "core/graph/graph_utils.h"
 #include "core/optimizer/utils.h"
 #include "core/optimizer/compute_optimizer/upstream_reshape_actors.h"
@@ -21,23 +22,23 @@ UpStreamReshapeGraphTransformer::UpStreamReshapeGraphTransformer(
       //    If optype is not enough to guarantee the equivalence, we need to add a customized pre-check function.
       // 2. Should all inputs be allowed when tracking back further (bottom-up);
       //    if not, add the input index restriction.
-      {GetFullQualifiedOpName("Add", kOnnxDomain),
+      {utils::GetFullQualifiedOpName("Add", kOnnxDomain),
        OpPassThroughConfig<UpStreamReshapeOperatorActorBase>(
            std::make_shared<SimplePointwiseReshapeActor<true>>(), opset_14_13_7_6_1)},
-      {GetFullQualifiedOpName("BiasGelu", kMSDomain),
+      {utils::GetFullQualifiedOpName("BiasGelu", kMSDomain),
        OpPassThroughConfig<UpStreamReshapeOperatorActorBase>(
            std::make_shared<SimplePointwiseReshapeActor<true>>(), opset_1)},
-      {GetFullQualifiedOpName("Cast", kOnnxDomain),
+      {utils::GetFullQualifiedOpName("Cast", kOnnxDomain),
        OpPassThroughConfig<UpStreamReshapeOperatorActorBase>(
            std::make_shared<SimplePointwiseReshapeActor<true>>(), opset_19_13_9_6_1)},
-      {GetFullQualifiedOpName("Dropout", kOnnxDomain),
+      {utils::GetFullQualifiedOpName("Dropout", kOnnxDomain),
        OpPassThroughConfig<UpStreamReshapeOperatorActorBase>(
            std::make_shared<SimplePointwiseReshapeActor<true>>(), opset_13_12_10_7_6_1)},
       {// Be noted, this is our own implementation of ONNX domain op.
-       GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
+       utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
        OpPassThroughConfig<UpStreamReshapeOperatorActorBase>(
            std::make_shared<LayerNormalizationReshapeActor>(), opset_1)},
-      {GetFullQualifiedOpName("MatMul", kOnnxDomain),
+      {utils::GetFullQualifiedOpName("MatMul", kOnnxDomain),
        OpPassThroughConfig<UpStreamReshapeOperatorActorBase>(
            std::make_shared<MatMulReshapeActor>(), opset_13_9_1)},
   });
@@ -47,7 +48,7 @@ bool UpStreamReshapeGraphTransformer::UpStreamInternal(
     Graph& graph, std::deque<ReshapeInfo>& queue, Node& current_node, ReshapeInfo& info,
     const OpPassThroughConfig<UpStreamReshapeOperatorActorBase>& pass_through_config,
     const logging::Logger& logger) const {
-  const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain());
+  const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain());
 
   std::vector<int> propagate_input_indices;
   std::unordered_map<int, std::vector<DimCompare>> all_input_cmp_rets;
diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc
index f08e37296d259..4582f26a7dc68 100644
--- a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc
+++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc
@@ -5,6 +5,7 @@
 
 #include <onnx/defs/attr_proto_util.h>
 #include "core/common/safeint.h"
+#include "core/common/string_utils.h"
 #include "core/graph/graph_utils.h"
 #include "core/optimizer/initializer.h"
 #include "core/optimizer/utils.h"
@@ -130,7 +131,7 @@ template <typename T1, typename T2>
 bool UpStreamGraphTransformerBase<T1, T2>::Upstream(Graph& graph, std::deque<T1>& queue,
                                                     Node& current_node, T1& info,
                                                     const logging::Logger& logger) const {
-  const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain());
+  const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain());
   if (allowed_passthrough_ops_.count(op_type)) {
     auto& pass_through_config = allowed_passthrough_ops_.at(op_type);
     LOG_DEBUG_INFO(logger, "Enter reorder handle for node " + current_node.Name() + "(" + op_type + ")");
diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h
index 6e22fc791ade3..d848a03c555bb 100644
--- a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h
+++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h
@@ -72,13 +72,6 @@ class UpStreamGraphTransformerBase : public GraphTransformer {
                                 const OpPassThroughConfig<T2>& pass_through_config,
                                 const logging::Logger& logger) const = 0;
 
-  /**
-   * @brief A consistent way to construct the full qualified op name.
-   */
-  std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) const {
-    return domain + "::" + op_type;
-  }
-
   std::unordered_map<std::string, OpPassThroughConfig<T2>> allowed_passthrough_ops_;
 
  private:
diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc
index 4903bc1d6b961..90cabff88122c 100644
--- a/onnxruntime/core/optimizer/gather_fusion.cc
+++ b/onnxruntime/core/optimizer/gather_fusion.cc
@@ -9,55 +9,144 @@
 
 namespace onnxruntime {
 
-bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis,
-                                            int64_t& indices_n_dims) const {
-  if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) ||
+namespace {
+static int64_t GetGatherAxis(const Node& node, int64_t rank) {
+  int64_t axis = 0;
+  auto& attrs = node.GetAttributes();
+  if (attrs.find("axis") != attrs.end()) {
+    auto& axis_attr = attrs.at("axis");
+    if (utils::HasInt(axis_attr)) {
+      axis = axis_attr.i();
+      if (axis < 0) axis += rank;
+    }
+  }
+  return axis;
+}
+
+static bool GetScalarInt64Initializer(const Graph& graph, const NodeArg& node_arg, int64_t& value, int64_t& rank) {
+  if (!optimizer_utils::IsScalar(node_arg)) return false;
+  const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node_arg.Name());
+  if (!tensor_proto || tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false;
+  Initializer init_const{*tensor_proto, graph.ModelPath()};
+  value = *(init_const.data<int64_t>());
+  rank = tensor_proto->dims_size();
+  return true;
+}
+
+static bool GetSliceAxis(const Graph& graph, const Node& node, int64_t rank, int64_t& axis) {
+  if (node.InputDefs().size() < 4) return false;
+  int64_t unused = 0;
+  if (!GetScalarInt64Initializer(graph, *node.InputDefs()[3], axis, unused)) return false;
+  if (axis < 0) axis += rank;
+  return true;
+}
+
+static bool GetAxis(const Graph& graph, const Node& node, int64_t rank, int64_t& axis) {
+  if (node.OpType() == "Gather") {
+    axis = GetGatherAxis(node, rank);
+    return true;
+  }
+  if (node.OpType() == "Slice") {
+    return GetSliceAxis(graph, node, rank, axis);
+  }
+  return false;
+}
+
+}  // namespace
+
+bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t rank,
+                                                 int64_t target_axis, int64_t dim_size, InlinedVector<bool>& consumed,
+                                                 int64_t& start, bool& need_squeeze) const {
+  if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {13}) ||
       !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
     return false;
   }
 
-  const NodeArg& input_arg = *(node.InputDefs()[1]);
-  if (!optimizer_utils::IsScalar(input_arg)) return false;
-  const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
-  if (!tensor_proto) return false;
-  if (tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) return false;
-  Initializer init_const{*tensor_proto, graph.ModelPath()};
-  index = *(init_const.data<int64_t>());
-  axis = 0;  // Default value.
-  auto& attrs = node.GetAttributes();
-  if (attrs.find("axis") != attrs.end()) {
-    auto& axis_attr = attrs.at("axis");
-    if (utils::HasInt(axis_attr)) axis = axis_attr.i();
+  if (GetGatherAxis(node, rank) != target_axis) return false;
+  // Require the indices input to be a scalar tensor for now. Normally if not, the exporter will choose Slice.
+  // We can relax this later if needed.
+  int64_t indices_n_dims = 0;
+  if (!GetScalarInt64Initializer(graph, *(node.InputDefs()[1]), start, indices_n_dims)) return false;
+  if (start < 0) start += dim_size;
+  if (start < 0 || start >= dim_size || consumed[static_cast<size_t>(start)]) return false;
+  consumed[static_cast<size_t>(start)] = true;
+  need_squeeze = indices_n_dims == 0;
+  return true;
+}
+
+bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis,
+                                                int64_t dim_size, InlinedVector<bool>& consumed, int64_t& start,
+                                                int64_t& end) const {
+  if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {13}) ||
+      !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
+    return false;
+  }
+
+  int64_t axis = 0;
+  if (!GetSliceAxis(graph, node, rank, axis) || axis != target_axis) return false;
+  int64_t unused = 0;
+  if (!GetScalarInt64Initializer(graph, *node.InputDefs()[1], start, unused) ||
+      !GetScalarInt64Initializer(graph, *node.InputDefs()[2], end, unused)) {
+    return false;
+  }
+  // Handling start and end according to schema definition.
+  if (start < 0) start += dim_size;
+  if (end < 0) end += dim_size;
+  if (start < 0)
+    start = 0;
+  else if (start > dim_size)
+    start = dim_size;
+  if (end < 0)
+    end = 0;
+  else if (end > dim_size)
+    end = dim_size;
+  if (start >= end) return false;
+  if (node.InputDefs().size() >= 5) {
+    int64_t step = 0;
+    if (!GetScalarInt64Initializer(graph, *node.InputDefs()[4], step, unused) || step != 1) return false;
+  }
+  for (int64_t i = start; i < end; ++i) {
+    if (consumed[static_cast<size_t>(i)]) return false;
+    consumed[static_cast<size_t>(i)] = true;
   }
-  indices_n_dims = tensor_proto->dims_size();
   return true;
 }
 
 /*
-GatherToSplitFusion is to fuse:
-Node -> Gather(index=0, axis=axis)
-    |-> Gather(index=1, axis=axis)
-    |-> Gather(index=2, axis=axis)
+GatherSliceToSplitFusion is to fuse:
+Node -> Gather(indices=0, axis=axis)
+    |-> Gather(indices=[1], axis=axis)
+    |-> Slice(start=2, end=3, axes=[axis])
     |...
 
 To
 
 Node -> Split -> Squeeze(axis=axis)
-             |-> Squeeze(axis=axis)
-             |-> Squeeze(axis=axis)
+             |->
+             |->
              |...
 
 So that we can use one kernel to finish the job.
+The fusion requires that the indices of Gather nodes and start/end of Slice nodes are not overlapping and cover
+all the elements in the target axis. Step of Slice node should be 1.
 */
-Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
-                                      const logging::Logger& logger) const {
+Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
+                                           const logging::Logger& logger) const {
+  // Squeeze, Gather, Slice and Split have different schemas before and after OpSet 13.
+  // To make code simple, support OpSet >= 13 only.
+  int onnx_opset_version = -1;
+  if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) {
+    onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain);
+  }
+  if (onnx_opset_version < 13) return Status::OK();
+
   GraphViewer graph_viewer(graph);
   const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
 
-  InlinedVector<const NodeArg*> node_args;
+  InlinedVector<const NodeArg*> candidate_args;
   for (auto node_arg : graph.GetInputs()) {
     if (node_arg && graph.GetConsumerNodes(node_arg->Name()).size() > 1) {
-      node_args.push_back(node_arg);
+      candidate_args.push_back(node_arg);
     }
   }
 
@@ -65,7 +154,7 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
     if (graph.GetConsumerNodes(entry.first).size() > 1) {
       auto node_arg = graph.GetNodeArg(entry.first);
       if (node_arg) {
-        node_args.push_back(node_arg);
+        candidate_args.push_back(node_arg);
       }
     }
   }
@@ -90,129 +179,108 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
     size_t output_count = node.GetOutputEdgesCount();
     if (output_count <= 1) continue;
 
-    node_args.push_back(node.OutputDefs()[0]);
+    candidate_args.push_back(node.OutputDefs()[0]);
   }
 
-  for (const NodeArg* node_arg : node_args) {
+  for (const NodeArg* node_arg : candidate_args) {
     auto shape = node_arg->Shape();
     if (!shape) continue;
     int64_t rank = static_cast<int64_t>(shape->dim_size());
-
-    bool can_fuse = true;
-    bool first_edge = true;
-    int64_t split_axis = 0;
-    int64_t indices_n_dims = -1;
     auto consumers = graph.GetConsumerNodes(node_arg->Name());
-    size_t consumer_count = consumers.size();
-    InlinedVector<NodeArg*> gather_outputs(consumer_count, nullptr);
-    InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse;
+    InlinedVector<const Node*> condidate_consumers;
     for (auto consumer : consumers) {
-      int64_t index, axis, dims;
-      if (!consumer || consumer->InputDefs()[0] != node_arg ||
-          !IsSupportedGather(graph, *consumer, index, axis, dims)) {
-        can_fuse = false;
-        break;
-      }
-      if (indices_n_dims == -1) {
-        indices_n_dims = dims;
-      } else if (indices_n_dims != dims) {
-        // Not the same number of dimensions (0 or 1) for all scalar indices.
-        can_fuse = false;
-        break;
+      if (consumer && consumer->InputDefs()[0] == node_arg &&
+          (consumer->OpType() == "Gather" || consumer->OpType() == "Slice")) {
+        condidate_consumers.emplace_back(consumer);
       }
-      if (axis < 0) axis += rank;
-      if (first_edge) {
-        auto dim = shape->dim(static_cast<int>(axis));
-        if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast<int64_t>(consumer_count)) {
-          can_fuse = false;
-          break;
-        }
-        split_axis = axis;
-        first_edge = false;
-      } else if (axis != split_axis) {
+    }
+    if (condidate_consumers.size() < 2) continue;
+    int64_t axis = 0;
+    if (!GetAxis(graph, *condidate_consumers[0], rank, axis)) continue;
+    auto dim = shape->dim(static_cast<int>(axis));
+    if (!utils::HasDimValue(dim)) continue;
+    int64_t dim_size = dim.dim_value();
+    InlinedVector<bool> consumed(static_cast<size_t>(dim_size), false);
+    bool can_fuse = true;
+    InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse;
+    InlinedVector<int64_t> starts;
+    InlinedHashMap<int64_t, std::tuple<NodeArg*, int64_t, bool>> output_info_map;
+    for (auto consumer : condidate_consumers) {
+      if (!consumer || consumer->InputDefs()[0] != node_arg) {
         can_fuse = false;
         break;
       }
-      if (index < 0) index += static_cast<int64_t>(consumer_count);
-      if (index < 0 || index >= static_cast<int64_t>(consumer_count) || gather_outputs[static_cast<size_t>(index)]) {
+      int64_t start = 0, end = 0;
+      bool need_squeeze = false;
+      if (IsSupportedGather(graph, *consumer, rank, axis, dim_size, consumed, start, need_squeeze)) {
+        Node& gather_node = *graph.GetNode(consumer->Index());
+        nodes_to_fuse.emplace_back(gather_node);
+        starts.emplace_back(start);
+        output_info_map[start] = std::make_tuple(gather_node.MutableOutputDefs()[0], 1, need_squeeze);
+      } else if (IsSupportedSlice(graph, *consumer, rank, axis, dim_size, consumed, start, end)) {
+        Node& slice_node = *graph.GetNode(consumer->Index());
+        nodes_to_fuse.emplace_back(slice_node);
+        starts.emplace_back(start);
+        output_info_map[start] = std::make_tuple(slice_node.MutableOutputDefs()[0], end - start, false);
+      } else {
         can_fuse = false;
         break;
       }
-      Node& gather_node = *graph.GetNode(consumer->Index());
-      nodes_to_fuse.emplace_back(gather_node);
-      gather_outputs[static_cast<size_t>(index)] = gather_node.MutableOutputDefs()[0];
-    }
-
-    if (!can_fuse) continue;
-
-    ONNX_NAMESPACE::TypeProto split_output_type;
-    const ONNX_NAMESPACE::TensorProto_DataType element_type =
-        static_cast<ONNX_NAMESPACE::TensorProto_DataType>(node_arg->TypeAsProto()->tensor_type().elem_type());
-    split_output_type.mutable_tensor_type()->set_elem_type(element_type);
-    for (int64_t i = 0; i < rank; ++i) {
-      if (i == split_axis) {
-        split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL);
-      } else {
-        *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast<int>(i));
-      }
     }
 
+    if (!can_fuse || std::find(consumed.begin(), consumed.end(), false) != consumed.end()) continue;
+    std::sort(starts.begin(), starts.end());
     InlinedVector<NodeArg*> split_outputs;
-    bool add_squeeze_node = indices_n_dims == 0;
-    if (add_squeeze_node) {
-      for (size_t i = 0; i < consumer_count; ++i) {
-        split_outputs.emplace_back(
-            &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type));
-      }
-    }
-
-    Node& split_node =
-        graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
-                      {graph.GetNodeArg(node_arg->Name())}, add_squeeze_node ? split_outputs : gather_outputs);
-    split_node.AddAttribute("axis", split_axis);
-    split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
-
-    // Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas.
-    int onnx_opset_version = -1;
-    if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) {
-      onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain);
-    }
-
-    if (onnx_opset_version < 13) {
-      if (add_squeeze_node) {
-        for (size_t i = 0; i < consumer_count; ++i) {
-          Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
-                                             "Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]});
-          squeeze_node.AddAttribute("axes", std::vector<int64_t>{split_axis});
-          squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
+    InlinedVector<int64_t> split_values;
+    for (int64_t start : starts) {
+      auto& output_info = output_info_map[start];
+      NodeArg* original_output_arg = std::get<0>(output_info);
+      int64_t split_value = std::get<1>(output_info);
+      split_values.emplace_back(split_value);
+      if (std::get<2>(output_info)) {
+        ONNX_NAMESPACE::TypeProto split_output_type;
+        const ONNX_NAMESPACE::TensorProto_DataType element_type =
+            static_cast<ONNX_NAMESPACE::TensorProto_DataType>(node_arg->TypeAsProto()->tensor_type().elem_type());
+        split_output_type.mutable_tensor_type()->set_elem_type(element_type);
+        for (int64_t i = 0; i < rank; ++i) {
+          if (i == axis) {
+            split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(split_value);
+          } else {
+            *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast<int>(i));
+          }
         }
-      }
-    } else {
-      if (onnx_opset_version >= 18) {
-        split_node.AddAttribute("num_outputs", static_cast<int64_t>(consumer_count));
-      }
-
-      if (add_squeeze_node) {
+        NodeArg* split_output_arg =
+            &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split_output"), &split_output_type);
         ONNX_NAMESPACE::TensorProto axes_initializer_proto;
-        axes_initializer_proto.set_name(graph.GenerateNodeName("SqueezeAxesInitializer"));
+        axes_initializer_proto.set_name(graph.GenerateNodeName("squeeze_axes"));
         axes_initializer_proto.add_dims(static_cast<int64_t>(1));
         axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
-        InlinedVector<int64_t> axes_value{split_axis};
-        axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t));
+        axes_initializer_proto.add_int64_data(axis);
         NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto);
-
-        for (size_t i = 0; i < consumer_count; ++i) {
-          Node& squeeze_node =
-              graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
-                            "Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]});
-          squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
-        }
+        Node& squeeze_node =
+            graph.AddNode(graph.GenerateNodeName("Squeeze"), "Squeeze", "Squeeze for Fused Gather nodes",
+                          {split_output_arg, axes_arg}, {original_output_arg});
+        squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
+        split_outputs.emplace_back(split_output_arg);
+      } else {
+        split_outputs.emplace_back(original_output_arg);
       }
     }
 
-    for (Node& n : nodes_to_fuse) {
-      graph_utils::RemoveNodeOutputEdges(graph, n);
-      graph.RemoveNode(n.Index());
+    ONNX_NAMESPACE::TensorProto split_initializer_proto;
+    split_initializer_proto.set_name(graph.GenerateNodeName("splits"));
+    split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
+    split_initializer_proto.add_dims(static_cast<int64_t>(split_values.size()));
+    split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end());
+    NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto);
+    Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
+                                     {graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs);
+    split_node.AddAttribute("axis", axis);
+    split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
+
+    for (Node& node : nodes_to_fuse) {
+      graph_utils::RemoveNodeOutputEdges(graph, node);
+      graph.RemoveNode(node.Index());
     }
 
     modified = true;
diff --git a/onnxruntime/core/optimizer/gather_fusion.h b/onnxruntime/core/optimizer/gather_fusion.h
index 44c235915b6cc..098278a77dafe 100644
--- a/onnxruntime/core/optimizer/gather_fusion.h
+++ b/onnxruntime/core/optimizer/gather_fusion.h
@@ -8,19 +8,23 @@
 namespace onnxruntime {
 
 /**
-@Class GatherToSplitFusion
+@Class GatherSliceToSplitFusion
 
-Fuse multiple Gather nodes that comsuming one output to one Split node.
+Fuse multiple Gather/Slice nodes that comsuming one output to one Split node.
 */
-class GatherToSplitFusion : public GraphTransformer {
+class GatherSliceToSplitFusion : public GraphTransformer {
  public:
-  GatherToSplitFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
-      : GraphTransformer("GatherToSplitFusion", compatible_execution_providers) {}
+  GatherSliceToSplitFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
+      : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {}
 
   Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
 
  private:
-  bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const;
+  bool IsSupportedGather(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size,
+                         InlinedVector<bool>& consumed, int64_t& start, bool& need_squeeze) const;
+
+  bool IsSupportedSlice(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size,
+                        InlinedVector<bool>& consumed, int64_t& start, int64_t& end) const;
 };
 
 /**
diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc
index c62887da09fdc..50be2cbd48f7b 100644
--- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc
+++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc
@@ -56,6 +56,13 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
       continue;
     }
 
+    NodeArg* node_output = node.MutableOutputDefs()[0];
+    auto data_type = node_output->TypeAsProto()->tensor_type().elem_type();
+    if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
+      // FusedGemm is only registered for float data type in fused_gemm.cc!
+      continue;
+    }
+
     const Node& next_node = *(node.OutputNodesBegin());
     if (!IsFusableActivation(next_node) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
       continue;
diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc
index cd3c49be15aa4..63612c47f9c56 100644
--- a/onnxruntime/core/optimizer/graph_transformer_utils.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc
@@ -69,6 +69,7 @@
 #include "core/optimizer/reshape_fusion.h"
 #include "core/optimizer/rocm_blas_alt_impl.h"
 #include "core/optimizer/rule_based_graph_transformer.h"
+#include "core/optimizer/shape_input_merge.h"
 #include "core/optimizer/skip_layer_norm_fusion.h"
 #include "core/optimizer/slice_elimination.h"
 #include "core/optimizer/transpose_optimizer.h"
@@ -211,9 +212,9 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
         transformers.emplace_back(std::make_unique<DoubleQDQPairsRemover>());
       }
 
-      // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for
-      // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by
-      // default, CSE will not merge them, because the different initializers are represented by different NodeArg.
+      // Put ConstantSharing and ShapeInputMerge before CommonSubexpressionElimination by intention as it can create
+      // more opportunities for CSE. For example, if A and B nodes consume same different args but produce same output
+      // or consume different initializers with same value, by default, CSE will not merge them.
       InlinedHashSet<std::string> excluded_initializers;
       excluded_initializers.reserve(session_options.initializers_to_share_map.size());
       for (const auto& p : session_options.initializers_to_share_map) {
@@ -221,7 +222,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
       }
       const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
       transformers.emplace_back(std::make_unique<ConstantSharing>(no_limit_empty_ep_list, excluded_initializers));
-
+      transformers.emplace_back(std::make_unique<ShapeInputMerge>());
       transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>());
       transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq,
                                                                   session_options.config_options));
@@ -278,7 +279,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
                                                                                onnxruntime::kAclExecutionProvider,
                                                                                onnxruntime::kArmNNExecutionProvider,
                                                                                onnxruntime::kJsExecutionProvider};
-
+      const InlinedHashSet<std::string_view> cpu_dml_eps = {onnxruntime::kCpuExecutionProvider,
+                                                            onnxruntime::kDmlExecutionProvider};
 #ifdef MLAS_TARGET_AMD64_IX86
       const bool avx2_precision_mode =
           session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow();
@@ -296,7 +298,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
       }
 
       transformers.emplace_back(std::make_unique<GemmActivationFusion>(cpu_ep));
-      transformers.emplace_back(std::make_unique<MatMulIntegerToFloatFusion>(cpu_ep));
+      transformers.emplace_back(std::make_unique<MatMulIntegerToFloatFusion>(cpu_dml_eps));
       transformers.emplace_back(std::make_unique<DynamicQuantizeMatMulFusion>(cpu_ep));
 
       transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_cuda_rocm_acl_armnn_js_eps));
@@ -306,7 +308,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
       transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(cpu_cuda_rocm_eps));
       transformers.emplace_back(std::make_unique<AttentionFusion>(cpu_cuda_dml_rocm_eps));
       transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_cuda_dml_rocm_eps));
-      transformers.emplace_back(std::make_unique<GatherToSplitFusion>(cpu_cuda_rocm_eps));
+      transformers.emplace_back(std::make_unique<GatherSliceToSplitFusion>(cpu_cuda_rocm_eps));
       transformers.emplace_back(std::make_unique<GatherToSliceFusion>(cpu_cuda_rocm_eps));
 
       transformers.emplace_back(std::make_unique<MatmulTransposeFusion>(cpu_cuda_dml_rocm_eps));
diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc
index 159e3b23d1ab0..ce696154adb6d 100644
--- a/onnxruntime/core/optimizer/layer_norm_fusion.cc
+++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc
@@ -13,7 +13,7 @@ using namespace onnxruntime::common;
 namespace onnxruntime {
 
 // LayerNorm supports limited data types.
-static constexpr std::array<std::string_view, 3> supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)"};
+static constexpr std::array<std::string_view, 4> supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"};
 // Default epsilon
 static constexpr float DEFAULT_LAYERNORM_EPSILON = 1e-5f;
 
@@ -447,6 +447,13 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
 
     NodeArg* x_input = has_leading_cast ? graph.GetNode(p_reduce_mean_input_node->Index())->MutableInputDefs()[0]
                                         : reduce_mean_node.MutableInputDefs()[0];
+
+    // CPU doesn't support fp16
+    if (reduce_mean_node.GetExecutionProviderType() == kCpuExecutionProvider &&
+        x_input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
+      continue;
+    }
+
     InlinedVector<NodeArg*> layer_norm_input_defs{x_input, scale, bias};
     Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName("LayerNormalization"),
                                           "LayerNormalization",
@@ -689,6 +696,13 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
 
     NodeArg* x_input = has_leading_cast ? graph.GetNode(p_pow_input_node->Index())->MutableInputDefs()[0]
                                         : pow_node.MutableInputDefs()[0];
+
+    // CPU doesn't support fp16
+    if (reduce_mean_node.GetExecutionProviderType() == kCpuExecutionProvider &&
+        x_input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
+      continue;
+    }
+
     InlinedVector<NodeArg*> layer_norm_input_defs{x_input, scale};
     Node& layer_norm_node =
         graph.AddNode(graph.GenerateNodeName("SimplifiedLayerNormalization"), "SimplifiedLayerNormalization",
diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
index 4505d4afdf1e0..7953cde6686c0 100644
--- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
+++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
@@ -31,6 +31,7 @@ CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const a
 }
 
 #if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS
+// TODO(mtavenrath) generate list from registered kernels using nhwc domain
 const std::unordered_set<std::string_view>& GetCUDALayoutSensitiveOps() {
   static std::unordered_set<std::string_view> cuda_nhwc_ops = []() {
     return std::unordered_set<std::string_view>{
@@ -41,7 +42,10 @@ const std::unordered_set<std::string_view>& GetCUDALayoutSensitiveOps() {
         "MaxPool",
         "GlobalAveragePool",
         "AveragePool",
-    };
+        "GridSample",
+        "DepthToSpace",
+        "SpaceToDepth",
+        "LRN"};
   }();
   return cuda_nhwc_ops;
 }
diff --git a/onnxruntime/core/optimizer/matmul_integer_to_float.cc b/onnxruntime/core/optimizer/matmul_integer_to_float.cc
index 56e51cb787931..4fee1a6ce224e 100644
--- a/onnxruntime/core/optimizer/matmul_integer_to_float.cc
+++ b/onnxruntime/core/optimizer/matmul_integer_to_float.cc
@@ -31,6 +31,24 @@ static bool CheckBiasShape(const TensorShapeProto* bias_shape) {
   return bias_last_dim > 1;
 }
 
+bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) {
+  if (!node_arg.Exists()) {
+    return false;
+  }
+
+  const auto* type_proto = node_arg.TypeAsProto();
+  if (!type_proto) {
+    return false;
+  }
+
+  int32_t actual_data_type;
+  if (!utils::TryGetElementDataType(*type_proto, actual_data_type)) {
+    return false;
+  }
+
+  return data_type == actual_data_type;
+}
+
 /**
 MatMulIntegerToFloatFusion will fuse subgraph like below into MatMulIntegerToFloat:
 
@@ -63,9 +81,10 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g
     auto& mul_node = *node_ptr;
 
     ORT_RETURN_IF_ERROR(Recurse(mul_node, modified, graph_level, logger));
-
+    const bool is_dml_ep = node_ptr->GetExecutionProviderType() == kDmlExecutionProvider;
     if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
-        !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders())) {
+        !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders()) ||
+        (!is_dml_ep && HasElementDataType(*mul_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT16))) {
       continue;
     }
 
diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc
index b3c2991d54b28..bba39b698a27a 100644
--- a/onnxruntime/core/optimizer/noop_elimination.cc
+++ b/onnxruntime/core/optimizer/noop_elimination.cc
@@ -42,49 +42,62 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con
 
   // if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule,
   // but it won't happen if the case is accepted, thus reject it
-  auto initializer_rank = initializer->dims().size();
+  const auto& dims = initializer->dims();
+  auto initializer_rank = dims.size();
   const auto* other_input_shape = node.InputDefs()[input0_is_initializer ? 1 : 0]->Shape();
   if (other_input_shape == nullptr || initializer_rank > other_input_shape->dim_size()) {
     return false;
   }
 
-  int32_t data_type = initializer->data_type();
-  Initializer add_init(*initializer, graph.ModelPath());
-  if (add_init.size() > 1) {
+  int64_t tensor_size = 1;
+  for (auto i : dims) {
+    tensor_size *= i;
+  }
+
+  if (tensor_size > 1) {
     return false;
   }
+
   // handle edge case where the total size of the initializer is 0
-  if (add_init.size() == 0) {
+  if (tensor_size == 0) {
     return true;
   }
 
-  float value = 0.0f;
-  switch (data_type) {
-    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
-      value = *add_init.data<float>();
-      break;
-    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
-      value = math::halfToFloat(add_init.data<MLFloat16>()->val);
-      break;
-    case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
-      value = static_cast<float>(*add_init.data<double>());
-      break;
-    case ONNX_NAMESPACE::TensorProto_DataType_INT32:
-      value = static_cast<float>(*add_init.data<int32_t>());
-      break;
-    case ONNX_NAMESPACE::TensorProto_DataType_INT64:
-      value = static_cast<float>(*add_init.data<int64_t>());
-      break;
-    default:
+  if (op_type == "Add" ||
+      op_type == "Sub" ||
+      op_type == "Mul" ||
+      op_type == "Div") {
+    int32_t data_type = initializer->data_type();
+    Initializer add_init(*initializer, graph.ModelPath());
+
+    float value = 0.0f;
+    switch (data_type) {
+      case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
+        value = *add_init.data<float>();
+        break;
+      case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
+        value = math::halfToFloat(add_init.data<MLFloat16>()->val);
+        break;
+      case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
+        value = static_cast<float>(*add_init.data<double>());
+        break;
+      case ONNX_NAMESPACE::TensorProto_DataType_INT32:
+        value = static_cast<float>(*add_init.data<int32_t>());
+        break;
+      case ONNX_NAMESPACE::TensorProto_DataType_INT64:
+        value = static_cast<float>(*add_init.data<int64_t>());
+        break;
+      default:
+        return false;
+    }
+
+    if (value != 0.0f && (op_type == "Add" || op_type == "Sub")) {
       return false;
-  }
+    }
 
-  if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) {
-    return false;
-  }
-
-  if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) {
-    return false;
+    if (value != 1.0f && (op_type == "Mul" || op_type == "Div")) {
+      return false;
+    }
   }
 
   // reject node output is graph output for now
diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc
index b1ab641a23256..4e3dff705bd41 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc
@@ -76,6 +76,49 @@ bool IsQDQPairSupported(
   }
 }
 
+bool IsDQQConversion(
+    const Node& dq_node, const Node& q_node,
+    const GetConstantInitializerFn& get_const_initializer,
+    const Path& model_path) {
+  ConstPointerContainer<std::vector<NodeArg*>> dq_input_defs = dq_node.InputDefs();
+  ConstPointerContainer<std::vector<NodeArg*>> q_input_defs = q_node.InputDefs();
+
+  // Q/DQ contains optional input is not supported
+  // non-scalar Q/DQ scale and zero point needs are not supported
+  if (dq_input_defs.size() != InputIndex::TOTAL_COUNT ||
+      q_input_defs.size() != InputIndex::TOTAL_COUNT ||
+      !optimizer_utils::IsScalar(*q_input_defs[InputIndex::SCALE_ID]) ||
+      !optimizer_utils::IsScalar(*q_input_defs[InputIndex::ZERO_POINT_ID]) ||
+      !optimizer_utils::IsScalar(*dq_input_defs[InputIndex::SCALE_ID]) ||
+      !optimizer_utils::IsScalar(*dq_input_defs[InputIndex::ZERO_POINT_ID])) {
+    return false;
+  }
+
+  // if Q/DQ scale and zero point are not constant, return false
+  const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto =
+      get_const_initializer(dq_input_defs[InputIndex::SCALE_ID]->Name());
+  const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto =
+      get_const_initializer(q_input_defs[InputIndex::SCALE_ID]->Name());
+  const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto =
+      get_const_initializer(dq_input_defs[InputIndex::ZERO_POINT_ID]->Name());
+  const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto =
+      get_const_initializer(q_input_defs[InputIndex::ZERO_POINT_ID]->Name());
+  if (nullptr == q_zp_tensor_proto ||
+      nullptr == dq_zp_tensor_proto ||
+      nullptr == q_scale_tensor_proto ||
+      nullptr == dq_scale_tensor_proto) {
+    return false;
+  }
+
+  // check Q/DQ have same scale type and different zero point type
+  Initializer q_zp(*q_zp_tensor_proto, model_path);
+  Initializer q_scale(*q_scale_tensor_proto, model_path);
+  Initializer dq_zp(*dq_zp_tensor_proto, model_path);
+  Initializer dq_scale(*dq_scale_tensor_proto, model_path);
+
+  return (dq_zp.data_type() != q_zp.data_type()) && (dq_scale.data_type() == q_scale.data_type());
+}
+
 bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) {
   bool zero_point_exists = false;
   if (!QOrDQNodeHasConstantScalarScaleAndZeroPoint(dq_node, get_const_initializer, zero_point_exists)) {
diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h
index bb0bf9438cfcb..8333168b0093f 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h
+++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h
@@ -38,6 +38,18 @@ bool IsQDQPairSupported(
     const GetConstantInitializerFn& get_const_initializer,
     const Path& model_path);
 
+// Check if a DQ -> Q sequence represents a conversion in quantization data type.
+// Example of uint8 to uint16:
+//     Dequantize (uint8 to float) -> Quantize (float to uint16)
+// Requires:
+// 1. Q/DQ doesn't have optional input.
+// 2. scale and zero-point are constant scalars.
+// 3. Q and DQ have the same scale *type* and different zero-point *types*.
+bool IsDQQConversion(
+    const Node& dq_node, const Node& q_node,
+    const GetConstantInitializerFn& get_const_initializer,
+    const Path& model_path);
+
 // Check if DQ is supported in extended level QDQ transformers. It requires:
 // 1. DQ doesn't have optional input.
 // 2. scale and zero point is constant scalar
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
index 8535b8c9a944a..6b4f62ae1343d 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
@@ -58,8 +58,8 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod
     return false;
   }
 
-  if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes);
-      !dq_validation_status.IsOK()) {
+  if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes);
+      !qdq_validation_status.IsOK()) {
     return false;
   }
 
@@ -153,8 +153,8 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
     return false;
   }
 
-  if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes);
-      !dq_validation_status.IsOK()) {
+  if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes);
+      !qdq_validation_status.IsOK()) {
     return false;
   }
 
@@ -544,8 +544,8 @@ bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer,
     return false;
   }
 
-  if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes);
-      !dq_validation_status.IsOK()) {
+  if (const auto qdq_validation_status = QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes);
+      !qdq_validation_status.IsOK()) {
     return false;
   }
 
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
index deee6e7f25f1a..c90a42a36483d 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
@@ -5,6 +5,7 @@
 
 #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
 
+#include "core/framework/node_unit.h"
 #include "core/optimizer/selectors_actions/selector_action_transformer.h"
 
 namespace onnxruntime {
@@ -13,13 +14,6 @@ class Node;
 
 namespace QDQ {
 
-// Struct to represent a DQ->Op->Q node group
-struct NodeGroup {
-  std::vector<NodeIndex> dq_nodes;
-  std::vector<NodeIndex> q_nodes;
-  NodeIndex target_node;
-};
-
 class NodeGroupSelector {
  public:
   // This is a QDQ Selectors only function, will return QDQ::NodeGroup instead of NodesToOptimizeIndices
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc
index 544fe82a268c8..1876f7826c968 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc
@@ -13,6 +13,7 @@
 #include <core/providers/common.h>
 
 #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
+#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
 
 namespace onnxruntime {
 namespace QDQ {
@@ -43,6 +44,7 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() {
           {"Tile", {}}};
 }
 
+// These produce int64 indices output, which can't be quantized, so there's no downstream Q node.
 static const OpVersionsAndSelector::OpVersionsMap GetDropDQOpVersionsMap() {
   return {{"ArgMax", {}},
           {"ArgMin", {}}};
@@ -324,28 +326,48 @@ std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& grap
   return qdq_selections;
 }
 
-Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer,
-                                const Node& target_node,
-                                gsl::span<const Node* const> dq_nodes) {
-  // Within a QDQ node group, a target node input is the only consumer of each DQ.
-  // This should have been ensured by the EnsureUniqueDQForNodeUnit graph transformer, but other graph modifications
-  // may have happened since. Verify that this is still true.
-  for (const auto* dq_node : dq_nodes) {
-    const bool dq_produces_graph_output = graph_viewer.NodeProducesGraphOutput(*dq_node);
-    ORT_RETURN_IF(dq_produces_graph_output,
-                  "QDQ node group cannot have DQ node that produces a graph output. DQ node: ", dq_node->Name(),
-                  ", target node: ", target_node.Name());
-
-    const bool dq_has_single_output_edge_to_target =
-        dq_node->GetOutputEdgesCount() == 1 &&
-        dq_node->OutputEdgesBegin()->GetNode().Index() == target_node.Index();
-    ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target,
-                      "QDQ node group cannot have DQ that doesn't have a single output edge to the target node. "
-                      "DQ node: ",
-                      dq_node->Name(), ", target node: ", target_node.Name());
+std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
+GetAllNodeUnits(const GraphViewer& graph_viewer) {
+  std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
+  std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
+
+  const auto add_node_unit_to_map = [&](const std::vector<NodeIndex>& node_indices, const NodeUnit* node_unit) {
+    for (const auto& node_idx : node_indices) {
+      const auto* node = graph_viewer.GetNode(node_idx);
+      node_unit_map.insert({node, node_unit});
+    }
+  };
+
+  // Get QDQ NodeUnits first
+  QDQ::SelectorManager selector_mgr;
+  const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer);
+
+  for (const auto& qdq_selection : qdq_selections) {
+    auto qdq_unit = std::make_unique<NodeUnit>(graph_viewer, qdq_selection);
+
+    // Fill the node to node_unit map for all nodes in the QDQ Group
+    add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get());
+    add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get());
+    add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get());
+
+    node_unit_holder.push_back(std::move(qdq_unit));
+  }
+
+  // Get the left over SingleNode NodeUnits
+  const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
+  for (const auto node_idx : node_indices) {
+    const auto* node(graph_viewer.GetNode(node_idx));
+
+    // This is already part of a QDQ NodeUnit
+    if (node_unit_map.find(node) != node_unit_map.cend())
+      continue;
+
+    auto node_unit = std::make_unique<NodeUnit>(*node);
+    node_unit_map[node] = node_unit.get();
+    node_unit_holder.push_back(std::move(node_unit));
   }
 
-  return Status::OK();
+  return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map));
 }
 
 }  // namespace QDQ
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h
index 246f26c1760ec..de36202afff29 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h
@@ -7,6 +7,7 @@
 #include "core/common/common.h"
 #include "core/common/gsl.h"
 #include "core/common/inlined_containers.h"
+#include "core/framework/node_unit.h"
 #include "core/graph/basic_types.h"
 
 #if !defined(ORT_MINIMAL_BUILD)
@@ -78,11 +79,16 @@ class SelectorManager {
   ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SelectorManager);
 };
 
-// Checks whether the provided DQ nodes are valid for forming a QDQ node group with the provided target node.
-// Returns successful status if so, failed status with reason otherwise.
-Status ValidateNodeGroupDQNodes(const GraphViewer& graph_viewer,
-                                const Node& target_node,
-                                gsl::span<const Node* const> dq_nodes);
+// Get all the nodes in the given graph_viewer as NodeUnits (SingleNode or QDQGroup)
+// And return a map to quick query the NodeUnit which contains the given Node,
+// Note, the value of the map is owned by the vector of std::unique_ptr<NodeUnit>
+//
+// TODO: The overall QDQ setup needs refactoring to separate out generic functionality from optimizer specific
+// functionality.
+// We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer
+// library whereas it should be able to be used by an EP with no dependency on optimizers.
+std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
+GetAllNodeUnits(const GraphViewer& graph_viewer);
 
 }  // namespace QDQ
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/shape_input_merge.cc b/onnxruntime/core/optimizer/shape_input_merge.cc
new file mode 100644
index 0000000000000..9f20520e3e3f4
--- /dev/null
+++ b/onnxruntime/core/optimizer/shape_input_merge.cc
@@ -0,0 +1,78 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/optimizer/shape_input_merge.h"
+
+#include "core/graph/graph_utils.h"
+
+namespace onnxruntime {
+
+namespace {
+std::string GetShapeString(const NodeArg* input_arg) {
+  auto shape = input_arg->Shape();
+  if (!shape) return "";
+  std::stringstream ss;
+  ss << "[";
+  for (int i = 0; i < shape->dim_size(); ++i) {
+    if (i != 0) ss << ",";
+    auto dim = shape->dim(i);
+    if (dim.has_dim_value()) {
+      ss << std::to_string(dim.dim_value());
+    } else if (dim.has_dim_param()) {
+      ss << "'" << dim.dim_param() << "'";
+    } else {
+      return "";
+    }
+  }
+  ss << "]";
+  return ss.str();
+}
+
+}  // namespace
+
+Status ShapeInputMerge::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
+  GraphViewer graph_viewer(graph);
+  const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
+  InlinedHashMap<std::string, InlinedVector<Node*>> input_hash_to_nodes;
+  for (auto node_index : node_topology_list) {
+    auto* p_node = graph.GetNode(node_index);
+    if (!p_node) continue;  // we removed the node as part of an earlier fusion
+    ORT_RETURN_IF_ERROR(Recurse(*p_node, modified, graph_level, logger));
+    if (!graph_utils::IsSupportedOptypeVersionAndDomain(*p_node, "Shape", {1, 13, 15, 19, 21}) ||
+        !graph_utils::IsSupportedProvider(*p_node, GetCompatibleExecutionProviders())) {
+      continue;
+    }
+    std::string shape_str = GetShapeString(p_node->InputDefs()[0]);
+    if (shape_str.empty()) continue;
+    if (input_hash_to_nodes.find(shape_str) == input_hash_to_nodes.end()) {
+      input_hash_to_nodes[shape_str] = InlinedVector<Node*>();
+    }
+    input_hash_to_nodes[shape_str].emplace_back(p_node);
+  }
+
+  // All Shape nodes are processed in topological order, so we can safely merge the inputs to the first node's input.
+  for (auto& kv : input_hash_to_nodes) {
+    if (kv.second.size() < 2) continue;
+    NodeArg* first_input_arg = kv.second[0]->MutableInputDefs()[0];
+    bool is_first_input_arg_graph_input = graph.IsInputsIncludingInitializers(first_input_arg);
+    for (size_t i = 1; i < kv.second.size(); ++i) {
+      Node* p_node = kv.second[i];
+      const NodeArg* input_arg = p_node->InputDefs()[0];
+      if (p_node->InputDefs()[0]->Name() == first_input_arg->Name()) continue;
+      if (!graph.IsInputsIncludingInitializers(input_arg)) {
+        const Node::EdgeEnd& input_edge = *p_node->InputEdgesBegin();
+        graph.RemoveEdge(input_edge.GetNode().Index(), p_node->Index(), input_edge.GetSrcArgIndex(), 0);
+      }
+      graph_utils::ReplaceNodeInput(*p_node, 0, *first_input_arg);
+      if (!is_first_input_arg_graph_input) {
+        const Node::EdgeEnd& first_input_edge = *kv.second[0]->InputEdgesBegin();
+        graph.AddEdge(first_input_edge.GetNode().Index(), p_node->Index(), first_input_edge.GetSrcArgIndex(), 0);
+      }
+      modified = true;
+    }
+  }
+
+  return Status::OK();
+}
+
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/shape_input_merge.h b/onnxruntime/core/optimizer/shape_input_merge.h
new file mode 100644
index 0000000000000..5cb943998487b
--- /dev/null
+++ b/onnxruntime/core/optimizer/shape_input_merge.h
@@ -0,0 +1,23 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/optimizer/graph_transformer.h"
+
+namespace onnxruntime {
+
+/**
+@Class ShapeInputMerge
+Merge all shape inputs having same shape value to a single shape input.
+This change will not affect the performance, but it open chances for CSE fusion to merge nodes.
+*/
+class ShapeInputMerge : public GraphTransformer {
+ public:
+  ShapeInputMerge(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
+      : GraphTransformer("ShapeInputMerge", compatible_execution_providers) {}
+
+  Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
+};
+
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/stft_decomposition.cc b/onnxruntime/core/optimizer/stft_decomposition.cc
new file mode 100644
index 0000000000000..a54904ff15e1e
--- /dev/null
+++ b/onnxruntime/core/optimizer/stft_decomposition.cc
@@ -0,0 +1,381 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include <limits>
+
+#include "core/optimizer/stft_decomposition.h"
+#include "core/optimizer/initializer.h"
+#include "core/optimizer/utils.h"
+#include "core/graph/graph_utils.h"
+#include "core/optimizer/optimizer_execution_frame.h"
+#include "core/optimizer/utils.h"
+#include "core/framework/op_kernel.h"
+#include "core/framework/tensorprotoutils.h"
+
+using namespace onnxruntime::common;
+
+namespace onnxruntime {
+
+STFTDecomposition::STFTDecomposition(const InlinedHashSet<std::string_view>& compatible_execution_providers) noexcept
+    : GraphTransformer("STFTDecomposition", compatible_execution_providers) {
+}
+
+template <typename T>
+constexpr static ONNX_NAMESPACE::TensorProto_DataType GetDataType() {
+  if constexpr (std::is_same<T, float>::value) {
+    return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
+  } else if constexpr (std::is_same<T, MLFloat16>::value) {
+    return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
+  } else if constexpr (std::is_same<T, double>::value) {
+    return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
+  } else if constexpr (std::is_same<T, int64_t>::value) {
+    return ONNX_NAMESPACE::TensorProto_DataType_INT64;
+  } else {
+    throw std::logic_error("Invalid data type requested for STFT decomposition");
+  }
+}
+
+template <typename TDataType, size_t TDims>
+NodeArg* AddInitializer(Graph& graph, const char* name, const int64_t (&shape)[TDims], const TDataType* begin) {
+  ONNX_NAMESPACE::TensorProto proto;
+  proto.set_name(graph.GenerateNodeArgName(name));
+  proto.set_data_type(GetDataType<TDataType>());
+  int64_t element_count = 1;
+  for (size_t i = 0; i < TDims; i++) {
+    element_count *= shape[i];
+    proto.add_dims(shape[i]);
+  }
+  proto.set_raw_data(begin, element_count * sizeof(TDataType));
+  return &graph_utils::AddInitializer(graph, proto);
+}
+
+template <size_t TDims>
+NodeArg* AddShapeInitializer(Graph& graph, const char* name, const int64_t (&shape)[TDims]) {
+  int64_t shape_shape[] = {TDims};
+  return AddInitializer<int64_t>(graph, name, shape_shape, shape);
+}
+
+std::pair<Node*, NodeArg*> AddNode(Graph& graph,
+                                   const char* op_type,
+                                   ProviderType execution_provider_type,
+                                   gsl::span<NodeArg*> inputs) {
+  auto def_name = graph.GenerateNodeArgName(op_type);
+  auto node_arg = &graph.GetOrCreateNodeArg(def_name, nullptr);
+  Node& node = graph.AddNode(graph.GenerateNodeName(op_type),
+                             op_type,
+                             "",
+                             inputs,
+                             {node_arg});
+  node.SetExecutionProviderType(execution_provider_type);
+  return std::make_pair(&node, node_arg);
+}
+
+std::pair<Node*, NodeArg*> AddNodeCast(Graph& graph, NodeArg* in,
+                                       ONNX_NAMESPACE::TensorProto_DataType data_type) {
+  auto def_name = graph.GenerateNodeArgName("Cast");
+  auto node_arg = &graph.GetOrCreateNodeArg(def_name, nullptr);
+  Node& node = graph.AddNode(graph.GenerateNodeName("Cast"),
+                             "Cast",
+                             "",
+                             {in},
+                             {node_arg});
+  node.AddAttribute("to", static_cast<int64_t>(data_type));
+  node.SetExecutionProviderType(kCpuExecutionProvider);
+  return std::make_pair(&node, node_arg);
+}
+
+#define CONTINUE_IF_NO_DIM_VALUE(dim) \
+  if (!dim.has_dim_value()) {         \
+    continue;                         \
+  }
+#define CONTINUE_IF_NULL(x) \
+  if (x == nullptr) {       \
+    continue;               \
+  }
+
+/*
+    This function decomposes a STFT node into a subgraph.
+    The decomposition requires that:
+      1) The signal input is real valued and not complex valued!
+      2) Both (frame_step) *and* either (window or frame_length) inputs must be constant.
+    Otherwise the transform will not be applied.
+
+    Subgraph pattern 1: STFT with optional Window parameter set
+              [root]--(signal)--------------------+
+              [root]--(frame_step)---------------+|
+              [root]--(window)------------------+||
+              [root]--(frame_length) ----------+|||
+                                               ||||
+                                               vvvv
+                                              [STFT]--(output)-->
+    After Fusion:
+              [root]--(signal)-------------------------+
+              [root]                                   |
+              [root]--(window)--+                      |
+              [root]            |                      |
+                                v                      v
+         (only for non-fp32) [Cast]             +--[Reshape]
+                                |               |      |
+                                v               |      v
+                            [Reshape]-->[Mul]---|-->[Conv]-------+
+                                |               |                |
+                                |               +-----|          |
+                                |                     v          v
+                                +------>[Mul]------>[Conv]-->[Concat]-->[Reshape]-->[Transpose]--(output)-->
+
+
+    Subgraph pattern 2: STFT without optional Window parameter set
+              [root]--(signal)-------------------+
+              [root]--(frame_step)--------------+|
+              [root]                             |
+              [root]--(frame_length) ----------+||
+                                               |||
+                                               vvv
+                                              [STFT]--(output)-->
+    After Fusion:
+              [root]--(signal)-->[Reshape]-->[Conv]
+              [root]                 |         |
+              [root]                 |         v
+              [root]                 +------>[Conv]-->[Concat]-->[Reshape]-->[Transpose]--(output)-->
+*/
+Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
+  GraphViewer graph_viewer(graph);
+  auto& order = graph_viewer.GetNodesInTopologicalOrder();
+
+  for (NodeIndex i : order) {
+    auto node = graph.GetNode(i);
+    CONTINUE_IF_NULL(node);
+    ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger));
+
+    if (node->OpType() != "STFT") {
+      continue;
+    }
+
+    Node& stft = *node;
+    auto signal = stft.MutableInputDefs()[0];
+    auto frame_step = stft.MutableInputDefs()[1];
+    auto window = stft.MutableInputDefs()[2];
+    auto frame_length = stft.MutableInputDefs()[3];
+
+    // If the signal has free dimensions, do not transform...
+    auto batch_size_dim = signal->Shape()->dim(0);
+    auto signal_length_dim = signal->Shape()->dim(1);
+    auto signal_components_dim = signal->Shape()->dim(2);
+    CONTINUE_IF_NO_DIM_VALUE(signal_length_dim);
+    CONTINUE_IF_NO_DIM_VALUE(signal_components_dim);
+
+    auto batch_size = batch_size_dim.has_dim_value() ? batch_size_dim.dim_value() : static_cast<int64_t>(-1);
+    auto signal_length = signal_length_dim.dim_value();
+    auto is_real = signal_components_dim.dim_value() == 1;
+    auto data_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(signal->TypeAsProto()->tensor_type().elem_type());
+
+    auto frame_step_initializer = graph_utils::GetConstantInitializer(graph, frame_step->Name());
+    auto window_initializer = graph_utils::GetConstantInitializer(graph, window->Name());
+    auto frame_length_initializer = graph_utils::GetConstantInitializer(graph, frame_length->Name());
+    CONTINUE_IF_NULL(frame_step_initializer);
+    if (!frame_length_initializer && !window_initializer) {
+      continue;
+    }
+
+    auto read_int64_initializer = [](Graph& graph, const ONNX_NAMESPACE::TensorProto* initializer) {
+      return *Initializer(*initializer, graph.ModelPath()).data<int64_t>();
+    };
+    auto frame_step_value = read_int64_initializer(graph, frame_step_initializer);
+
+    // Get DFT Size
+    int64_t dft_size = 0;
+    if (frame_length_initializer) {
+      dft_size = read_int64_initializer(graph, frame_length_initializer);
+    }
+    if (dft_size == 0 && window_initializer) {
+      auto window_length_dim = window->Shape()->dim(0);
+      CONTINUE_IF_NO_DIM_VALUE(window_length_dim);
+      dft_size = window_length_dim.dim_value();
+    }
+
+    bool is_onesided = true;
+    auto& attrs = stft.GetAttributes();
+    if (attrs.find("onesided") != attrs.end()) {
+      auto& onesided_attr = attrs.at("onesided");
+      if (utils::HasInt(onesided_attr)) {
+        is_onesided = static_cast<bool>(onesided_attr.i());
+      }
+    }
+
+    auto dft_unique_bins = is_onesided ? ((dft_size >> 1) + 1) : dft_size;
+
+    Node* signal_recipient = nullptr;
+    Node* window_recipient = nullptr;
+    Node* stft_producer = nullptr;
+    if (is_real) {
+      auto output_num_frames = stft.MutableOutputDefs()[0]->Shape()->dim(1).dim_value();
+      auto output_frame_length = stft.MutableOutputDefs()[0]->Shape()->dim(2).dim_value();
+      auto weight_size = static_cast<size_t>(dft_unique_bins * dft_size);
+      auto real_weights_data = std::vector<float>(weight_size);
+      auto imag_weights_data = std::vector<float>(weight_size);
+
+      // Populate weights
+      for (size_t k = 0; k < static_cast<size_t>(dft_unique_bins); k++) {
+        for (size_t n = 0; n < static_cast<size_t>(dft_size); n++) {
+          auto index = static_cast<size_t>(k * dft_size + n);
+          auto theta = -2 * M_PI * k * n / static_cast<float>(dft_size);
+          real_weights_data[index] = static_cast<float>(cos(theta));
+          imag_weights_data[index] = static_cast<float>(sin(theta));
+        }
+      }
+
+      const int64_t weight_shape[] = {dft_unique_bins, 1, 1, dft_size};
+      auto real_weights = AddInitializer<float>(graph, "stft_real_conv_weights", weight_shape, real_weights_data.data());
+      auto imaginary_weights = AddInitializer<float>(graph, "stft_imaginary_conv_weights", weight_shape, imag_weights_data.data());
+
+      const int64_t signal_reshaped[] = {batch_size, 1, 1, signal_length};
+      auto signal_shape = AddShapeInitializer(graph, "stft_signal_shape", signal_reshaped);
+
+      const int64_t unsqueezed_output_shape[] = {2, batch_size, output_frame_length, output_num_frames};
+      auto unsqueezed_shape = AddShapeInitializer(graph, "stft_output_reshaped", unsqueezed_output_shape);
+
+      NodeArg* signal_reshaped_inputs[] = {signal, signal_shape};
+      Node* reshape_signal_node = nullptr;
+      NodeArg* reshape_output = nullptr;
+      std::tie(reshape_signal_node, reshape_output) =
+          AddNode(graph, "Reshape", stft.GetExecutionProviderType(), signal_reshaped_inputs);
+
+      NodeArg* real_weights_final = real_weights;
+      NodeArg* imag_weights_final = imaginary_weights;
+      if (!window->Exists()) {
+        // When we are missing a window function
+        if (real_weights_final->TypeAsProto()->tensor_type().elem_type() != data_type) {
+          std::tie(std::ignore, real_weights_final) =
+              AddNodeCast(graph, real_weights_final, data_type);
+        }
+        if (imag_weights_final->TypeAsProto()->tensor_type().elem_type() != data_type) {
+          std::tie(std::ignore, imag_weights_final) =
+              AddNodeCast(graph, imag_weights_final, data_type);
+        }
+      } else {
+        // When we have a window function
+        const int64_t window_reshaped_shape[] = {1, 1, 1, dft_size};
+        auto window_shape = AddShapeInitializer(graph, "stft_window_shape", window_reshaped_shape);
+
+        auto window_final = window;
+        if (window->TypeAsProto()->tensor_type().elem_type() != GetDataType<float>()) {
+          Node* window_cast_node = nullptr;
+          std::tie(window_cast_node, window_final) =
+              AddNodeCast(graph, window, GetDataType<float>());
+          window_recipient = window_cast_node;
+        }
+
+        NodeArg* window_reshaped_inputs[] = {window_final, window_shape};
+        Node* window_reshape_node;
+        NodeArg* window_reshaped = nullptr;
+        std::tie(window_reshape_node, window_reshaped) =
+            AddNode(graph, "Reshape", kCpuExecutionProvider, window_reshaped_inputs);
+        if (!window_recipient) {
+          window_recipient = window_reshape_node;
+        }
+
+        NodeArg* scale_real_weights_inputs[] = {real_weights, window_reshaped};
+        NodeArg* windowed_real_weights_output = nullptr;
+        std::tie(std::ignore, windowed_real_weights_output) =
+            AddNode(graph, "Mul", kCpuExecutionProvider, scale_real_weights_inputs);
+
+        NodeArg* scale_imag_weights_inputs[] = {imaginary_weights, window_reshaped};
+        NodeArg* windowed_imag_weights_output = nullptr;
+        std::tie(std::ignore, windowed_imag_weights_output) =
+            AddNode(graph, "Mul", kCpuExecutionProvider, scale_imag_weights_inputs);
+
+        std::tie(std::ignore, real_weights_final) =
+            AddNodeCast(graph, windowed_real_weights_output, data_type);
+        std::tie(std::ignore, imag_weights_final) =
+            AddNodeCast(graph, windowed_imag_weights_output, data_type);
+      }
+
+      // Add Convolution (reals)
+      NodeArg* conv_real_inputs[] = {reshape_output, real_weights_final};
+      Node* real_conv_node = nullptr;
+      NodeArg* real_conv_output = nullptr;
+      std::tie(real_conv_node, real_conv_output) =
+          AddNode(graph, "Conv", stft.GetExecutionProviderType(), conv_real_inputs);
+      real_conv_node->AddAttribute("strides", std::vector<int64_t>{1, frame_step_value});
+
+      // Add Convolution (imaginary)
+      NodeArg* conv_imag_inputs[] = {reshape_output, imag_weights_final};
+      Node* imag_conv_node = nullptr;
+      NodeArg* imag_conv_output = nullptr;
+      std::tie(imag_conv_node, imag_conv_output) =
+          AddNode(graph, "Conv", stft.GetExecutionProviderType(), conv_imag_inputs);
+      imag_conv_node->AddAttribute("strides", std::vector<int64_t>{1, frame_step_value});
+
+      // Concatenate
+      NodeArg* concatenate_inputs[] = {real_conv_output, imag_conv_output};
+      Node* concat_node = nullptr;
+      NodeArg* concatenated_conv_output = nullptr;
+      std::tie(concat_node, concatenated_conv_output) =
+          AddNode(graph, "Concat", stft.GetExecutionProviderType(), concatenate_inputs);
+      concat_node->AddAttribute("axis", static_cast<int64_t>(0));
+
+      // Unsqueeze Reshape
+      NodeArg* unsqueeze_reshape_inputs[] = {concatenated_conv_output, unsqueezed_shape};
+      NodeArg* unsqueezed_output = nullptr;
+      std::tie(std::ignore, unsqueezed_output) =
+          AddNode(graph, "Reshape", stft.GetExecutionProviderType(), unsqueeze_reshape_inputs);
+
+      // Transpose
+      NodeArg* transpose_inputs[] = {unsqueezed_output};
+      Node* transpose_node = nullptr;
+      NodeArg* transpose_output = nullptr;
+      std::tie(transpose_node, transpose_output) =
+          AddNode(graph, "Transpose", stft.GetExecutionProviderType(), transpose_inputs);
+      transpose_node->AddAttribute("perm", std::vector<int64_t>{1, 3, 2, 0});
+
+      signal_recipient = reshape_signal_node;
+      stft_producer = transpose_node;
+    } else {
+      continue;
+    }
+
+    auto input_edges = graph_utils::GraphEdge::GetNodeInputEdges(stft);
+    auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(stft);
+
+    // Copy inputs
+    auto signal_target_idx = signal_recipient->Index();
+    auto window_target_idx = window_recipient->Index();
+    for (auto cur = input_edges.cbegin(), end = input_edges.cend(); cur != end; ++cur) {
+      const graph_utils::GraphEdge& edge = *cur;
+      NodeIndex target_idx = 0;
+      Node* recipient = nullptr;
+      switch (cur->dst_arg_index) {
+        case 0:
+          target_idx = signal_target_idx;
+          recipient = signal_recipient;
+          break;
+        case 2:
+          target_idx = window_target_idx;
+          recipient = window_recipient;
+          break;
+      }
+
+      if (!recipient) {
+        continue;
+      }
+
+      auto arg_index = graph_utils::GetNodeInputIndexFromInputName(*recipient, edge.arg_name);
+      graph.AddEdge(edge.src_node, target_idx, edge.src_arg_index, arg_index);
+    }
+
+    // Copy STFT outputs to stft_producer
+    stft_producer->MutableOutputDefs() = stft.MutableOutputDefs();
+    auto stft_producer_target_idx = stft_producer->Index();
+    for (auto cur = output_edges.cbegin(), end = output_edges.cend(); cur != end; ++cur) {
+      graph.AddEdge(stft_producer_target_idx, cur->dst_node, cur->src_arg_index, cur->dst_arg_index);
+    }
+
+    graph_utils::GraphEdge::RemoveGraphEdges(graph, input_edges);
+    graph_utils::GraphEdge::RemoveGraphEdges(graph, output_edges);
+    graph.RemoveNode(stft.Index());
+
+    modified = true;
+  }
+  return Status::OK();
+}
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/stft_decomposition.h b/onnxruntime/core/optimizer/stft_decomposition.h
new file mode 100644
index 0000000000000..cac058474375e
--- /dev/null
+++ b/onnxruntime/core/optimizer/stft_decomposition.h
@@ -0,0 +1,30 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/optimizer/graph_transformer.h"
+#include "core/framework/ort_value.h"
+#include <memory>
+#include "core/framework/execution_provider.h"
+
+namespace onnxruntime {
+
+/**
+@class STFTDecomposition
+
+Transformer that traverses the graph top-down and decomposes
+STFT into convolution.
+*/
+class STFTDecomposition : public GraphTransformer {
+ public:
+  /*! STFT decomposition .
+      \param execution_provider Execution provider instance to execute constant folding.
+  */
+  STFTDecomposition(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept;
+
+ private:
+  Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
+};
+
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc
index d9f08ffe1171e..c532f56b3d3d9 100644
--- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc
+++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc
@@ -115,7 +115,7 @@ class ApiGraph final : public api::GraphRef {
     const auto& graph_outputs = graph_.GetOutputs();
     graph_outputs_.reserve(graph_outputs.size());
     for (const auto* output : graph_outputs) {
-      graph_outputs_.insert(output->Name());
+      graph_outputs_.emplace(output->Name());
     }
   }
 
diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc
index 7c3599a08ec7a..7055882961e17 100644
--- a/onnxruntime/core/optimizer/utils.cc
+++ b/onnxruntime/core/optimizer/utils.cc
@@ -272,7 +272,7 @@ int32_t IndexOfNodeOutput(const Node& node, const NodeArg& node_arg) {
 // We could also allow other known domains (kMSDomain, kMSNchwcDomain, kMSFeaturizersDomain),
 // as long as we verify which of their operations are non-deterministic and add them in the map below.
 constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNormal", "RandomUniformLike",
-                                                    "RandomNormalLike", "Multinomial"};
+                                                    "RandomNormalLike", "Multinomial", "Dropout"};
 
 // List of deterministic MS domain operators. Currently used for constant folding and common subexpression elimination.
 //
@@ -280,7 +280,8 @@ constexpr std::array kOnnxDomainNonDeterministicOps{"RandomUniform", "RandomNorm
 // with the above ONNX list. With the current approach, only MS domain Q/DQ operators
 // (plus ShrunkenGather for training) are considered deterministic.
 #ifdef ENABLE_TRAINING_OPS
-constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear"};
+constexpr std::array kMSDomainDeterministicOps{"ShrunkenGather", "QuantizeLinear", "DequantizeLinear",
+                                               "ConcatTraining"};
 #else
 constexpr std::array kMSDomainDeterministicOps{"QuantizeLinear", "DequantizeLinear"};
 #endif
diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc
index 1a0713db43db8..983cc6089bb4c 100644
--- a/onnxruntime/core/platform/windows/env.cc
+++ b/onnxruntime/core/platform/windows/env.cc
@@ -32,6 +32,9 @@ limitations under the License.
 #include "core/common/span_utils.h"
 #include "core/platform/env.h"
 #include "core/platform/scoped_resource.h"
+#if defined(_M_X64) && !defined(_M_ARM64EC) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH)
+#include "core/platform/windows/hardware_core_enumerator.h"
+#endif
 #include <unsupported/Eigen/CXX11/ThreadPool>
 #include <wil/Resource.h>
 
@@ -248,12 +251,53 @@ void WindowsEnv::SleepForMicroseconds(int64_t micros) const {
   Sleep(static_cast<DWORD>(micros) / 1000);
 }
 
+// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option.
+#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH)
+static constexpr std::array<int, 3> kVendorID_Intel = {0x756e6547, 0x6c65746e, 0x49656e69};  // "GenuntelineI"
+#endif
 int WindowsEnv::DefaultNumCores() {
   return std::max(1, static_cast<int>(std::thread::hardware_concurrency() / 2));
 }
 
 int WindowsEnv::GetNumPhysicalCpuCores() const {
-  return cores_.empty() ? DefaultNumCores() : static_cast<int>(cores_.size());
+// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option.
+#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH)
+  // The following code is a temporary fix for a perf problem on Intel's Meteor Lake CPUs. The Intel compute platform has
+  // a hybrid architecture that some CPU cores runs significant slower than the others. If we distribute our compute work
+  // evenly to all CPU cores, the slowest CPU core will drag the performance down. So, instead, we reduce the total number
+  // of threads to exclude the slowest cores out.
+  // The following code is based on assumptions that:
+  // 1. All Intel hybrid CPUs should have 3 levels of cache.
+  // 2. If a CPU core is only associated with two levels of cache,  it should be a low performance CPU core and should
+  //    not be used.
+  // Since we don't know what the next Intel hybrid CPU would be like, later on we may need to rework the following code.
+  // However, no matter what the code should not cause any crash. The worst is it might return 1 that
+  //  thread pools will not be created, which is just a perf issue and does not impact usability.
+  // TODO: detect if CPUID instruction is available per instructions at https://wiki.osdev.org/CPUID#Checking_CPUID_availability
+  int regs[4];
+  __cpuid(regs, 0);
+  bool bIsIntel =
+      (kVendorID_Intel[0] == regs[1]) &&
+      (kVendorID_Intel[1] == regs[2]) &&
+      (kVendorID_Intel[2] == regs[3]);
+  if (bIsIntel && regs[0] >= 7) {
+    // Query Structured Extended Feature Flags Enumeration Leaf
+    __cpuid(regs, 0x7);
+    // The bit 15 of EDX indicates if the processor is identified as a hybrid part.
+    bool ishybrid = regs[3] & (1 << 15);
+    if (ishybrid) {
+      // NOTE: even if ishybrid is true, it doesn't mean the processor must have P-cores and E-cores.
+      // On Intel CPUs we assume the HardwareCoreEnumerator::DefaultIntraOpNumThreads function would never fail.
+      // NOTE: due to resource restrictions, we cannot test this branch in our CI build pipelines.
+      return std::max(static_cast<uint32_t>(1), HardwareCoreEnumerator::DefaultIntraOpNumThreads());
+    } else {
+      return cores_.empty() ? DefaultNumCores() : static_cast<int>(cores_.size());
+    }
+  } else
+#endif
+  {
+    return cores_.empty() ? DefaultNumCores() : static_cast<int>(cores_.size());
+  }
 }
 
 std::vector<LogicalProcessors> WindowsEnv::GetDefaultThreadAffinities() const {
@@ -415,8 +459,8 @@ Status WindowsEnv::MapFileIntoMemory(_In_z_ const ORTCHAR_T* file_path,
 
   void* const mapped_base = MapViewOfFile(file_mapping_handle.get(),
                                           FILE_MAP_READ,
-                                          0,
-                                          static_cast<DWORD>(mapped_offset),
+                                          static_cast<DWORD>((mapped_offset >> 32) & 0xFFFFFFFF),
+                                          static_cast<DWORD>(mapped_offset & 0xFFFFFFFF),
                                           mapped_length);
   GSL_SUPPRESS(r.11)
   mapped_memory =
diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc
new file mode 100644
index 0000000000000..121c59808ae59
--- /dev/null
+++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc
@@ -0,0 +1,89 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "hardware_core_enumerator.h"
+#include <memory>
+#include <Windows.h>
+#include <assert.h>
+
+namespace onnxruntime {
+
+struct LogicalProcessorInformation {
+  std::unique_ptr<char[]> Buffer;
+  size_t Length;
+};
+
+struct CoreCounter {
+  uint32_t PhysicalCores = 0;
+  uint32_t SocDieCores = 0;
+};
+
+static LogicalProcessorInformation GetLogicalProcessorInfos(LOGICAL_PROCESSOR_RELATIONSHIP relationship) {
+  DWORD length = 0;
+  DWORD rc = GetLogicalProcessorInformationEx(relationship, nullptr, &length);
+
+  assert(rc == FALSE);
+
+  auto processorInformationBytes = std::make_unique<char[]>(length);
+
+  rc = GetLogicalProcessorInformationEx(
+      relationship, reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(processorInformationBytes.get()), &length);
+
+  assert(rc == TRUE);
+
+  return {std::move(processorInformationBytes), length};
+}
+
+uint32_t CountSetBits(DWORD input) {
+  uint32_t c;
+  for (c = 0; input; c++) {
+    input &= input - 1;
+  }
+  return c;
+}
+
+static CoreCounter GetNumberOPhysicalAndEngineeringCores() {
+  auto logicalProcessorInformation = GetLogicalProcessorInfos(RelationAll);
+
+  CoreCounter cores;
+  DWORD dwLevel2GroupMask = 0;
+  DWORD dwLevel3GroupMask = 0;
+  size_t read = 0;
+  PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX currentProcessorInfo = NULL;
+
+  while ((read + FIELD_OFFSET(SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, Processor)) < logicalProcessorInformation.Length) {
+    currentProcessorInfo =
+        reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(logicalProcessorInformation.Buffer.get() + read);
+    if ((read + currentProcessorInfo->Size) > logicalProcessorInformation.Length) {
+      break;
+    }
+
+    switch (currentProcessorInfo->Relationship) {
+      case RelationProcessorCore:
+        cores.PhysicalCores++;
+        break;
+      case RelationCache:
+        if (currentProcessorInfo->Cache.Level == 2) {
+          dwLevel2GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask;
+        } else if (currentProcessorInfo->Cache.Level == 3) {
+          dwLevel3GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask;
+        }
+        break;
+    }
+
+    read += currentProcessorInfo->Size;
+  }
+
+  cores.SocDieCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask);
+  return cores;
+}
+
+uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() {
+  // # of physical cores = # of P cores + # of E Cores + # of Soc Cores.
+  // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores.
+  auto cores = GetNumberOPhysicalAndEngineeringCores();
+  // We want to use the number of physical cores, but exclude soc cores
+  return cores.PhysicalCores - cores.SocDieCores;
+}
+
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.h b/onnxruntime/core/platform/windows/hardware_core_enumerator.h
new file mode 100644
index 0000000000000..93b50f452afcd
--- /dev/null
+++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.h
@@ -0,0 +1,12 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include <stdint.h>
+
+namespace onnxruntime {
+struct HardwareCoreEnumerator {
+  HardwareCoreEnumerator() = delete;
+  static uint32_t DefaultIntraOpNumThreads();
+};
+}  // namespace onnxruntime
\ No newline at end of file
diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc
index 752b742805a7c..9a242919665bb 100644
--- a/onnxruntime/core/providers/cann/cann_execution_provider.cc
+++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc
@@ -1045,7 +1045,7 @@ CANNExecutionProvider::~CANNExecutionProvider() {
 }
 
 // All threads share the same context and stream
-Status CANNExecutionProvider::OnRunStart() {
+Status CANNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
   CANN_RETURN_IF_ERROR(aclrtSetDevice(info_.device_id));
 
   return Status::OK();
diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h
index 63ae980869c65..d83bd88d6958f 100644
--- a/onnxruntime/core/providers/cann/cann_execution_provider.h
+++ b/onnxruntime/core/providers/cann/cann_execution_provider.h
@@ -33,7 +33,7 @@ class CANNExecutionProvider : public IExecutionProvider {
   explicit CANNExecutionProvider(const CANNExecutionProviderInfo& info);
   virtual ~CANNExecutionProvider();
 
-  Status OnRunStart() override;
+  Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
 
   template <typename T>
   Status Fill(Tensor* y, void* addr, aclrtStream stream) const {
diff --git a/onnxruntime/core/providers/cann/cann_stream_handle.h b/onnxruntime/core/providers/cann/cann_stream_handle.h
index 4d03fe5201209..5d822d23f966f 100644
--- a/onnxruntime/core/providers/cann/cann_stream_handle.h
+++ b/onnxruntime/core/providers/cann/cann_stream_handle.h
@@ -12,6 +12,7 @@
 #include "core/providers/cann/cann_call.h"
 
 namespace onnxruntime {
+void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
 
 struct CannStream : Stream {
   CannStream(aclrtStream stream, const OrtDevice& device, bool own_flag);
@@ -23,10 +24,11 @@ struct CannStream : Stream {
   void Flush() override;
 
   bool own_stream_{true};
+
+  WaitNotificationFn GetWaitNotificationFn() const override { return WaitCannNotificationOnDevice; }
 };
 
 void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
                                const OrtDevice::DeviceType device_type);
 
-void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/builders/coreml_spec.h b/onnxruntime/core/providers/coreml/builders/coreml_spec.h
index e9cd4af94e5fd..9448f1167990e 100644
--- a/onnxruntime/core/providers/coreml/builders/coreml_spec.h
+++ b/onnxruntime/core/providers/coreml/builders/coreml_spec.h
@@ -3,12 +3,33 @@
 
 #pragma once
 
-// TODO come up with a more intuitive way of limiting this to Apple platform builds
-// E.g., putting CoreML EP files that should be enabled iff `defined(__APPLE__)` in a separate directory.
-#if !defined(__APPLE__)
-#error "This file should only be included when building on Apple platforms."
+#include "onnxruntime_config.h"
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic push
+
+// Disable warning from protobuf code.
+//
+// In file included from coreml_proto/Model.pb.h:30:
+// In file included from _deps/protobuf-src/src/google/protobuf/extension_set.h:53:
+// _deps/protobuf-src/src/google/protobuf/parse_context.h:328:47:
+//     error: implicit conversion loses integer precision: 'long' to 'int' [-Werror,-Wshorten-64-to-32]
+#ifdef HAS_SHORTEN_64_TO_32
+#pragma GCC diagnostic ignored "-Wshorten-64-to-32"
+#endif
+#elif defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4244)  // conversion from long to int
 #endif
 
+// Model.pb.h is generated in the build output directory from the CoreML protobuf files in
+// <build output directory>/_deps/coremltools-src/mlmodel/format
 #include "coreml_proto/Model.pb.h"
 
+#if defined(__GNUC__)
+#pragma GCC diagnostic pop
+#elif defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
 namespace COREML_SPEC = CoreML::Specification;
diff --git a/onnxruntime/core/providers/coreml/builders/helper.cc b/onnxruntime/core/providers/coreml/builders/helper.cc
index 897856256cc79..b8ebbd05a2a20 100644
--- a/onnxruntime/core/providers/coreml/builders/helper.cc
+++ b/onnxruntime/core/providers/coreml/builders/helper.cc
@@ -22,22 +22,35 @@
 namespace onnxruntime {
 namespace coreml {
 
-OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, uint32_t coreml_flags) {
+OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer,
+                                         int32_t coreml_version,
+                                         uint32_t coreml_flags) {
   return OpBuilderInputParams{graph_viewer,
-                              (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0};
+                              coreml_version,
+                              (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0,
+                              (coreml_flags & COREML_FLAG_CREATE_MLPROGRAM) != 0};
 }
 
-bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) {
+const IOpBuilder* GetOpBuilder(const Node& node) {
   const auto& op_builders = GetOpBuilders();
-  if (Contains(op_builders, node.OpType())) {
-    const auto* op_builder = op_builders.at(node.OpType());
+  const auto it = op_builders.find(node.OpType());
+  if (it != op_builders.cend()) {
+    return it->second;
+  }
+
+  return nullptr;
+}
+
+bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) {
+  const auto* op_builder = GetOpBuilder(node);
+  if (op_builder) {
     return op_builder->IsOpSupported(node, input_params, logger);
   } else {
     return false;
   }
 }
 
-bool IsInputSupported(const NodeArg& input, const std::string& parent_name,
+bool IsInputSupported(const Node& node, const NodeArg& input,
                       const OpBuilderInputParams& input_params, const logging::Logger& logger) {
   if (!input.Exists()) {
     // optional input that is not provided
@@ -48,8 +61,8 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name,
   std::vector<int64_t> shape;
   // We do not support input with no shape
   if (!GetShape(input, shape, logger)) {
-    LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name
-                          << "] has no shape";
+    LOGS(logger, VERBOSE) << MakeString("Input [", input_name, "] of Node [", node.Name(), "] type [", node.OpType(),
+                                        "] has no shape");
     return false;
   }
 
@@ -63,11 +76,25 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name,
     // For some undocumented reason, Apple CoreML framework will fail loading the model if the model
     // input has dimension > 16384
     // See this issue, https://github.com/apple/coremltools/issues/1003
+    // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf has maximum texture widths which may be the
+    // root cause.
     if (dim > 16384) {
       LOGS(logger, WARNING) << "CoreML does not support input dim > 16384. Input:" << input_name
                             << ", shape: " << Shape2String(shape);
       return false;
     }
+
+    if (dim == 0) {
+      if (node.OpType() == "Resize" && &input == node.InputDefs()[1]) {
+        // one special case. Resize 'roi' input was originally a required input but is rarely used.
+        // ROI is not supported in the CoreML implementation so we will ignore the value, but is often added
+        // (at least in the unit tests) as an initializer with shape {0}.
+      } else {
+        LOGS(logger, WARNING) << "CoreML does not support shapes with dimension values of 0. Input:" << input_name
+                              << ", shape: " << Shape2String(shape);
+        return false;
+      }
+    }
   }
 
   // Limit input shape rank to 5.
@@ -87,13 +114,6 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
                                                   const logging::Logger& logger) {
   std::unordered_set<const Node*> supported_nodes{};
 
-#ifdef __APPLE__
-  if (!util::HasRequiredBaseOS()) {
-    LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because we do not have supported OS";
-    return supported_nodes;
-  }
-#endif
-
   for (const auto& node : graph_viewer.Nodes()) {
     const bool supported = IsNodeSupported(node, input_params, logger);
     LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType()
@@ -111,7 +131,7 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
 
 bool CheckIsConstantInitializer(const NodeArg& node_arg, const GraphViewer& graph_viewer,
                                 const logging::Logger& logger, std::string_view input_description) {
-  if (graph_viewer.GetConstantInitializer(node_arg.Name(), true) == nullptr) {
+  if (graph_viewer.GetConstantInitializer(node_arg.Name()) == nullptr) {
     LOGS(logger, VERBOSE) << input_description << " (NodeArg name: '" << node_arg.Name()
                           << "') is not a constant initializer tensor";
     return false;
@@ -149,7 +169,9 @@ bool HasNeuralEngine(const logging::Logger& logger) {
 #else
   // In this case, we are running the EP on non-apple platform, which means we are running the model
   // conversion with CoreML EP enabled, for this we always assume the target system has Neural Engine
-  LOGS(logger, VERBOSE) << "HasNeuralEngine running on non-Apple hardware for model conversion only";
+  LOGS(logger, INFO) << "HasNeuralEngine running on non-Apple hardware. "
+                        "Returning true to enable model conversion and local testing of CoreML EP implementation. "
+                        "No CoreML model will be compiled or run.";
   has_neural_engine = true;
 #endif  // #ifdef __APPLE__
 
diff --git a/onnxruntime/core/providers/coreml/builders/helper.h b/onnxruntime/core/providers/coreml/builders/helper.h
index d8b27ac76ae73..300de2dedd122 100644
--- a/onnxruntime/core/providers/coreml/builders/helper.h
+++ b/onnxruntime/core/providers/coreml/builders/helper.h
@@ -23,10 +23,14 @@ class Logger;
 
 namespace coreml {
 
-OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, uint32_t coreml_flags);
+OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer,
+                                         int32_t coreml_version,
+                                         uint32_t coreml_flags);
 
-bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name,
-                      const OpBuilderInputParams& input_params, const logging::Logger& logger);
+const IOpBuilder* GetOpBuilder(const Node& node);
+
+bool IsInputSupported(const Node& node, const NodeArg& node_arg, const OpBuilderInputParams& input_params,
+                      const logging::Logger& logger);
 
 bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger);
 
diff --git a/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc
index 53f18b205880c..e9e520156576e 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc
@@ -3,39 +3,26 @@
 
 #include "core/providers/coreml/builders/helper.h"
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/shared/utils/utils.h"
 
-#ifdef __APPLE__
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime {
 namespace coreml {
 
 class LRNOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 };
 
-// Add operator related
-
-#ifdef __APPLE__
-
 Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
                                            const Node& node,
-                                           const logging::Logger& logger) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+                                           const logging::Logger& /*logger*/) const {
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
 
   auto* coreml_lrn = layer->mutable_lrn();
 
@@ -56,9 +43,6 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
   model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 bool LRNOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
                                      const logging::Logger& logger) const {
diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc
index 88d6616b4e097..dee87ce3632a8 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc
@@ -2,44 +2,32 @@
 // Licensed under the MIT License.
 
 #include "core/common/narrow.h"
+#include "core/framework/tensorprotoutils.h"
 #include "core/optimizer/initializer.h"
 #include "core/providers/common.h"
 #include "core/providers/coreml/builders/helper.h"
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/impl/builder_utils.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/shared/utils/utils.h"
 
-#ifdef __APPLE__
-#include "core/framework/tensorprotoutils.h"
-#include "core/providers/coreml/builders/impl/builder_utils.h"
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime {
 namespace coreml {
 
 class ActivationOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- public:
   void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
 
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
+
   int GetMinSupportedOpSet(const Node& node) const override;
 };
 
-// Add operator related
-
-#ifdef __APPLE__
 void ActivationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
   const auto& op_type = node.OpType();
   const auto& input_defs = node.InputDefs();
@@ -86,7 +74,7 @@ Status AddPReluWeight(ModelBuilder& model_builder, const Node& node,
 Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
                                                   const Node& node,
                                                   const logging::Logger& logger) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
 
   const auto& op_type(node.OpType());
   if (op_type == "Sigmoid") {
@@ -115,14 +103,10 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
   model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 namespace {
 // assumes that node.OpType() == "PRelu"
-bool IsPReluOpSupported(const Node& node, const OpBuilderInputParams& input_params,
-                        const logging::Logger& logger) {
+bool IsPReluOpSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) {
   const auto& input_defs = node.InputDefs();
 
   // X input rank must be 3 or 4
diff --git a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc
index 7a5d4a5af673b..e9a8176c8349b 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc
@@ -1,37 +1,26 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
-#include "core/providers/shared/utils/utils.h"
-#ifdef __APPLE__
+#include "core/providers/coreml/builders/impl/base_op_builder.h"
 #include "core/providers/coreml/builders/model_builder.h"
-#endif
 #include "core/providers/coreml/builders/op_builder_factory.h"
-
-#include "base_op_builder.h"
+#include "core/providers/shared/utils/utils.h"
 
 namespace onnxruntime {
 namespace coreml {
 
 class ArgMaxOpBuilder : public BaseOpBuilder {
-  // Add operator related
- private:
-#ifdef __APPLE__
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 };
 
-// Add operator related
-
-#ifdef __APPLE__
 Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
                                               const Node& node,
                                               const logging::Logger& /* logger */) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
   const auto& graph_viewer = model_builder.GetGraphViewer();
 
   NodeAttrHelper helper(node);
@@ -67,9 +56,6 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
   model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
                                         const logging::Logger& logger) const {
diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc
index 25d5bad14ceb6..83a572f4b60fa 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc
@@ -1,21 +1,18 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
-#include "core/providers/coreml/builders/impl/base_op_builder.h"
-
 #include "core/providers/common.h"
 #include "core/providers/coreml/builders/helper.h"
+#include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/shared/utils/utils.h"
 
-#ifdef __APPLE__
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
+using namespace CoreML::Specification;
 
 namespace onnxruntime {
 namespace coreml {
 
-// Shared functions
-
+namespace {
 // TODO, move this to shared_library
 bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node,
                             const logging::Logger& logger) {
@@ -37,93 +34,83 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node
 
   return false;
 }
+}  // namespace
 
-// Add operator related
-#ifdef __APPLE__
 Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node,
-                                        const OpBuilderInputParams& input_params,
                                         const logging::Logger& logger) const {
-  ORT_RETURN_IF_NOT(
-      IsOpSupported(node, input_params, logger),
-      "Unsupported operator ",
-      node.OpType());
-
-  ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger));
-  LOGS(logger, VERBOSE) << "Operator name: [" << node.Name()
-                        << "] type: [" << node.OpType() << "] was added";
-  return Status::OK();
-}
+  Status status = AddToModelBuilderImpl(model_builder, node, logger);
 
-/* static */ std::unique_ptr<COREML_SPEC::NeuralNetworkLayer>
-BaseOpBuilder::CreateNNLayer(ModelBuilder& model_builder, const Node& node) {
-  auto layer_name = node.Name();
-  if (layer_name.empty()) {
-    // CoreML requires layer has a name, while the node name is optional in ONNX
-    // In this case, create a unique name for the layer
-    layer_name = model_builder.GetUniqueName(MakeString("Node_", node.Index(), "_type_", node.OpType()));
+  if (status.IsOK()) {
+    LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() << "] type: [" << node.OpType() << "] was added";
   }
-  return CreateNNLayer(layer_name);
-}
 
-/* static */ std::unique_ptr<COREML_SPEC::NeuralNetworkLayer>
-BaseOpBuilder::CreateNNLayer(const std::string& layer_name) {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = std::make_unique<COREML_SPEC::NeuralNetworkLayer>();
-  layer->set_name(layer_name);
-  return layer;
+  return status;
 }
-#endif
-
-// Operator support related
 
 bool BaseOpBuilder::IsOpSupported(const Node& node, const OpBuilderInputParams& input_params,
                                   const logging::Logger& logger) const {
-  if (!HasSupportedInputs(node, input_params, logger))
+  if (input_params.create_mlprogram && !SupportsMLProgram()) {
+    LOGS(logger, VERBOSE) << "Operator [" << node.OpType() << "] does not support MLProgram";
     return false;
+  }
 
-  // We do not support external initializers for now
-  const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors();
-  if (HasExternalInitializer(initializers, node, logger))
+  if (!HasSupportedOpSet(node, logger)) {
     return false;
+  }
 
-  if (!HasSupportedOpSet(node, logger))
+  if (!HasSupportedInputs(node, input_params, logger)) {
     return false;
+  }
+
+  // We do not support external initializers for now
+  const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors();
+  if (HasExternalInitializer(initializers, node, logger)) {
+    return false;
+  }
 
   return IsOpSupportedImpl(node, input_params, logger);
 }
 
 bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params,
                                        const logging::Logger& logger) const {
-  const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
   for (const auto* input : node.InputDefs()) {
-    if (!IsInputSupported(*input, node_name, input_params, logger)) {
+    if (!IsInputSupported(node, *input, input_params, logger)) {
       return false;
     }
   }
 
-  return HasSupportedInputsImpl(node, logger);
+  return HasSupportedInputsImpl(node, input_params, logger);
 }
 
-bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const {
-  // We only check the type of input 0 by default
-  // specific op builder can override this
-  const auto& input = *node.InputDefs()[0];
-
-  int32_t input_type;
-  if (!GetType(input, input_type, logger))
+/* static */
+bool BaseOpBuilder::IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& /*input_params*/,
+                                 const logging::Logger& logger) {
+  if (idx >= node.InputDefs().size()) {
+    LOGS(logger, VERBOSE) << "Input index [" << idx << "] is out of range";
     return false;
+  }
+
+  const auto& input = *node.InputDefs()[idx];
 
-  if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
-    LOGS(logger, VERBOSE) << "[" << node.OpType()
-                          << "] Input type: [" << input_type
-                          << "] is not supported for now";
+  int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
+
+  // currently only float is supported
+  if (!GetType(input, input_type, logger) || input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
+    LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported";
     return false;
   }
 
   return true;
 }
 
-bool BaseOpBuilder::HasSupportedOpSet(const Node& node,
-                                      const logging::Logger& logger) const {
+bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params,
+                                           const logging::Logger& logger) const {
+  // We only check the type of input 0 by default
+  // specific op builder can override this
+  return IsInputFloat(node, 0, input_params, logger);
+}
+
+bool BaseOpBuilder::HasSupportedOpSet(const Node& node, const logging::Logger& logger) const {
   auto since_version = node.SinceVersion();
   if (since_version < GetMinSupportedOpSet(node) || since_version > GetMaxSupportedOpSet(node)) {
     LOGS(logger, VERBOSE) << node.OpType() << "is only supported for opset ["
diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h
index b4132d3b770ec..63f0b813d654c 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h
+++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h
@@ -3,11 +3,9 @@
 
 #pragma once
 
-#include "core/providers/coreml/builders/op_builder.h"
-
-#ifdef __APPLE__
+#include "core/common/span_utils.h"
 #include "core/providers/coreml/builders/coreml_spec.h"
-#endif
+#include "core/providers/coreml/builders/op_builder.h"
 
 namespace onnxruntime {
 namespace coreml {
@@ -18,45 +16,40 @@ class BaseOpBuilder : public IOpBuilder {
  public:
   virtual ~BaseOpBuilder() = default;
 
-  // Add operator related
+  // does the operator implementation support creating an ML Program
+  bool SupportsMLProgram() const override { return false; }
+
+  bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params,
+                     const logging::Logger& logger) const override final;
 
-#ifdef __APPLE__
- public:
-  virtual void AddInitializersToSkip(ModelBuilder& /* model_builder */, const Node& /* node */) const override {}
   Status AddToModelBuilder(ModelBuilder& model_builder, const Node& node,
-                           const OpBuilderInputParams& input_params,
                            const logging::Logger& logger) const override final;
 
- protected:
-  virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
-                                       const logging::Logger& logger) const = 0;
-
-  static std::unique_ptr<COREML_SPEC::NeuralNetworkLayer>
-  CreateNNLayer(ModelBuilder& model_builder, const Node& node);
-
-  static std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> CreateNNLayer(const std::string& layer_name);
-#endif
-
-  // Operator support related
- public:
-  bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params,
-                     const logging::Logger& logger) const override final;
+  void AddInitializersToSkip(ModelBuilder& /*model_builder*/, const Node& /*node*/) const override {}
 
  protected:
-  virtual bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */,
-                                 const logging::Logger& /* logger */) const {
+  // currently we only support float
+  static bool IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& input_params,
+                           const logging::Logger& logger);
+
+ private:
+  virtual bool IsOpSupportedImpl(const Node& /*node*/, const OpBuilderInputParams& /*input_params*/,
+                                 const logging::Logger& /*logger*/) const {
     return true;
   }
 
-  virtual bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const;
+  virtual bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params,
+                                      const logging::Logger& logger) const;
 
-  virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; }
-  virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 20; }
+  virtual int GetMinSupportedOpSet(const Node& /*node*/) const { return 1; }
+  virtual int GetMaxSupportedOpSet(const Node& /*node*/) const { return 20; }
 
- private:
   bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
   bool HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params,
                           const logging::Logger& logger) const;
+
+  virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
+                                       const logging::Logger& logger) const = 0;
 };
 
 }  // namespace coreml
diff --git a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc
index 391b02eaec497..8da58f659acf1 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc
@@ -5,30 +5,20 @@
 #include "core/providers/coreml/builders/helper.h"
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
 #include "core/providers/coreml/builders/impl/builder_utils.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/shared/utils/utils.h"
 
-#ifdef __APPLE__
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime {
 namespace coreml {
 
 class BatchNormalizationOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- public:
   void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
 
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 
@@ -36,9 +26,6 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder {
   int GetMinSupportedOpSet(const Node& /* node */) const override { return 7; }
 };
 
-// Add operator related
-
-#ifdef __APPLE__
 void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
   // skip everything except input0 for BatchNormalization
   const auto& input_defs = node.InputDefs();
@@ -48,10 +35,9 @@ void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_buil
   model_builder.AddInitializerToSkip(input_defs[4]->Name());  // var
 }
 
-Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
-                                                          const Node& node,
+Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                                           const logging::Logger& /* logger */) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
 
   const auto& input_defs = node.InputDefs();
   const auto& initializers(model_builder.GetInitializerTensors());
@@ -81,9 +67,6 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
   model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                                                     const logging::Logger& logger) const {
diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc
index 10c9b32d03f37..fb8e07633621f 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc
@@ -1,35 +1,31 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
+#include "core/framework/tensorprotoutils.h"
 #include "core/providers/common.h"
 #include "core/providers/coreml/builders/helper.h"
+#include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/impl/builder_utils.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/shared/utils/utils.h"
-#ifdef __APPLE__
-#include "core/framework/tensorprotoutils.h"
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
-#include "base_op_builder.h"
 
 namespace onnxruntime {
 namespace coreml {
-
 class BinaryOpBuilder : public BaseOpBuilder {
-  // Add operator related
- private:
-#ifdef __APPLE__
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
-  // Operator support related
+
   int GetMinSupportedOpSet(const Node& node) const override;
 
-  bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override;
+  bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params,
+                              const logging::Logger& logger) const override;
+
+  bool SupportsMLProgram() const override { return true; }
 };
 
-#ifdef __APPLE__
-static bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger) {
+namespace {
+bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger) {
   const auto& input_defs = node.InputDefs();
 
   const auto* x_shape_proto = input_defs[0]->Shape();
@@ -57,78 +53,94 @@ static bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger&
                     y_shape_proto->dim().begin(), y_shape_proto->dim().end(),
                     dim_eq);
 }
-
-// Add operator related
+}  // namespace
 
 Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                               const logging::Logger& logger) const {
   const auto& op_type(node.OpType());
   const auto& input_defs(node.InputDefs());
 
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
-
-  if (op_type == "Add") {
-    // original mutable_add() has limited broadcasting support
-    // updated to use CoreML::AddBroadcastableLayerParams which has more general broadcasting support
-    if (CheckIfBothInputShapesMatch(node, logger)) {
-      layer->mutable_add();
+#if defined(COREML_ENABLE_MLPROGRAM)
+  if (model_builder.CreateMLProgram()) {
+    using namespace CoreML::Specification::MILSpec;
+
+    // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_binary
+    std::string_view coreml_op_type;
+    if (op_type == "Add") {
+      coreml_op_type = "add";
+    } else if (op_type == "Mul") {
+      coreml_op_type = "mul";
+    } else if (op_type == "Sub") {
+      coreml_op_type = "sub";
+    } else if (op_type == "Div") {
+      // we only support fp32 currently. when we add support for integers we need to check the type and use
+      // "floor_div" or "real_div" accordingly
+      coreml_op_type = "real_div";
+    } else if (op_type == "Pow") {
+      coreml_op_type = "pow";
     } else {
-      layer->mutable_addbroadcastable();
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+                             "BinaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type);
     }
-  } else if (op_type == "Mul") {
-    if (CheckIfBothInputShapesMatch(node, logger)) {
-      layer->mutable_multiply();
+
+    std::unique_ptr<Operation> op = model_builder.CreateOperation(node, coreml_op_type);
+    AddOperationInput(*op, "x", input_defs[0]->Name());
+    AddOperationInput(*op, "y", input_defs[1]->Name());
+    AddOperationOutput(*op, *node.OutputDefs()[0]);
+
+    model_builder.AddOperation(std::move(op));
+  } else
+#endif  // defined (COREML_ENABLE_MLPROGRAM)
+  {
+    std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
+
+    if (op_type == "Add") {
+      // original mutable_add() has limited broadcasting support
+      // updated to use CoreML::AddBroadcastableLayerParams which has more general broadcasting support
+      if (CheckIfBothInputShapesMatch(node, logger)) {
+        layer->mutable_add();
+      } else {
+        layer->mutable_addbroadcastable();
+      }
+    } else if (op_type == "Mul") {
+      if (CheckIfBothInputShapesMatch(node, logger)) {
+        layer->mutable_multiply();
+      } else {
+        layer->mutable_multiplybroadcastable();
+      }
+    } else if (op_type == "Sub") {
+      layer->mutable_subtractbroadcastable();
+    } else if (op_type == "Div") {
+      layer->mutable_dividebroadcastable();
+    } else if (op_type == "Pow") {
+      layer->mutable_powbroadcastable();
     } else {
-      layer->mutable_multiplybroadcastable();
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+                             "BinaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type);
     }
-  } else if (op_type == "Sub") {
-    layer->mutable_subtractbroadcastable();
-  } else if (op_type == "Div") {
-    layer->mutable_dividebroadcastable();
-  } else if (op_type == "Pow") {
-    layer->mutable_powbroadcastable();
-  } else {
-    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
-                           "BinaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
-  }
 
-  *layer->mutable_input()->Add() = input_defs[0]->Name();
-  *layer->mutable_input()->Add() = input_defs[1]->Name();
-  *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
+    *layer->mutable_input()->Add() = input_defs[0]->Name();
+    *layer->mutable_input()->Add() = input_defs[1]->Name();
+    *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
+
+    model_builder.AddLayer(std::move(layer));
+  }
 
-  model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 int BinaryOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const {
   // Add/Sub/Mul/Div opset 6- has broadcast attributes we do not support now
   return 7;
 }
 
-bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const {
-  bool is_pow = node.OpType() == "Pow";
-  if (!is_pow) {
-    return BaseOpBuilder::HasSupportedInputsImpl(node, logger);
-  }
-
-  const auto& input_1 = *node.InputDefs()[0];
-  const auto& input_2 = *node.InputDefs()[1];
-  // Pow we only support both inputs as fp32 for now
-  int32_t input_type_1;
-  if (!GetType(input_1, input_type_1, logger))
-    return false;
-
-  int32_t input_type_2;
-  if (!GetType(input_2, input_type_2, logger))
-    return false;
-
-  if (input_type_1 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT || input_type_1 != input_type_2) {
-    LOGS(logger, VERBOSE) << "Pow only supports fp32 inputs, actual input type"
-                          << ", Input type 1: " << input_type_1
-                          << ", Input type 2: " << input_type_2;
+bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params,
+                                             const logging::Logger& logger) const {
+  // Add/Sub/Mul/Div spec says inputs must be of the same type.
+  // Pow spec says inputs can be different types.
+  // We only support float for all of these inputs.
+  if (!IsInputFloat(node, 0, input_params, logger) ||
+      ((node.OpType() == "Pow") && !IsInputFloat(node, 1, input_params, logger))) {
     return false;
   }
 
diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc
index ef66e6b877a1f..cbea969904ed5 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc
@@ -1,17 +1,17 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
-#ifdef __APPLE__
-
 #include "core/providers/coreml/builders/impl/builder_utils.h"
 
 #include "core/common/narrow.h"
 #include "core/framework/tensorprotoutils.h"
+#include "core/providers/coreml/builders/coreml_spec.h"
 #include "core/providers/coreml/builders/helper.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/shared/utils/utils.h"
 #include "core/optimizer/initializer.h"
 
-#include "coreml_proto/NeuralNetwork.pb.h"
+using namespace COREML_SPEC;
 
 namespace onnxruntime {
 namespace coreml {
@@ -133,7 +133,249 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span<c
   CreateCoreMLWeightConvertingDataToFloats(weight, data);
 }
 
+#if defined(COREML_ENABLE_MLPROGRAM)
+//
+// ML Program Utils
+//
+
+namespace {
+void SetTensorTypeInfo(MILSpec::TensorType& tensor_type, MILSpec::DataType data_type,
+                       std::optional<gsl::span<const int64_t>> shape) {
+  tensor_type.set_datatype(data_type);
+  if (shape) {
+    tensor_type.set_rank(shape->size());
+    for (const auto& dim : *shape) {
+      if (dim >= 0) {
+        tensor_type.add_dimensions()->mutable_constant()->set_size(narrow<int32_t>(dim));
+      } else {
+        tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false);
+      }
+    }
+  }
+}
+
+void SetTensorTypeInfo(MILSpec::TensorType& tensor_type, MILSpec::DataType data_type,
+                       const ONNX_NAMESPACE::TensorShapeProto* shape) {
+  tensor_type.set_datatype(data_type);
+  if (shape) {
+    tensor_type.set_rank(shape->dim_size());
+    for (const auto& dim : shape->dim()) {
+      if (dim.has_dim_value()) {
+        tensor_type.add_dimensions()->mutable_constant()->set_size(narrow<int32_t>(dim.dim_value()));
+      } else {
+        tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false);
+      }
+    }
+  }
+}
+
+template <typename T1, typename T2 = T1>
+void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span<const T1> data) {
+  // need a 'false' that is dependent on the template types to make gcc happy and give a meaningful error message.
+  static_assert(false_for_T<T1> && false_for_T<T2>, "Unsupported data type");  // add specializations below as needed
+}
+
+template <>
+void CopyDataToTensorValue<float>(MILSpec::TensorValue& tensor_value, gsl::span<const float> data) {
+  tensor_value.mutable_floats()->mutable_values()->Add(data.begin(), data.end());
+}
+
+template <>
+void CopyDataToTensorValue<int32_t>(MILSpec::TensorValue& tensor_value, gsl::span<const int32_t> data) {
+  tensor_value.mutable_ints()->mutable_values()->Add(data.begin(), data.end());
+}
+
+template <>
+void CopyDataToTensorValue<std::string>(MILSpec::TensorValue& tensor_value, gsl::span<const std::string> data) {
+  tensor_value.mutable_strings()->mutable_values()->Add(data.begin(), data.end());
+}
+
+// copy int64_t (used by ONNX for strides/indexes/etc.) to int32_t (used by CoreML)
+template <>
+void CopyDataToTensorValue<int64_t, int32_t>(MILSpec::TensorValue& tensor_value, gsl::span<const int64_t> data) {
+  auto& int32_out = *tensor_value.mutable_ints()->mutable_values();
+  int32_out.Reserve(narrow<int32_t>(data.size()));
+  for (const int64_t v : data) {
+    int32_out.AddAlreadyReserved(narrow<int32_t>(v));
+  }
+}
+
+template <>
+void CopyDataToTensorValue<bool>(MILSpec::TensorValue& tensor_value, gsl::span<const bool> data) {
+  tensor_value.mutable_bools()->mutable_values()->Add(data.begin(), data.end());
+}
+
+}  // namespace
+
+MILSpec::DataType OnnxDataTypeToMILSpec(int onnx_type) {
+  switch (static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_type)) {
+    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
+      return MILSpec::DataType::FLOAT32;
+    case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
+      return MILSpec::DataType::FLOAT64;
+    case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
+      return MILSpec::DataType::BFLOAT16;
+    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
+      return MILSpec::DataType::FLOAT16;
+
+    case ONNX_NAMESPACE::TensorProto_DataType_INT8:
+      return MILSpec::DataType::INT8;
+    case ONNX_NAMESPACE::TensorProto_DataType_INT16:
+      return MILSpec::DataType::INT16;
+    case ONNX_NAMESPACE::TensorProto_DataType_INT32:
+      return MILSpec::DataType::INT32;
+    case ONNX_NAMESPACE::TensorProto_DataType_INT64:
+      return MILSpec::DataType::INT64;
+
+    case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
+      return MILSpec::DataType::UINT8;
+    case ONNX_NAMESPACE::TensorProto_DataType_UINT16:
+      return MILSpec::DataType::UINT16;
+    case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
+      return MILSpec::DataType::UINT32;
+    case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
+      return MILSpec::DataType::UINT64;
+
+    case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
+      return MILSpec::DataType::BOOL;
+    case ONNX_NAMESPACE::TensorProto_DataType_STRING:
+      return MILSpec::DataType::STRING;
+    default:
+      ORT_THROW("Unsupported data type: ", onnx_type);
+  }
+}
+
+template <typename T1, typename T2>
+MILSpec::Value CreateTensorValue(const gsl::span<const T1> data,
+                                 std::optional<gsl::span<const int64_t>> shape) {
+  MILSpec::Value value;
+  MILSpec::TensorType& tensor_type = *value.mutable_type()->mutable_tensortype();
+
+  if (shape) {
+    SetTensorTypeInfo(tensor_type, DataTypeToMILSpec<T2>(), *shape);
+  } else {
+    // infer as 1D shape
+    std::vector<int64_t> coreml_shape{narrow<int64_t>(data.size())};
+    SetTensorTypeInfo(tensor_type, DataTypeToMILSpec<T2>(), coreml_shape);
+  }
+
+  MILSpec::TensorValue& tensor_value = *value.mutable_immediatevalue()->mutable_tensor();
+  CopyDataToTensorValue<T1, T2>(tensor_value, data);
+
+  return value;
+}
+
+template <typename T>
+MILSpec::Value CreateScalarTensorValue(const T& data) {
+  gsl::span<const T> data_span{&data, 1};
+  std::vector<int64_t> shape = {};  // empty for scalar
+  return CreateTensorValue<T>(data_span, shape);
+}
+
+// explicit specializations for types we handle so the implementation can be in the .cc file
+template MILSpec::Value CreateTensorValue<int64_t, int32_t>(gsl::span<const int64_t> data,
+                                                            std::optional<gsl::span<const int64_t>> shape);
+
+template MILSpec::Value CreateScalarTensorValue(const float& data);
+template MILSpec::Value CreateScalarTensorValue(const int32_t& data);
+template MILSpec::Value CreateScalarTensorValue(const std::string& data);
+template MILSpec::Value CreateScalarTensorValue(const bool& data);
+
+COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg) {
+  MILSpec::NamedValueType nvt;
+  nvt.set_name(node_arg.Name());
+  MILSpec::TensorType& tensor_type = *nvt.mutable_type()->mutable_tensortype();
+
+  SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(node_arg.TypeAsProto()->tensor_type().elem_type()),
+                    node_arg.Shape());
+
+  return nvt;
+}
+
+void AddOperationInput(MILSpec::Operation& op, std::string_view input_name, std::string_view value_name) {
+  MILSpec::Argument arg;
+  arg.mutable_arguments()->Add()->set_name(std::string(value_name));
+
+  (*op.mutable_inputs())[input_name] = std::move(arg);
+}
+
+void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output) {
+  auto& outputs = *op.mutable_outputs();
+  auto& output_arg = *outputs.Add();
+  output_arg.set_name(output.Name());
+
+  MILSpec::ValueType& value = *output_arg.mutable_type();
+  MILSpec::TensorType& tensor_type = *value.mutable_tensortype();
+
+  SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(output.TypeAsProto()->tensor_type().elem_type()),
+                    output.Shape());
+}
+
+void AddPadTypeAndPads(COREML_SPEC::MILSpec::Operation& op, ModelBuilder& model_builder, std::string_view op_type,
+                       const NodeAttrHelper& helper, int num_spatial_dims) {
+  AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET"));
+
+  // pad type (string)
+  //   valid - no pads  (ONNX auto_pad VALID)
+  //   custom - pads input  (ONNX NOTSET)
+  //   same - inferred to be `d_out[i] = ceil(d_in[i] / strides[i])`  (assuming == ONNX SAME_UPPER)
+  //   same_lower - as per same but any extra rows/cols are added at top/left if padding is odd (ONNX SAME_LOWER)
+  //
+  // TODO: See if we want to update HandleAutoPad to support 1D (and 3D) so we can infer if an autopad value
+  //       can be used. TBD if that provides any performance benefit with ML Program though as CoreML could
+  //       potentially do that same optimization internally.
+  switch (auto_pad_type) {
+    case AutoPadType::NOTSET: {
+      // use `pads` attribute.
+      auto onnx_pads = helper.GetInt64s("pads");  // 'pads' are used if auto_pad is NOTSET
+      if (onnx_pads) {
+        AddOperationInput(op, "pad_type",
+                          model_builder.AddScalarConstant(op_type, "pad_type", std::string("custom")));
+
+        // need to re-order from x1_start, x2_start..., x1_end, x2_end... to
+        // x1_start, x1_end, x2_start, x2_end,...
+        size_t num_pads = onnx_pads->size();
+        size_t num_dims = num_pads / 2;
+        std::vector<int64_t> reordered_pads(num_pads, 0);
+        for (size_t i = 0; i < num_pads; ++i) {
+          auto cur_dim = i % num_dims;
+          if (i < num_dims) {  // start values
+            reordered_pads[cur_dim * 2] = (*onnx_pads)[i];
+          } else {  // end values
+            reordered_pads[cur_dim * 2 + 1] = (*onnx_pads)[i];
+          }
+        }
+
+        AddOperationInput(op, "pad", model_builder.AddConstant(op_type, "pad", reordered_pads));
+
+        break;
+      }
+
+      // fall through if explicit pads were not provided as the default value for `pads` is all zeros,
+      // which is the same as 'valid' padding.
+      [[fallthrough]];
+    }
+    case AutoPadType::VALID:
+      AddOperationInput(op, "pad_type",
+                        model_builder.AddScalarConstant(op_type, "pad_type", std::string("valid")));
+
+      break;
+    case AutoPadType::SAME_UPPER:
+    case AutoPadType::SAME_LOWER: {
+      const auto pad_type = (auto_pad_type == AutoPadType::SAME_UPPER ? "same" : "same_lower");
+      AddOperationInput(op, "pad_type",
+                        model_builder.AddScalarConstant(op_type, "pad_type", std::string(pad_type)));
+
+      // despite what the spec says, a 'pad' input seems to be required.
+      // https://github.com/apple/coremltools/issues/2127
+      // Provide the default value as that's what coremltools does for conv/avg_pool/max_pool.
+      std::vector<int64_t> ignored_pads(num_spatial_dims * 2, 0);
+      AddOperationInput(op, "pad", model_builder.AddConstant(op_type, "pad", ignored_pads));
+
+      break;
+    }
+  }
+}
+#endif  // defined(COREML_ENABLE_MLPROGRAM)
 }  // namespace coreml
 }  // namespace onnxruntime
-
-#endif
diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h
index 23b11928f7dc2..2804589065631 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h
+++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h
@@ -5,21 +5,20 @@
 
 #pragma once
 
-#ifdef __APPLE__
+#include <optional>
 
 #include "core/common/gsl.h"
 #include "core/common/status.h"
 #include "core/graph/basic_types.h"
 #include "core/providers/common.h"
-
-namespace CoreML {
-namespace Specification {
-class WeightParams;
-}
-}  // namespace CoreML
+#include "core/providers/coreml/builders/coreml_spec.h"
+#include "core/providers/shared/utils/utils.h"
 
 namespace onnxruntime {
+class NodeArg;
+
 namespace coreml {
+class ModelBuilder;
 
 // Try to see if we can map explicit padding to auto padding for Conv/Pool
 // Since usually use auto padding is more efficient
@@ -32,6 +31,10 @@ Status HandleAutoPad(const std::vector<int64_t> input_shape,
                      AutoPadType auto_pad_type,
                      AutoPadType& auto_pad_type_out);
 
+//
+// NeuralNetwork utils
+//
+
 // Copy an onnx initializer data to a coreml weight
 Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, const ONNX_NAMESPACE::TensorProto& tensor);
 
@@ -44,7 +47,103 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span<c
 // Copy the int64_t array to a coreml weight
 void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span<const int64_t> data);
 
+#if defined(COREML_ENABLE_MLPROGRAM)
+//
+// MLProgram utils
+//
+
+// helper for static_assert where the value needs to be dependent on a template parameter
+template <typename>
+constexpr bool false_for_T = false;
+
+template <typename T>
+COREML_SPEC::MILSpec::DataType DataTypeToMILSpec() {
+  if constexpr (std::is_same_v<T, float>) {
+    return COREML_SPEC::MILSpec::DataType::FLOAT32;
+  } else if constexpr (std::is_same_v<T, double>) {
+    return COREML_SPEC::MILSpec::DataType::FLOAT64;
+  } else if constexpr (std::is_same_v<T, BFloat16>) {
+    return COREML_SPEC::MILSpec::DataType::BFLOAT16;
+  } else if constexpr (std::is_same_v<T, MLFloat16>) {
+    return COREML_SPEC::MILSpec::DataType::FLOAT16;
+
+  } else if constexpr (std::is_same_v<T, int8_t>) {
+    return COREML_SPEC::MILSpec::DataType::INT8;
+  } else if constexpr (std::is_same_v<T, int16_t>) {
+    return COREML_SPEC::MILSpec::DataType::INT16;
+  } else if constexpr (std::is_same_v<T, int32_t>) {
+    return COREML_SPEC::MILSpec::DataType::INT32;
+  } else if constexpr (std::is_same_v<T, int64_t>) {
+    return COREML_SPEC::MILSpec::DataType::INT64;
+
+  } else if constexpr (std::is_same_v<T, uint8_t>) {
+    return COREML_SPEC::MILSpec::DataType::UINT8;
+  } else if constexpr (std::is_same_v<T, uint16_t>) {
+    return COREML_SPEC::MILSpec::DataType::UINT16;
+  } else if constexpr (std::is_same_v<T, uint32_t>) {
+    return COREML_SPEC::MILSpec::DataType::UINT32;
+  } else if constexpr (std::is_same_v<T, uint64_t>) {
+    return COREML_SPEC::MILSpec::DataType::UINT64;
+
+  } else if constexpr (std::is_same_v<T, bool>) {
+    return COREML_SPEC::MILSpec::DataType::BOOL;
+  } else if constexpr (std::is_same_v<T, std::string>) {
+    return COREML_SPEC::MILSpec::DataType::STRING;
+  } else {
+    static_assert(false_for_T<T>, "Unsupported type.");
+  }
+}
+
+// The TensorProto.data_type field is an int, but must be a valid TensorProto_DataType value.
+// Use int for the arg so the caller can pass TensorProto.data_type() value and do the cast to enum internally
+COREML_SPEC::MILSpec::DataType OnnxDataTypeToMILSpec(int onnx_type);
+
+/// <summary>
+/// Create a CoreML MILSpec::TensorValue for the given input data.
+/// </summary>
+/// <typeparam name="T1">Original C++ data type</typeparam>
+/// <typeparam name="T2">CoreML C++ data type</typeparam>
+/// <param name="data">ONNX data</param>
+/// <param name="shape">ONNX data shape. Inferred to be a 1D shape of `{data.size()}` if not specified.</param>
+/// <returns>TensorValue containing data.</returns>
+template <typename T1, typename T2 = T1>
+COREML_SPEC::MILSpec::Value CreateTensorValue(gsl::span<const T1> data,
+                                              std::optional<gsl::span<const int64_t>> shape = std::nullopt);
+
+template <typename T>
+COREML_SPEC::MILSpec::Value CreateScalarTensorValue(const T& data);
+
+/// <summary>Create a NamedValueType from an ONNX tensor NodeArg.</summary>
+/// <remarks>Used to create inputs for the 'main' function in an ML Program.</remarks>
+COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg);
+
+/// <summary>
+/// Add an input argument to a MILSpec::Operation
+/// </summary>
+/// <param name="op">Operation to update.</param>
+/// <param name="input_name">The input name defined by the spec for the operation.</param>
+/// <param name="value_name">The name of the value that is providing the input.</param>
+/// <see>"https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html"</see>
+void AddOperationInput(COREML_SPEC::MILSpec::Operation& op,
+                       std::string_view input_name, std::string_view value_name);
+
+/// <summary>
+/// Add an output to a MILSpec::Operation. Name, data type and shape are used from the NodeArg.
+/// </summary>
+/// <param name="op">Operation to update.</param>
+/// <param name="output">NodeArg with details of output to add.</param>
+void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output);
+
+/// <summary>
+/// Add pad_type and pad values.
+/// </summary>
+/// <param name="op">Operator to update</param>
+/// <param name="model_builder">ModelBuilder to add constants with.</param>
+/// <param name="op_type">Operator type.</param>
+/// <param name="helper">Node attribute helper.</param>
+/// <param name="num_spatial_dims">Number of spatial dims in input. Generally rank - 2 (ignore N and C dims).</param>
+void AddPadTypeAndPads(COREML_SPEC::MILSpec::Operation& op, ModelBuilder& model_builder, std::string_view op_type,
+                       const NodeAttrHelper& helper, int num_spatial_dims);
+#endif  // defined(COREML_ENABLE_MLPROGRAM)
 }  // namespace coreml
 }  // namespace onnxruntime
-
-#endif
diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc
index 15ee1f0fc7284..70053c2c606a0 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc
@@ -1,34 +1,25 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
-#include "core/providers/shared/utils/utils.h"
 #include "core/providers/coreml/builders/helper.h"
-#ifdef __APPLE__
+#include "core/providers/coreml/builders/impl/base_op_builder.h"
 #include "core/providers/coreml/builders/model_builder.h"
-#endif
 #include "core/providers/coreml/builders/op_builder_factory.h"
-
-#include "base_op_builder.h"
+#include "core/providers/shared/utils/utils.h"
 
 namespace onnxruntime {
 namespace coreml {
 
 class CastOpBuilder : public BaseOpBuilder {
-  // Add operator related
- private:
-#ifdef __APPLE__
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
-  // Operator support related
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
-  bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override;
-};
 
-// Add operator related
+  bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params,
+                              const logging::Logger& logger) const override;
+};
 
-#ifdef __APPLE__
 Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& /* model_builder */,
                                             const Node& /* node */,
                                             const logging::Logger& /* logger */) const {
@@ -37,9 +28,6 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& /* model_builder */,
   // Cast node is not provided in CoreML model, so we're skipping adding the Cast node here.
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                                       const logging::Logger& logger) const {
@@ -84,7 +72,8 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara
   return true;
 }
 
-bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const {
+bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
+                                           const logging::Logger& logger) const {
   // We only check the type of input 0
   const auto& input = *node.InputDefs()[0];
 
diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc
index a298a8d12c741..41f4041ef1181 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc
@@ -1,40 +1,48 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
-#ifdef __APPLE__
+#include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/impl/builder_utils.h"
 #include "core/providers/coreml/builders/model_builder.h"
-#endif
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/shared/utils/utils.h"
 
-#include "base_op_builder.h"
-
 namespace onnxruntime {
 namespace coreml {
 
 class ClipOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- public:
   void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
 
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
-};
 
-// Add operator related
+  bool SupportsMLProgram() const override { return true; }
+};
 
-#ifdef __APPLE__
 void ClipOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
+  bool skip = true;
+
+  if (model_builder.CreateMLProgram()) {
+    float min, max;
+    ORT_IGNORE_RETURN_VALUE(GetClipMinMax(model_builder.GetGraphViewer(), node, min, max, model_builder.Logger()));
+
+    bool has_min = min != std::numeric_limits<float>::lowest();
+    bool has_max = max != std::numeric_limits<float>::max();
+    if (has_min && has_max && min == 0.f && max == 6.f) {
+      // relu6 - skip both
+    } else if (has_min && min == 0.f && !has_max) {
+      // relu - skip both
+    } else {
+      // clip - we will use both
+      skip = false;
+    }
+  }
+
   // Both min and max values will be injected into the layer, no need to add to the model
-  if (node.SinceVersion() >= 11) {
+  if (skip && node.SinceVersion() >= 11) {
     if (node.InputDefs().size() > 1)
       model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
 
@@ -48,86 +56,132 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
                                             const logging::Logger& logger) const {
   const auto& node_name = node.Name();
   const auto& input_name = node.InputDefs()[0]->Name();
-  const auto& output_name = node.OutputDefs()[0]->Name();
+  const auto& output = *node.OutputDefs()[0];
+  const auto& output_name = output.Name();
   float min, max;
   ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetGraphViewer(), node, min, max, logger), "GetClipMinMax failed");
 
   bool has_min = min != std::numeric_limits<float>::lowest();
   bool has_max = max != std::numeric_limits<float>::max();
 
-  if (!has_min && !has_max) {
-    // Clip without min/max is an identity node
-    // In CoreML we don't have identity, use ActivationLinear instead
-    std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
-    layer->mutable_activation()->mutable_linear()->set_alpha(1.0f);
-    *layer->mutable_input()->Add() = input_name;
-    *layer->mutable_output()->Add() = output_name;
-
-    model_builder.AddLayer(std::move(layer));
-  } else {
-    // The implementation of clip(min, max) is done by
-    // 1. Clipping at min -> max(input, min) is handled by
-    //    min_output = threshold(input, min)
-    // 2. Clipping at max -> min(min_output, max) is handled by
-    //    output = -1 * (threshold(-min_output, -max))
-
-    // Now we have at least one or min or max is not default value
-    // Clipping at max will need take the output of clipping at min, or the node input, if min value is default
-    // If max value is default, the output of clipping at min will be the output of the node
-    std::string min_output_name = output_name;
-    if (has_max) {
-      min_output_name = has_min
-                            ? model_builder.GetUniqueName(node_name + "min_output")
-                            : input_name;
+#if defined(COREML_ENABLE_MLPROGRAM)
+  if (model_builder.CreateMLProgram()) {
+    using namespace CoreML::Specification::MILSpec;
+
+    std::unique_ptr<Operation> op;
+    if (!has_min && !has_max) {
+      // Clip without min/max is an identity node.
+      op = model_builder.CreateOperation(node, "identity");
+      Operation& identity_op = *op;
+      AddOperationInput(identity_op, "x", input_name);
+    } else {
+      if (has_min && has_max && min == 0.f && max == 6.f) {
+        // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.activation.relu6
+        op = model_builder.CreateOperation(node, "relu6");
+        Operation& relu6_op = *op;
+        AddOperationInput(relu6_op, "x", input_name);
+      } else if (has_min && min == 0.f && !has_max) {
+        // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.activation.relu
+        op = model_builder.CreateOperation(node, "relu");
+        Operation& relu_op = *op;
+        AddOperationInput(relu_op, "x", input_name);
+      } else {
+        // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_unary.clip
+        op = model_builder.CreateOperation(node, "clip");
+
+        Operation& clip_op = *op;
+        AddOperationInput(clip_op, "x", input_name);
+
+        // if min and max were attributes we need to add initializers. otherwise we use the existing inputs
+        const bool min_max_attribs = node.SinceVersion() < 11;
+        std::string_view min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", min)
+                                                    : node.InputDefs()[1]->Name();
+
+        AddOperationInput(clip_op, "alpha", min_name);
+
+        if (has_max) {
+          std::string_view max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", max)
+                                                      : node.InputDefs()[2]->Name();
+          AddOperationInput(clip_op, "beta", max_name);
+        }
+      }
     }
 
-    // Handle clipping at min first
-    if (has_min) {
-      const auto clip_min_layer_name = model_builder.GetUniqueName(MakeString(node_name, "_Clip_min"));
-      std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> min_layer = CreateNNLayer(clip_min_layer_name);
-      if (min == 0.0f) {  // If min is 0. then this min will be handled by relu
-        min_layer->mutable_activation()->mutable_relu();
-      } else {  // otherwise, min will be handled by unary->threshold
-        min_layer->mutable_unary()->set_alpha(min);
-        min_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD);
+    AddOperationOutput(*op, output);
+    model_builder.AddOperation(std::move(op));
+  } else
+#endif  // defined(COREML_ENABLE_MLPROGRAM)
+  {
+    // TODO: CoreML has a Clip layer for NeuralNetwork. Added in CoreML 4. We could potentially use that if available
+    // to simplify.
+    // https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#cliplayerparams
+
+    if (!has_min && !has_max) {
+      // Clip without min/max is an identity node
+      // In CoreML we don't have identity, use ActivationLinear instead
+      std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
+      layer->mutable_activation()->mutable_linear()->set_alpha(1.0f);
+      *layer->mutable_input()->Add() = input_name;
+      *layer->mutable_output()->Add() = output_name;
+
+      model_builder.AddLayer(std::move(layer));
+    } else {
+      // The implementation of clip(min, max) is done by
+      // 1. Clipping at min -> max(input, min) is handled by
+      //    min_output = threshold(input, min)
+      // 2. Clipping at max -> min(min_output, max) is handled by
+      //    output = -1 * (threshold(-min_output, -max))
+
+      // Now we have at least one or min or max is not default value
+      // Clipping at max will need take the output of clipping at min, or the node input, if min value is default
+      // If max value is default, the output of clipping at min will be the output of the node
+      std::string min_output_name = output_name;
+      if (has_max) {
+        min_output_name = has_min
+                              ? model_builder.GetUniqueName(node_name + "min_output")
+                              : input_name;
       }
 
-      *min_layer->mutable_input()->Add() = input_name;
-      *min_layer->mutable_output()->Add() = min_output_name;
-      model_builder.AddLayer(std::move(min_layer));
-    }
-
-    // Clipping at max is handled by -1 * (threshold (-min_output, -max))
-    if (has_max) {
-      const auto threshold_output_name = model_builder.GetUniqueName(MakeString(node_name, "threshold_output"));
-      {  // Add threshold layer, which is actually max( -1 * min_output, -max)
-        const auto clip_max_threshold_layer_name =
-            model_builder.GetUniqueName(MakeString(node_name, "_Clip_max_threshold"));
-        auto threshold_layer = CreateNNLayer(clip_max_threshold_layer_name);
-        threshold_layer->mutable_unary()->set_alpha(-max);
-        threshold_layer->mutable_unary()->set_scale(-1.0f);
-        threshold_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD);
-        *threshold_layer->mutable_input()->Add() = min_output_name;
-        *threshold_layer->mutable_output()->Add() = threshold_output_name;
-        model_builder.AddLayer(std::move(threshold_layer));
+      // Handle clipping at min first
+      if (has_min) {
+        std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> min_layer = model_builder.CreateNNLayer(node, "_Clip_min");
+        if (min == 0.0f) {  // If min is 0. then this min will be handled by relu
+          min_layer->mutable_activation()->mutable_relu();
+        } else {  // otherwise, min will be handled by unary->threshold
+          min_layer->mutable_unary()->set_alpha(min);
+          min_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD);
+        }
+
+        *min_layer->mutable_input()->Add() = input_name;
+        *min_layer->mutable_output()->Add() = min_output_name;
+        model_builder.AddLayer(std::move(min_layer));
       }
-      {  // Add linear activation layer -1 * threshold_output
-        const auto clip_max_linear_layer_name =
-            model_builder.GetUniqueName(MakeString(node_name, "_Clip_max_linear"));
-        auto linear_layer = CreateNNLayer(clip_max_linear_layer_name);
-        linear_layer->mutable_activation()->mutable_linear()->set_alpha(-1.0f);
-        *linear_layer->mutable_input()->Add() = threshold_output_name;
-        *linear_layer->mutable_output()->Add() = output_name;
-        model_builder.AddLayer(std::move(linear_layer));
+
+      // Clipping at max is handled by -1 * (threshold (-min_output, -max))
+      if (has_max) {
+        const auto threshold_output_name = model_builder.GetUniqueName(MakeString(node_name, "threshold_output"));
+        {  // Add threshold layer, which is actually max( -1 * min_output, -max)
+          auto threshold_layer = model_builder.CreateNNLayer(node, "_Clip_max_threshold");
+          threshold_layer->mutable_unary()->set_alpha(-max);
+          threshold_layer->mutable_unary()->set_scale(-1.0f);
+          threshold_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD);
+          *threshold_layer->mutable_input()->Add() = min_output_name;
+          *threshold_layer->mutable_output()->Add() = threshold_output_name;
+          model_builder.AddLayer(std::move(threshold_layer));
+        }
+        {  // Add linear activation layer -1 * threshold_output
+          auto linear_layer = model_builder.CreateNNLayer(node, "_Clip_max_linear");
+          linear_layer->mutable_activation()->mutable_linear()->set_alpha(-1.0f);
+          *linear_layer->mutable_input()->Add() = threshold_output_name;
+          *linear_layer->mutable_output()->Add() = output_name;
+          model_builder.AddLayer(std::move(linear_layer));
+        }
       }
     }
   }
 
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 bool ClipOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                                       const logging::Logger& logger) const {
diff --git a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc
index b1e761024f5c9..34193318a0264 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc
@@ -4,37 +4,26 @@
 #include "core/providers/common.h"
 #include "core/providers/coreml/builders/helper.h"
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/shared/utils/utils.h"
 
-#ifdef __APPLE__
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime {
 namespace coreml {
 
 class ConcatOpBuilder : public BaseOpBuilder {
-  // Add operator related
- private:
-#ifdef __APPLE__
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 };
 
-// Add operator related
-
-#ifdef __APPLE__
 Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
                                               const Node& node,
                                               const logging::Logger& logger) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
 
   layer->mutable_concat()->set_sequenceconcat(false);
 
@@ -48,9 +37,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
   model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif
 
-// Operator support related
 bool ConcatOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */,
                                         const logging::Logger& logger) const {
   const auto& input_defs = node.InputDefs();
diff --git a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc
index ff9dcbd9f8874..38125957bf481 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc
@@ -4,39 +4,35 @@
 #include "core/providers/common.h"
 #include "core/providers/coreml/builders/helper.h"
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
-#include "core/providers/coreml/builders/op_builder_factory.h"
-#include "core/providers/shared/utils/utils.h"
-
-#ifdef __APPLE__
 #include "core/providers/coreml/builders/impl/builder_utils.h"
 #include "core/providers/coreml/builders/model_builder.h"
+#include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
-#endif
+#include "core/providers/shared/utils/utils.h"
+
+using namespace CoreML::Specification;
 
 namespace onnxruntime {
 namespace coreml {
 
 class ConvOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- public:
   void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
 
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
   bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */,
                          const logging::Logger& /* logger */) const override;
-};
 
-// Add operator related
+  bool SupportsMLProgram() const override { return true; }
+};
 
-#ifdef __APPLE__
 void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
+  if (model_builder.CreateMLProgram()) {
+    // we add the initializers as 'const' operations via ModelBuilder::RegisterInitializers
+    return;
+  }
+
   const auto& input_defs = node.InputDefs();
 
   // skip the weight and bias (if has it) for conv as we will directly set those as part of the NN layer
@@ -49,136 +45,177 @@ void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod
 
 Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                             const logging::Logger& logger) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
-
   const auto& input_defs = node.InputDefs();
   const auto& output_defs = node.OutputDefs();
   const auto& input_name = input_defs[0]->Name();
   const auto& output_name = output_defs[0]->Name();
 
-  const auto& weight_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name());
-  std::vector<int64_t> weight_shape = {weight_tensor.dims().cbegin(), weight_tensor.dims().cend()};
+  NodeAttrHelper helper(node);
 
-  const bool is_1d_conv = (weight_shape.size() == 3);
+#if defined(COREML_ENABLE_MLPROGRAM)
+  if (model_builder.CreateMLProgram()) {
+    using namespace CoreML::Specification::MILSpec;
 
-  if (is_1d_conv) {
-    // weight_shape needs to be expanded from MXCXH->MXCXHx1
-    weight_shape.push_back(1);
-  }
+    // https://github.com/apple/coremltools/blob/7.1/coremltools/converters/mil/mil/ops/defs/iOS15/conv.py
 
-  NodeAttrHelper helper(node);
-  auto strides = helper.Get("strides", std::vector<int64_t>{1, 1});
-  auto dilations = helper.Get("dilations", std::vector<int64_t>{1, 1});
-  auto onnx_pads = helper.Get("pads", std::vector<int64_t>{0, 0, 0, 0});
-  // Strides/dilations for 1d conv is normally of length 1. Expand them by 1
-  // to meet the required length 2 (for 2d conv it's normally 2)
-  // Similarly 1d conv normally has a length 2 padding. Expand it to length 4 by adding additional zeros.
-  if (is_1d_conv) {
-    if (strides.size() < 2) {
-      ORT_RETURN_IF_NOT(strides.size() == 1, "strides size does not equal 1 for Conv 1d");
-      strides.push_back(1);
+    std::unique_ptr<Operation> conv_op = model_builder.CreateOperation(node, "conv");
+
+    AddOperationInput(*conv_op, "x", input_name);
+    AddOperationInput(*conv_op, "weight", input_defs[1]->Name());
+
+    if (input_defs.size() > 2) {
+      AddOperationInput(*conv_op, "bias", input_defs[2]->Name());
     }
-    if (dilations.size() < 2) {
-      ORT_RETURN_IF_NOT(dilations.size() == 1, "dilations size does not equal 1 for Conv 1d");
-      dilations.push_back(1);
+
+    // we know this input has a valid shape due to the check in IsOpSupportedImpl. ignore N and C dims.
+    const auto num_spatial_dims = input_defs[1]->Shape()->dim_size() - 2;
+    const auto& op_type = conv_op->type();
+
+    // Spec says strides and dilations are optional, but reality is they're required for at least the iOS15 target
+    // (CoreML5).
+    const auto strides = helper.Get("strides", std::vector<int64_t>(num_spatial_dims, 1));
+    auto dilations = helper.Get("dilations", std::vector<int64_t>(num_spatial_dims, 1));
+    auto groups = helper.GetInt64("group");
+
+    AddOperationInput(*conv_op, "strides", model_builder.AddConstant(op_type, "strides", strides));
+    AddOperationInput(*conv_op, "dilations", model_builder.AddConstant(op_type, "dilations", dilations));
+
+    if (groups) {
+      AddOperationInput(*conv_op, "groups", model_builder.AddScalarConstant(op_type, "groups", *groups));
     }
-    if (onnx_pads.size() < 4) {
-      ORT_RETURN_IF_NOT(onnx_pads.size() == 2, "onnx_pads size does not equal 2 for Conv 1d");
-      onnx_pads.insert(onnx_pads.begin() + 1, 0);
-      onnx_pads.push_back(0);
+
+    AddPadTypeAndPads(*conv_op, model_builder, op_type, helper, num_spatial_dims);
+
+    AddOperationOutput(*conv_op, *node.OutputDefs()[0]);
+
+    model_builder.AddOperation(std::move(conv_op));
+  } else
+#endif  // defined(COREML_ENABLE_MLPROGRAM)
+  {
+    std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
+
+    auto strides = helper.Get("strides", std::vector<int64_t>{1, 1});
+    auto dilations = helper.Get("dilations", std::vector<int64_t>{1, 1});
+    auto onnx_pads = helper.Get("pads", std::vector<int64_t>{0, 0, 0, 0});
+    const auto group = helper.Get("group", static_cast<int64_t>(1));
+
+    std::vector<int64_t> input_shape;
+    ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
+
+    const auto& weight_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name());
+    std::vector<int64_t> weight_shape = {weight_tensor.dims().cbegin(), weight_tensor.dims().cend()};
+
+    const bool is_1d_conv = (weight_shape.size() == 3);
+
+    // add dummy 'W' dim with value of 1 so we can use 2D conv.
+    if (is_1d_conv) {
+      input_shape.push_back(1);
+      weight_shape.push_back(1);
+
+      // Strides/dilations for 1d conv is normally of length 1. Expand them by 1
+      // to meet the required length 2 (for 2d conv it's normally 2)
+      if (strides.size() < 2) {
+        ORT_RETURN_IF_NOT(strides.size() == 1, "strides size does not equal 1 for Conv 1d");
+        strides.push_back(1);
+      }
+
+      if (dilations.size() < 2) {
+        ORT_RETURN_IF_NOT(dilations.size() == 1, "dilations size does not equal 1 for Conv 1d");
+        dilations.push_back(1);
+      }
+
+      // Similarly 1d conv normally has a length 2 padding. Expand it to length 4 by adding additional zeros.
+      if (onnx_pads.size() < 4) {
+        ORT_RETURN_IF_NOT(onnx_pads.size() == 2, "onnx_pads size does not equal 2 for Conv 1d");
+        onnx_pads.insert(onnx_pads.begin() + 1, 0);
+        onnx_pads.push_back(0);
+      }
     }
-  }
-  const auto group = helper.Get("group", static_cast<int64_t>(1));
-
-  auto* coreml_conv = layer->mutable_convolution();
-
-  std::string expand_output_name = model_builder.GetUniqueName(node.Name() + "_expandDims");
-
-  if (is_1d_conv) {
-    const auto expand_layer_name = model_builder.GetUniqueName(MakeString(node.Name(), "_Conv_expand"));
-    std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> expand_layer = CreateNNLayer(expand_layer_name);
-    // Add an expanddims layer here. CoreML only supports 2d convolution, so for 1d Conv case
-    // we need to add an additional dimension here to the input to make it "2d Conv" like.
-    // NxCxH -> NxCxHx1
-    expand_layer->mutable_expanddims()->add_axes(-1);
-    *expand_layer->mutable_input()->Add() = input_name;
-    *expand_layer->mutable_output()->Add() = expand_output_name;
-    model_builder.AddLayer(std::move(expand_layer));
-  }
-  coreml_conv->set_outputchannels(weight_shape[0]);  // M
-  coreml_conv->set_kernelchannels(weight_shape[1]);  // C/Group
-  coreml_conv->add_kernelsize(weight_shape[2]);      // H
-  coreml_conv->add_kernelsize(weight_shape[3]);      // W
-  coreml_conv->set_ngroups(group);
-  *coreml_conv->mutable_stride() = {strides.cbegin(), strides.cend()};
-  *coreml_conv->mutable_dilationfactor() = {dilations.cbegin(), dilations.cend()};
-
-  coreml_conv->set_isdeconvolution(false);
-
-  // Add Padding
-  // Usually using autopadding is more efficient than using explicit padding
-  // Try to see if we can map explicit padding to auto padding
-  std::vector<int64_t> input_shape;
-  ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
-  AutoPadType auto_pad_type;
-  ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3],
-                                    onnx_pads, strides, dilations,
-                                    StringToAutoPadType(helper.Get("auto_pad", "NOTSET")),
-                                    auto_pad_type));
-
-  if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) {
-    auto* padding_type = coreml_conv->mutable_same();
-    if (AutoPadType::SAME_LOWER == auto_pad_type) {  // default is SAME_UPPER
-      padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY);
+
+    auto* coreml_conv = layer->mutable_convolution();
+
+    std::string expand_output_name = model_builder.GetUniqueName(node.Name() + "_expandDims");
+
+    if (is_1d_conv) {
+      // Add an expanddims layer here. CoreML only supports 2d convolution, so for 1d Conv case
+      // we need to add an additional dimension here to the input to make it "2d Conv" like.
+      // NxCxH -> NxCxHx1
+      auto expand_layer = model_builder.CreateNNLayer(node, "_Conv_expand");
+      expand_layer->mutable_expanddims()->add_axes(-1);
+      *expand_layer->mutable_input()->Add() = input_name;
+      *expand_layer->mutable_output()->Add() = expand_output_name;
+      model_builder.AddLayer(std::move(expand_layer));
     }
-  } else {
-    auto* padding_type = coreml_conv->mutable_valid();
-    if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector<int64_t>{0, 0, 0, 0}) {
-      // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts
-      auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts();
-      height_border->set_startedgesize(onnx_pads[0]);
-      height_border->set_endedgesize(onnx_pads[2]);
-      auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts();
-      width_border->set_startedgesize(onnx_pads[1]);
-      width_border->set_endedgesize(onnx_pads[3]);
+
+    coreml_conv->set_outputchannels(weight_shape[0]);  // M
+    coreml_conv->set_kernelchannels(weight_shape[1]);  // C/Group
+    coreml_conv->add_kernelsize(weight_shape[2]);      // H
+    coreml_conv->add_kernelsize(weight_shape[3]);      // W
+    coreml_conv->set_ngroups(group);
+    *coreml_conv->mutable_stride() = {strides.cbegin(), strides.cend()};
+    *coreml_conv->mutable_dilationfactor() = {dilations.cbegin(), dilations.cend()};
+
+    coreml_conv->set_isdeconvolution(false);
+
+    // Add Padding
+    // Usually using autopadding is more efficient than using explicit padding
+    // Try to see if we can map explicit padding to auto padding
+    AutoPadType auto_pad_type;
+    ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3],
+                                      onnx_pads, strides, dilations,
+                                      StringToAutoPadType(helper.Get("auto_pad", "NOTSET")),
+                                      auto_pad_type));
+
+    if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) {
+      auto* padding_type = coreml_conv->mutable_same();
+      if (AutoPadType::SAME_LOWER == auto_pad_type) {  // default is SAME_UPPER
+        padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY);
+      }
+    } else {
+      auto* padding_type = coreml_conv->mutable_valid();
+      if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector<int64_t>{0, 0, 0, 0}) {
+        // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts
+        auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts();
+        height_border->set_startedgesize(onnx_pads[0]);
+        height_border->set_endedgesize(onnx_pads[2]);
+        auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts();
+        width_border->set_startedgesize(onnx_pads[1]);
+        width_border->set_endedgesize(onnx_pads[3]);
+      }
     }
-  }
 
-  // Add weight
-  ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_weights(), weight_tensor));
+    // Add weight
+    ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_weights(), weight_tensor));
 
-  // Add bias if present
-  if (input_defs.size() > 2) {
-    coreml_conv->set_hasbias(true);
-    const auto& bias_tensor = *model_builder.GetInitializerTensors().at(input_defs[2]->Name());
-    ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_bias(), bias_tensor));
-  }
+    // Add bias if present
+    if (input_defs.size() > 2) {
+      coreml_conv->set_hasbias(true);
+      const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name());
+      ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_bias(), bias_tensor));
+    }
 
-  if (is_1d_conv) {
-    std::string conv_output_name = model_builder.GetUniqueName(node.Name() + "_conv_output");
-    *layer->mutable_input()->Add() = expand_output_name;
-    *layer->mutable_output()->Add() = conv_output_name;
-    model_builder.AddLayer(std::move(layer));
-
-    // Add a squeeze layer here. Since CoreML only supports 2d conv and we expanded the dimension by 1 before,
-    // we need to squeeze it back from NxCxHx1->NxCxH.
-    const auto squeeze_layer_name = model_builder.GetUniqueName(MakeString(node.Name(), "_Conv_squeeze"));
-    std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> squeeze_layer = CreateNNLayer(squeeze_layer_name);
-    squeeze_layer->mutable_squeeze()->add_axes(-1);
-    *squeeze_layer->mutable_input()->Add() = conv_output_name;
-    *squeeze_layer->mutable_output()->Add() = output_name;
-    model_builder.AddLayer(std::move(squeeze_layer));
-  } else {
-    *layer->mutable_input()->Add() = input_name;
-    *layer->mutable_output()->Add() = output_name;
-    model_builder.AddLayer(std::move(layer));
+    if (is_1d_conv) {
+      std::string conv_output_name = model_builder.GetUniqueName(node.Name() + "_conv_output");
+      *layer->mutable_input()->Add() = expand_output_name;
+      *layer->mutable_output()->Add() = conv_output_name;
+      model_builder.AddLayer(std::move(layer));
+
+      // Add a squeeze layer here. Since CoreML only supports 2d conv and we expanded the dimension by 1 before,
+      // we need to squeeze it back from NxCxHx1->NxCxH.
+      auto squeeze_layer = model_builder.CreateNNLayer(node, "_Conv_squeeze");
+      squeeze_layer->mutable_squeeze()->add_axes(-1);
+      *squeeze_layer->mutable_input()->Add() = conv_output_name;
+      *squeeze_layer->mutable_output()->Add() = output_name;
+      model_builder.AddLayer(std::move(squeeze_layer));
+    } else {
+      *layer->mutable_input()->Add() = input_name;
+      *layer->mutable_output()->Add() = output_name;
+      model_builder.AddLayer(std::move(layer));
+    }
   }
 
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                                       const logging::Logger& logger) const {
@@ -186,23 +223,73 @@ bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara
   const auto& input_defs = node.InputDefs();
 
   const auto& weight_name = input_defs[1]->Name();
-  const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors();
-  if (Contains(initializers, weight_name)) {
-    const auto& tensor = *initializers.at(weight_name);
-    if (tensor.dims().size() != 4 && tensor.dims().size() != 3) {
-      LOGS(logger, VERBOSE) << "Conv [" << name << "] dimension: " << tensor.dims().size()
-                            << " Only conv 2d and conv 1d are supported.";
+  const auto* weight = input_params.graph_viewer.GetConstantInitializer(weight_name);
+
+#if defined(COREML_ENABLE_MLPROGRAM)
+  if (input_params.create_mlprogram) {
+    // ML Program supports non-const weight, 1D, 2D and 3D.
+    // keep to 1D and 2D for consistency with the NeuralNetwork implementation for now.
+    // add 3D support as/when needed.
+  } else
+#endif  // defined (COREML_ENABLE_MLPROGRAM)
+  {
+    if (!weight) {
+      LOGS(logger, VERBOSE) << "The weight of Conv [" << name << "] must be a constant initializer";
       return false;
     }
-  } else {
-    LOGS(logger, VERBOSE) << "The weight of Conv [" << name << "] must be known";
+  }
+
+  // use the weight for the shape as it should always be known
+  const auto* weight_shape = input_defs[1]->Shape();
+  int64_t num_dims = weight_shape ? weight_shape->dim_size() : -1;
+
+  // ONNX spec requires N and C as first 2 dims
+  if (num_dims != 3 && num_dims != 4) {
+    LOGS(logger, VERBOSE) << "Conv [" << name << "] is " << num_dims - 2 << "D. "
+                          << "Only 1D and 2D Conv are supported currently.";
     return false;
   }
 
-  if (input_defs.size() > 2) {
-    const auto& bias_name = input_defs[2]->Name();
-    if (!Contains(initializers, bias_name)) {
-      LOGS(logger, VERBOSE) << "The bias of Conv [" << name << "] must be a constant initializer";
+  if (input_defs.size() > 2 && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name())) {
+    LOGS(logger, VERBOSE) << "The bias of Conv [" << name << "] must be a constant initializer";
+    return false;
+  }
+
+  NodeAttrHelper helper(node);
+
+#if defined(COREML_ENABLE_MLPROGRAM)
+  // spec says same_lower is supported in CoreML 5. it lies. CoreML 6 is required otherwise you get
+  //   `Unexpected value for parameter pad_type[0] "same_lower" not in ("custom", "same", "valid").`
+  // We _could_ manually calculate the pads, but not implementing that until we have a real use case to justify
+  //  the effort as it's not clear how common usage of same_lower is.
+  if (input_params.create_mlprogram && input_params.coreml_version < 6) {
+    if (StringToAutoPadType(helper.Get("auto_pad", "NOTSET")) == AutoPadType::SAME_LOWER) {
+      LOGS(logger, VERBOSE) << "Pad type of SAME_LOWER  [" << name << "] is not supported until CoreML 6."
+                            << "Available version is CoreML " << input_params.coreml_version;
+      return false;
+    }
+  }
+#endif
+
+  // there's no equivalent to allow a manual kernel shape in CoreML.
+  // it's OK if a specified kernel_shape matches kH and kW dims of the weight input.
+  auto kernel_shape = helper.GetInt64s("kernel_shape");
+  if (kernel_shape) {
+    bool valid = true;
+    if (static_cast<int64_t>(kernel_shape->size()) == num_dims - 2) {
+      for (int i = 0; i < num_dims - 2; ++i) {
+        // check the specified kernel shape matches the weight shape. skip the initial N and C dims in the latter.
+        if ((*kernel_shape)[i] != weight_shape->dim()[i + 2].dim_value()) {
+          valid = false;
+          break;
+        }
+      }
+    } else {
+      valid = false;
+    }
+
+    if (!valid) {
+      LOGS(logger, VERBOSE) << "Conv [" << name << "] kernel_shape attribute does not match the weight shape";
       return false;
     }
   }
diff --git a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc
index a4ad1c31b5027..1eba312b2577b 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc
@@ -4,37 +4,26 @@
 #include "core/common/safeint.h"
 #include "core/providers/coreml/builders/helper.h"
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/shared/utils/utils.h"
 
-#ifdef __APPLE__
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime {
 namespace coreml {
 
 class DepthToSpaceOpBuilder : public BaseOpBuilder {
-  // Add operator related
- private:
-#ifdef __APPLE__
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 };
 
-// Add operator related
-
-#ifdef __APPLE__
 Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
                                                     const Node& node,
                                                     const logging::Logger& /* logger */) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
 
   const auto& input_defs = node.InputDefs();
   const auto& output_defs = node.OutputDefs();
@@ -54,9 +43,6 @@ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
   model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
                                               const logging::Logger& logger) const {
diff --git a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc
index b303fe7884cb1..f0adb70587bcf 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc
@@ -3,39 +3,26 @@
 
 #include "core/providers/coreml/builders/helper.h"
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/shared/utils/utils.h"
 
-#ifdef __APPLE__
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime {
 namespace coreml {
 
 class FlattenOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 };
 
-// Add operator related
-
-#ifdef __APPLE__
-
 Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
                                                const Node& node,
-                                               const logging::Logger& logger) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+                                               const logging::Logger& /*logger*/) const {
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
 
   // Note: ONNX Flatten corresponds to CoreML FlattenTo2DLayerParams
   auto* coreml_flatten = layer->mutable_flattento2d();
@@ -51,9 +38,6 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
 
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 bool FlattenOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
                                          const logging::Logger& logger) const {
diff --git a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc
index 9c7ec306ca093..7d32675e3e510 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc
@@ -2,34 +2,24 @@
 // Licensed under the MIT License.
 
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
-
 #include "core/providers/coreml/builders/op_builder_factory.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/shared/utils/utils.h"
 
-#if defined(__APPLE__)
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime::coreml {
 
 class GatherOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
-  bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override;
+  bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params,
+                              const logging::Logger& logger) const override;
+
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 };
 
-// Add operator related
-#if defined(__APPLE__)
 namespace {
 int64_t GetAxisAttribute(const Node& node) {
   NodeAttrHelper node_attr_helper{node};
@@ -38,8 +28,8 @@ int64_t GetAxisAttribute(const Node& node) {
 }  // namespace
 
 Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
-                                              const logging::Logger& logger) const {
-  auto layer = CreateNNLayer(model_builder, node);
+                                              const logging::Logger& /*logger*/) const {
+  auto layer = model_builder.CreateNNLayer(node);
   layer->mutable_gather()->set_axis(GetAxisAttribute(node));
   *layer->mutable_input()->Add() = node.InputDefs()[0]->Name();    // data
   *layer->mutable_input()->Add() = node.InputDefs()[1]->Name();    // indices
@@ -47,10 +37,9 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
   model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif  // defined(__APPLE__)
 
-// Operator support related
-bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const {
+bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
+                                             const logging::Logger& logger) const {
   int32_t input_type;
   if (!GetType(*node.InputDefs()[0], input_type, logger))
     return false;
diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc
index 71b08db6d44d8..8daf64dc4a457 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc
@@ -7,46 +7,66 @@
 #include "core/providers/common.h"
 #include "core/providers/coreml/builders/helper.h"
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/impl/builder_utils.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/shared/utils/utils.h"
 
-#ifdef __APPLE__
-#include "core/providers/coreml/builders/impl/builder_utils.h"
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime {
 namespace coreml {
 
 class GemmOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- public:
   void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
 
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
-  bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */,
-                         const logging::Logger& /* logger */) const override;
-};
+  bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
+                         const logging::Logger& logger) const override;
 
-// Add operator related
+  bool SupportsMLProgram() const override { return true; }
+};
 
-#ifdef __APPLE__
 void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
   const auto& op = node.OpType();
   const auto& input_defs(node.InputDefs());
-  // We have already embedded the weights (matrix B and C(if any)) into the coreml layer
-  // No need to copy them later to reduce memory consumption
-  model_builder.AddInitializerToSkip(input_defs[1]->Name());
-  if (op == "Gemm" && input_defs.size() > 2) {
-    model_builder.AddInitializerToSkip(input_defs[2]->Name());
+  const bool is_gemm = op == "Gemm";
+
+#if defined(COREML_ENABLE_MLPROGRAM)
+  if (model_builder.CreateMLProgram()) {
+    // we have to transpose the weight input of Gemm if transB is false, and potentially override the bias shape
+    if (is_gemm) {
+      NodeAttrHelper helper(node);
+      const auto transB = helper.Get("transB", 0);
+      if (transB == 0) {
+        model_builder.AddInitializerToSkip(input_defs[1]->Name());
+      }
+
+      if (input_defs.size() > 2) {
+        // ONNX spec requires B to be 2D and we required it to be a constant initializer so reading N this way is safe
+        // B is {K, N] by default. or {N, K} if transB is true
+        int N_dim = transB ? 0 : 1;
+        int64_t N = input_defs[1]->Shape()->dim().at(N_dim).dim_value();
+
+        const auto& bias_name = input_defs[2]->Name();
+        const auto& bias = *model_builder.GetConstantInitializer(bias_name);
+        if (bias.dims_size() != 1 || bias.dims(0) != N) {
+          // we have to override the shape/duplicate data to convert {}, {1} or {1, N} to 1D {N}
+          // when adding the Gemm operation so skip adding the original initializer
+          model_builder.AddInitializerToSkip(bias_name);
+        }
+      }
+    }
+  } else
+#endif  // defined(COREML_ENABLE_MLPROGRAM)
+  {
+    // We have already embedded the weights (matrix B and C(if any)) into the coreml layer
+    // No need to copy them later to reduce memory consumption
+    model_builder.AddInitializerToSkip(input_defs[1]->Name());
+    if (is_gemm && input_defs.size() > 2) {
+      model_builder.AddInitializerToSkip(input_defs[2]->Name());
+    }
   }
 }
 
@@ -70,156 +90,258 @@ static Status GetTensorFloatDataTransposed(const ONNX_NAMESPACE::TensorProto& te
 }
 
 Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
-                                            const logging::Logger& /* logger */) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+                                            const logging::Logger& logger) const {
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
 
   const auto& op_type = node.OpType();
   const auto& input_defs = node.InputDefs();
-  const auto& b_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name());
-  const auto& b_shape = b_tensor.dims();
-
-  auto* coreml_inner_product = layer->mutable_innerproduct();
-
-  // The coreml innerproduct weight (matrix B) is stored transposed
-  // - for MatMul and Gemm (transB = 0), the coreml weight is B'
-  // - for Gemm (transB = 1), the coreml weight is B
-  if (op_type == "MatMul") {
-    coreml_inner_product->set_inputchannels(b_shape[0]);
-    coreml_inner_product->set_outputchannels(b_shape[1]);
-    // Add weight (b of MatMul)
-    std::vector<float> b_transposed;
-    ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(b_tensor, b_transposed));
-    CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed);
-  } else {  // Gemm
-    NodeAttrHelper helper(node);
-    const auto transB = helper.Get("transB", 0);
-    if (transB == 0) {
-      coreml_inner_product->set_inputchannels(b_shape[0]);
-      coreml_inner_product->set_outputchannels(b_shape[1]);
+  const auto& a = *input_defs[0];
+  const auto& b = *input_defs[1];
+  const auto* b_initializer = model_builder.GetConstantInitializer(b.Name());  // MLProgram MatMul may not be constant
+
+  const bool is_matmul = op_type == "MatMul";
+  const bool is_gemm = op_type == "Gemm";
+
+  NodeAttrHelper helper(node);
+  const auto transB = is_gemm ? helper.Get("transB", 0) : 0;
+
+  std::vector<int64_t> b_shape;
+  ORT_IGNORE_RETURN_VALUE(GetShape(b, b_shape, logger));
+  int64_t b0 = -1, b1 = -1;
+
+  // ML Program MatMul supports N-D input
+  if (model_builder.CreateMLProgram() && is_matmul) {
+    if (b_shape.size() == 1) {
+      // B is treated as {b_shape[0], 1} according to the numpy rules.
+      b0 = b_shape[0];
+      b1 = 1;
+    } else {
+      // last 2 dims are used
+      b0 = b_shape[b_shape.size() - 2];
+      b1 = b_shape[b_shape.size() - 1];
+    }
+  } else {
+    // we only support 2D input
+    b0 = b_shape[0];
+    b1 = b_shape[1];
+  }
+
+  // B is {K, N} in ONNX spec by default, or {N, K} in Gemm if transB is true
+  const auto K = transB ? b1 : b0;
+  const auto N = transB ? b0 : b1;
+
+#if defined(COREML_ENABLE_MLPROGRAM)
+  if (model_builder.CreateMLProgram()) {
+    using namespace CoreML::Specification::MILSpec;
+
+    if (is_gemm) {
+      // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.linear.linear
+      auto gemm_op = model_builder.CreateOperation(node, "linear");
+      AddOperationInput(*gemm_op, "x", a.Name());
+
+      // CoreML takes weight input as {N, K} which is the reverse of ONNX.
+      // if transB is true the input weight is {N, K} so can be added directly.
+      if (transB) {
+        AddOperationInput(*gemm_op, "weight", b.Name());
+      } else {
+        // transpose from {K, N} to {N, K}
+        std::vector<float> weight_nk;
+        std::vector<int64_t> weight_nk_shape = {N, K};
+        ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, weight_nk));
+
+        AddOperationInput(*gemm_op, "weight",
+                          model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape));
+      }
+
+      if (input_defs.size() == 3) {
+        const auto& bias_arg = *input_defs[2];
+        const auto& bias = *model_builder.GetConstantInitializer(bias_arg.Name());
+
+        // CoreML linear op requires bias to be 1D tensor of size N
+        if (bias.dims_size() == 1 && bias.dims().at(0) == N) {
+          // can use existing initializer
+          AddOperationInput(*gemm_op, "bias", bias_arg.Name());
+        } else {
+          Initializer unpacked_tensor(bias);
+          auto bias_data = unpacked_tensor.DataAsSpan<float>();
+          std::string_view bias_data_name;
+          if (bias_data.size() == 1) {
+            // expand scalar to N
+            std::vector<float> expanded_bias_data(N, bias_data[0]);
+            bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data);
+          } else {
+            // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()})
+            bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data);
+          }
+
+          AddOperationInput(*gemm_op, "bias", bias_data_name);
+        }
+      }
+
+      AddOperationOutput(*gemm_op, *node.OutputDefs()[0]);
+      model_builder.AddOperation(std::move(gemm_op));
+    } else {
+      // CoreML implementation is the same as ONNX MatMul.
+      // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.linear.matmul
+      auto matmul_op = model_builder.CreateOperation(node, "matmul");
+      AddOperationInput(*matmul_op, "x", a.Name());
+      AddOperationInput(*matmul_op, "y", b.Name());
+
+      // once again the spec lies and says transpose_y and transpose_x are optional...
+      auto false_value_name = model_builder.AddScalarConstant(matmul_op->type(), "false", false);
+      AddOperationInput(*matmul_op, "transpose_x", false_value_name);
+      AddOperationInput(*matmul_op, "transpose_y", false_value_name);
+
+      AddOperationOutput(*matmul_op, *node.OutputDefs()[0]);
+      model_builder.AddOperation(std::move(matmul_op));
+    }
+  } else
+#endif  // defined(COREML_ENABLE_MLPROGRAM)
+  {
+    auto* coreml_inner_product = layer->mutable_innerproduct();
+
+    *layer->mutable_input()->Add() = a.Name();
+
+    coreml_inner_product->set_inputchannels(K);
+    coreml_inner_product->set_outputchannels(N);
+
+    // CoreML takes weight input as {N, K} which is the reverse of ONNX.
+    // if Gemm's transB is true the input weight is {N, K} and can be added directly.
+    if (transB) {
+      ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), *b_initializer));
+    } else {
       std::vector<float> b_transposed;
-      ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(b_tensor, b_transposed));
+      ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, b_transposed));
       CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed);
-    } else {
-      coreml_inner_product->set_inputchannels(b_shape[1]);
-      coreml_inner_product->set_outputchannels(b_shape[0]);
-      // Add weight (b of MatMul)
-      ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_tensor));
     }
 
-    // Add bias if present
-    if (input_defs.size() > 2) {
+    if (is_gemm && input_defs.size() > 2) {
+      // Add bias
       coreml_inner_product->set_hasbias(true);
-      const auto& bias_tensor = *model_builder.GetInitializerTensors().at(input_defs[2]->Name());
-      ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_bias(), bias_tensor));
+      const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name());
+
+      // if scalar, or single value expand to 1D tensor of size N
+      // IsOpSupportedImpl enforces it's scalar, {1}, {N}, or {1, N}.
+      Initializer unpacked_tensor(bias_tensor);
+      auto bias_data = unpacked_tensor.DataAsSpan<float>();
+      if (bias_data.size() == 1 && N > 1) {
+        std::vector<float> expanded_bias_data(N, bias_data[0]);
+        CreateCoreMLWeight(*coreml_inner_product->mutable_bias(), expanded_bias_data);
+      } else {
+        CreateCoreMLWeight(*coreml_inner_product->mutable_bias(), bias_data);
+      }
     }
-  }
 
-  *layer->mutable_input()->Add() = input_defs[0]->Name();
-  *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
+    *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
+    model_builder.AddLayer(std::move(layer));
+  }
 
-  model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 bool GemmOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                                       const logging::Logger& logger) const {
   const auto& op_type = node.OpType();
   const auto& input_defs(node.InputDefs());
+  const bool is_matmul = op_type == "MatMul";
+  const bool is_gemm = op_type == "Gemm";
+
   size_t a_idx = 0, b_idx = 1, c_idx = 2;  // A*B+C
 
-  const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors();
-  if (!Contains(initializers, input_defs[b_idx]->Name())) {
-    LOGS(logger, VERBOSE) << "B of Gemm/Matmul must be an initializer tensor";
+  std::vector<int64_t> a_shape;
+  if (!GetShape(*input_defs[a_idx], a_shape, logger)) {
     return false;
   }
 
-  std::vector<int64_t> a_shape;
-  {
-    if (!GetShape(*input_defs[a_idx], a_shape, logger))
-      return false;
-
-    if (a_shape.size() != 2) {
-      LOGS(logger, VERBOSE) << "A must be 2D";
-      return false;
-    }
+  std::vector<int64_t> b_shape;
+  if (!GetShape(*input_defs[b_idx], b_shape, logger)) {
+    return false;
+  }
 
-    // TODO is it ok if the shape is dynamic and empty?
-    if (Product(a_shape) == 0) {
-      LOGS(logger, VERBOSE) << "A must be non-empty";
+  if (!input_params.graph_viewer.GetConstantInitializer(input_defs[b_idx]->Name())) {
+    if (input_params.create_mlprogram && is_matmul) {
+      // ML Program MatMul allows non-constant B input
+    } else {
+      LOGS(logger, VERBOSE) << op_type << " B input must be a constant initializer";
       return false;
     }
   }
 
-  std::vector<int64_t> b_shape;
-  {
-    if (!GetShape(*input_defs[b_idx], b_shape, logger))
-      return false;
-
-    if (b_shape.size() != 2) {
-      LOGS(logger, VERBOSE) << "B must be 2D";
-      return false;
-    }
+  if (is_matmul) {
+    if (input_params.create_mlprogram) {
+      // ML Program matmul op has numpy semantics the same as the ONNX spec so we can use directly
+    } else {
+      // we could potentially support 1D and 3D if required. beyond 3D the dims that merge diverge.
+      // https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/onnx/_operators.py#L1607
+      // https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/mil/backend/nn/op_mapping.py#L1374
+      // https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#innerproductlayerparams
+      if (a_shape.size() != 2 || b_shape.size() != 2) {
+        LOGS(logger, VERBOSE) << "a and b inputs must be 2D. ";
+        return false;
+      }
 
-    if (Product(b_shape) == 0) {
-      LOGS(logger, VERBOSE) << "B must be non-empty";
-      return false;
+      if (input_defs.size() > 2) {
+        LOGS(logger, VERBOSE) << "MatMul with C input is not supported";
+        return false;
+      }
     }
   }
 
-  if (op_type == "Gemm") {
+  if (is_gemm) {
+    // A and B are 2D due to the ONNX spec
     NodeAttrHelper helper(node);
     const auto transA = helper.Get("transA", 0);
     const auto transB = helper.Get("transB", 0);
     const auto alpha = helper.Get("alpha", 1.0f);
     const auto beta = helper.Get("beta", 1.0f);
+
+    // TODO: We can support transA, alpha and beta by using multiple layers/operations if needed.
     if (!(transA == 0 && alpha == 1.f && beta == 1.f)) {
-      LOGS(logger, VERBOSE) << "Only transA == 0, alpha == 1.0 "
-                            << "and beta == 1.0 is supported."
+      LOGS(logger, VERBOSE) << "Only support for transA == 0, alpha == 1.0 "
+                            << "and beta == 1.0 is currently implemented."
                             << " transA " << transA
                             << " alpha " << alpha
                             << " beta " << beta;
       return false;
     }
 
-    // C of Gemm
-    // For now we only support {n} or {1,n} tensor
     if (input_defs.size() == 3) {
-      if (!Contains(initializers, input_defs[c_idx]->Name())) {
-        LOGS(logger, VERBOSE) << "C of Gemm must be an initializer tensor";
+      if (!input_params.graph_viewer.GetConstantInitializer(input_defs[c_idx]->Name())) {
+        LOGS(logger, VERBOSE) << "C of Gemm must be a constant initializer";
         return false;
       }
 
       std::vector<int64_t> c_shape;
-      if (!GetShape(*input_defs[c_idx], c_shape, logger))
+      if (!GetShape(*input_defs[c_idx], c_shape, logger)) {
         return false;
+      }
 
-      size_t c_dim = c_shape.size();
+      // B is {K, N} in ONNX spec by default, or {N, K} in Gemm if transB is true
+      const auto N = transB ? b_shape[0] : b_shape[1];
 
-      if (c_dim == 0) {
-        LOGS(logger, VERBOSE) << "C of Gemm cannot be a scalar";
-        return false;
-      }
+      size_t c_rank = c_shape.size();
 
-      if (c_dim != 1) {
-        // If C is a (2+)d tensor, it must have the format {1, 1, ..., 1, n}
-        // where every except the last dimension should be 1
-        for (size_t i = 0; i < c_dim - 1; ++i) {
-          if (c_shape[i] != 1) {
-            LOGS(logger, VERBOSE) << "C of Gemm must be a vector or a tensor with only last dimension != 1";
-            return false;
+      // allowed: scalar, or 1D where the value is 1 or N, 2D with shape {1, N}
+      bool c_valid = false;
+      switch (c_rank) {
+        case 0:
+          c_valid = true;
+          break;
+        case 1:
+          if (c_shape[0] == 1 || c_shape[0] == N) {
+            c_valid = true;
           }
-        }
+          break;
+        case 2:
+          if (c_shape[0] == 1 && c_shape[1] == N) {
+            c_valid = true;
+          }
+          break;
       }
 
-      auto c_size = c_shape[c_dim - 1];
-      if (c_size != (transB == 0 ? b_shape[1] : b_shape[0])) {
-        LOGS(logger, VERBOSE) << "C of Gemm must be a vector of b_shape["
-                              << (transB == 0 ? "1" : "0") << "]"
-                              << " b_shape: [" << b_shape[0] << ", " << b_shape[1] << "]"
-                              << " c_size: " << c_size;
+      if (!c_valid) {
+        LOGS(logger, VERBOSE) << "Shape of C Gemm input must be {}, {1}, {N}, or {1, N}. N:" << N << " C shape:"
+                              << Shape2String(c_shape);
 
         return false;
       }
diff --git a/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc
index ba12600e8bc40..99d6f01cb8c5b 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc
@@ -7,30 +7,20 @@
 #include "core/providers/common.h"
 #include "core/providers/coreml/builders/helper.h"
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/shared/utils/utils.h"
 
-#ifdef __APPLE__
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime {
 namespace coreml {
 
 class PadOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- public:
   void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
 
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 
@@ -64,9 +54,6 @@ static InlinedVector<int64_t> GetPaddingAxesData(const InitializedTensorSet& ini
   return axes_tensor_data;
 }
 
-// Add operator related
-
-#ifdef __APPLE__
 void PadOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
   model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());  //  pads
   model_builder.AddInitializerToSkip(node.InputDefs()[2]->Name());  //  constant_value
@@ -78,7 +65,7 @@ void PadOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node
 Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
                                            const Node& node,
                                            const logging::Logger& logger) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
 
   auto* coreml_pad = layer->mutable_padding();
   auto* constant_padding_type = coreml_pad->mutable_constant();  // CoreML::Specification::PaddingLayerParams_PaddingConstant
@@ -122,9 +109,6 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
 
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 bool PadOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                                      const logging::Logger& logger) const {
diff --git a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc
index fd1c77c851e6f..17910ba6fd486 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc
@@ -4,132 +4,191 @@
 #include "core/providers/common.h"
 #include "core/providers/coreml/builders/helper.h"
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/impl/builder_utils.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/shared/utils/utils.h"
 
-#ifdef __APPLE__
-#include "core/providers/coreml/builders/impl/builder_utils.h"
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime {
 namespace coreml {
 
 class PoolOpBuilder : public BaseOpBuilder {
-  // Add operator related
- private:
-#ifdef __APPLE__
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
-};
 
-// Add operator related
+  bool SupportsMLProgram() const override { return true; }
+};
 
-#ifdef __APPLE__
 Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
                                             const Node& node,
                                             const logging::Logger& logger) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
-
-  auto* coreml_pool = layer->mutable_pooling();
   const auto& op_type = node.OpType();
   const auto& input_defs = node.InputDefs();
 
-  bool is_global_pooling = false;
-  if (op_type == "GlobalAveragePool") {
-    is_global_pooling = true;
-    coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE);
-  } else if (op_type == "GlobalMaxPool") {
-    is_global_pooling = true;
-    coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX);
-  } else if (op_type == "AveragePool") {
-    coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE);
-  } else if (op_type == "MaxPool") {
-    coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX);
-  } else {
-    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "PoolOpBuilder, unknown op: ", op_type);
-  }
+#if defined(COREML_ENABLE_MLPROGRAM)
+  if (model_builder.CreateMLProgram()) {
+    using namespace CoreML::Specification::MILSpec;
+
+    std::string_view coreml_op_type;
+    bool is_global = false;
+    bool is_avg_pool = false;
+    if (op_type == "GlobalAveragePool") {
+      // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.reduction.reduce_mean
+      coreml_op_type = "reduce_mean";
+      is_global = true;
+    } else if (op_type == "GlobalMaxPool") {
+      // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.reduction.reduce_max
+      coreml_op_type = "reduce_max";
+      is_global = true;
+    } else if (op_type == "AveragePool") {
+      // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.pool.avg_pool
+      coreml_op_type = "avg_pool";
+      is_avg_pool = true;
+    } else if (op_type == "MaxPool") {
+      // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.pool.max_pool
+      coreml_op_type = "max_pool";
+    } else {
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "PoolOpBuilder, unexpected op: ", op_type);
+    }
 
-  if (is_global_pooling) {
-    coreml_pool->set_globalpooling(true);
-    coreml_pool->mutable_valid();
-  } else {  // AveragePool or MaxPool
-    NodeAttrHelper helper(node);
-    const auto kernel_shape = helper.Get("kernel_shape", std::vector<int64_t>{0, 0});
-    const auto strides = helper.Get("strides", std::vector<int64_t>{1, 1});
-    const auto onnx_pads = helper.Get("pads", std::vector<int64_t>{0, 0, 0, 0});
-
-    coreml_pool->add_kernelsize(kernel_shape[0]);
-    coreml_pool->add_kernelsize(kernel_shape[1]);
-    coreml_pool->add_stride(strides[0]);
-    coreml_pool->add_stride(strides[1]);
-    coreml_pool->set_avgpoolexcludepadding(helper.Get("count_include_pad", 0) == 0);
-    coreml_pool->set_globalpooling(false);
-
-    // Add Padding
-    // Usually using autopadding is more efficient than using explicit padding
-    // Try to see if we can map explicit padding to auto padding
-    std::vector<int64_t> input_shape;
-    ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
-    AutoPadType auto_pad_type;
-    ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, kernel_shape[0], kernel_shape[1],
-                                      onnx_pads, strides, {1, 1} /* dilations */,
-                                      StringToAutoPadType(helper.Get("auto_pad", "NOTSET")),
-                                      auto_pad_type));
-
-    if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) {
-      auto* padding_type = coreml_pool->mutable_same();
-      if (AutoPadType::SAME_LOWER == auto_pad_type) {  // default is SAME_UPPER
-        padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY);
+    std::unique_ptr<Operation> op = model_builder.CreateOperation(node, coreml_op_type);
+
+    AddOperationInput(*op, "x", input_defs[0]->Name());
+
+    if (is_global) {
+      // keep N and C dims, reduce the rest with keepdims=True. equivalent to the ONNX Global*Pool ops.
+      std::vector<int64_t> axes{2, 3};  // we only support 4D input currently.
+      AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", axes));
+      AddOperationInput(*op, "keep_dims", model_builder.AddScalarConstant(op->type(), "keep_dims", true));
+    } else {
+      NodeAttrHelper helper(node);
+      constexpr int num_spatial_dims = 2;  // we only support 4D. -2 for N and C dims.
+
+      AddPadTypeAndPads(*op, model_builder, op->type(), helper, num_spatial_dims);
+
+      const auto kernel_shape = helper.GetInt64s("kernel_shape");  // required
+      AddOperationInput(*op, "kernel_sizes", model_builder.AddConstant(op->type(), "kernel_sizes", *kernel_shape));
+
+      // in theory all these values are optional according to the CoreML spec but simpler to just provide default
+      // values as the actual model compilation tends to require them.
+      const auto strides = helper.Get("strides", std::vector<int64_t>(num_spatial_dims, 1));
+      const bool ceil_mode = helper.Get("ceil_mode", int64_t(0));  // convert int64_t to bool
+
+      AddOperationInput(*op, "strides", model_builder.AddConstant(op->type(), "strides", strides));
+      AddOperationInput(*op, "ceil_mode", model_builder.AddScalarConstant(op->type(), "ceil_mode", ceil_mode));
+
+      if (is_avg_pool) {
+        const bool count_exclude_pad = helper.Get("count_include_pad", int64_t(0)) == 0;
+        AddOperationInput(*op, "exclude_padding_from_average",
+                          model_builder.AddScalarConstant(op->type(), "count_exclude_pad", count_exclude_pad));
       }
+    }
+
+    AddOperationOutput(*op, *node.OutputDefs()[0]);
+    model_builder.AddOperation(std::move(op));
+
+  } else
+#endif  // defined(COREML_ENABLE_MLPROGRAM)
+  {
+    std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
+
+    auto* coreml_pool = layer->mutable_pooling();
+
+    bool is_global_pooling = false;
+    if (op_type == "GlobalAveragePool") {
+      is_global_pooling = true;
+      coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE);
+    } else if (op_type == "GlobalMaxPool") {
+      is_global_pooling = true;
+      coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX);
+    } else if (op_type == "AveragePool") {
+      coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE);
+    } else if (op_type == "MaxPool") {
+      coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX);
     } else {
-      auto* padding_type = coreml_pool->mutable_valid();
-      if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector<int64_t>{0, 0, 0, 0}) {
-        // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts
-        auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts();
-        height_border->set_startedgesize(onnx_pads[0]);
-        height_border->set_endedgesize(onnx_pads[2]);
-        auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts();
-        width_border->set_startedgesize(onnx_pads[1]);
-        width_border->set_endedgesize(onnx_pads[3]);
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "PoolOpBuilder, unexpected op: ", op_type);
+    }
+
+    if (is_global_pooling) {
+      coreml_pool->set_globalpooling(true);
+      coreml_pool->mutable_valid();
+    } else {  // AveragePool or MaxPool
+      NodeAttrHelper helper(node);
+      const auto kernel_shape = helper.Get("kernel_shape", std::vector<int64_t>{0, 0});
+      const auto strides = helper.Get("strides", std::vector<int64_t>{1, 1});
+      const auto onnx_pads = helper.Get("pads", std::vector<int64_t>{0, 0, 0, 0});
+
+      coreml_pool->add_kernelsize(kernel_shape[0]);
+      coreml_pool->add_kernelsize(kernel_shape[1]);
+      coreml_pool->add_stride(strides[0]);
+      coreml_pool->add_stride(strides[1]);
+      coreml_pool->set_avgpoolexcludepadding(helper.Get("count_include_pad", 0) == 0);
+      coreml_pool->set_globalpooling(false);
+
+      // Add Padding
+      // Usually using autopadding is more efficient than using explicit padding
+      // Try to see if we can map explicit padding to auto padding
+      std::vector<int64_t> input_shape;
+      ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
+      AutoPadType auto_pad_type;
+      ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, kernel_shape[0], kernel_shape[1],
+                                        onnx_pads, strides, {1, 1} /* dilations */,
+                                        StringToAutoPadType(helper.Get("auto_pad", "NOTSET")),
+                                        auto_pad_type));
+
+      if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) {
+        auto* padding_type = coreml_pool->mutable_same();
+        if (AutoPadType::SAME_LOWER == auto_pad_type) {  // default is SAME_UPPER
+          padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY);
+        }
+      } else {
+        auto* padding_type = coreml_pool->mutable_valid();
+        if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector<int64_t>{0, 0, 0, 0}) {
+          // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts
+          auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts();
+          height_border->set_startedgesize(onnx_pads[0]);
+          height_border->set_endedgesize(onnx_pads[2]);
+          auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts();
+          width_border->set_startedgesize(onnx_pads[1]);
+          width_border->set_endedgesize(onnx_pads[3]);
+        }
       }
     }
-  }
 
-  *layer->mutable_input()->Add() = input_defs[0]->Name();
-  *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
+    *layer->mutable_input()->Add() = input_defs[0]->Name();
+    *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
+
+    model_builder.AddLayer(std::move(layer));
+  }
 
-  model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif
 
-// Operator support related
-bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */,
+bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                                       const logging::Logger& logger) const {
   const auto& op_type = node.OpType();
   const auto& input_defs = node.InputDefs();
 
   std::vector<int64_t> input_shape;
-  if (!GetShape(*input_defs[0], input_shape, logger))
+  if (!GetShape(*input_defs[0], input_shape, logger)) {
     return false;
+  }
 
+  // TODO: ML Program supports 3D and 5D. Add if we have a use case for that.
   const auto input_size = input_shape.size();
   if (input_size != 4) {
-    LOGS(logger, VERBOSE)
-        << op_type << " only supports rank-4 tensor, input ["
-        << input_defs[0]->Name() << "] has actual dim count " << input_size;
+    LOGS(logger, VERBOSE) << op_type << " only supports rank-4 tensor, input ["
+                          << input_defs[0]->Name() << "] has actual dim count " << input_size;
     return false;
   }
 
   if (op_type == "AveragePool" || op_type == "MaxPool") {
     NodeAttrHelper helper(node);
+
     const auto storage_order = helper.Get("storage_order", 0);
     if (storage_order == 1) {
       LOGS(logger, VERBOSE) << "storage_order == 1 is not supported";
@@ -141,12 +200,14 @@ bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara
       return false;
     }
 
-    // TODO, add support of the ceil_mode by adjusting the padding
-    // See https://stackoverflow.com/questions/59906456/in-pytorchs-maxpool2d-is-padding-added-depending-on-ceil-mode
-    // and https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/mil/frontend/torch/ops.py#L621-L644
-    if (helper.Get("ceil_mode", 0) == 1) {
-      LOGS(logger, VERBOSE) << "ceil_mode == 1 is not supported for pooling";
-      return false;
+    if (!input_params.create_mlprogram) {
+      // TODO, add support of the ceil_mode by adjusting the padding
+      // See https://stackoverflow.com/questions/59906456/in-pytorchs-maxpool2d-is-padding-added-depending-on-ceil-mode
+      // and https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/mil/frontend/torch/ops.py#L621-L644
+      if (helper.Get("ceil_mode", 0) == 1) {
+        LOGS(logger, VERBOSE) << "ceil_mode == 1 is not supported for pooling";
+        return false;
+      }
     }
 
     if (helper.Get("dilations", std::vector<int32_t>{1, 1}) !=
diff --git a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc
index 6a2014e7952a2..32378b1f654d8 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc
@@ -1,36 +1,27 @@
 // Copyright (c) Shukant Pal.
 // Licensed under the MIT License.
 
+#include "core/optimizer/initializer.h"
 #include "core/providers/common.h"
-#include "core/providers/shared/utils/utils.h"
-
-#ifdef __APPLE__
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
 #include "core/providers/coreml/builders/helper.h"
+#include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
-#include "core/optimizer/initializer.h"
-
-#include "base_op_builder.h"
+#include "core/providers/shared/utils/utils.h"
 
 namespace onnxruntime {
 namespace coreml {
 
 class ReductionOpBuilder : public BaseOpBuilder {
-#ifdef __APPLE__
- public:
   void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
 
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
- private:
+
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 };
 
-#ifdef __APPLE__
 namespace {
 template <typename T>
 void AddReductionParams(T* params, const std::vector<int64_t>& axes, bool keepdims, bool noop_with_empty_axes) {
@@ -76,7 +67,7 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co
   const bool keepdims = helper.Get("keepdims", 1) != 0;
   const bool noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0) != 0;
 
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
 
   if (op_type == "ReduceSum") {
     AddReductionParams(layer->mutable_reducesum(), axes, keepdims, noop_with_empty_axes);
@@ -93,7 +84,6 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co
   model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif
 
 bool ReductionOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                                            const logging::Logger& logger) const {
@@ -124,4 +114,4 @@ void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations
 }
 
 }  // namespace coreml
-}  // namespace onnxruntime
\ No newline at end of file
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc
index 67aee73630cdb..27d24d9c21893 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc
@@ -1,90 +1,96 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
-#include "core/framework/tensorprotoutils.h"
 #include "core/optimizer/initializer.h"
-#include "core/providers/common.h"
 #include "core/providers/coreml/builders/helper.h"
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/impl/builder_utils.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/cpu/tensor/reshape_helper.h"
 #include "core/providers/shared/utils/utils.h"
 
-#ifdef __APPLE__
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime {
 namespace coreml {
 
 class ReshapeOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- public:
   void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
 
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 
   // Reshape opset 4- uses attributes for new shape which we do not support for now
   int GetMinSupportedOpSet(const Node& /* node */) const override { return 5; }
-};
 
-// Add operator related
+  bool SupportsMLProgram() const override { return true; }
+};
 
-#ifdef __APPLE__
 void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
+  // Skip the second input which is the new shape as we always have to create a new version as the CoreML rules
+  // are different from ONNX.
   model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
 }
 
 Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
                                                const Node& node,
                                                const logging::Logger& logger) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
-
   const auto& input_defs = node.InputDefs();
-  const auto& initializers(model_builder.GetInitializerTensors());
-  const auto& target_shape_tensor = *initializers.at(input_defs[1]->Name());
-  const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty()
-                                        ? reinterpret_cast<const int64_t*>(target_shape_tensor.raw_data().data())
-                                        : target_shape_tensor.int64_data().data();
-
-  const auto size = target_shape_tensor.dims()[0];
-  TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size};
   std::vector<int64_t> input_shape;
-  ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Cannot get shape");
-  ReshapeHelper helper(TensorShape(input_shape), target_shape);
-  *layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()};
-  *layer->mutable_input()->Add() = input_defs[0]->Name();
-  *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
+  ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Cannot get shape of data");
+
+  const auto& data_name = input_defs[0]->Name();
+  const auto& new_shape_name = input_defs[1]->Name();
+  Initializer unpacked_tensor(*model_builder.GetConstantInitializer(new_shape_name));
+  TensorShapeVector new_shape = ToShapeVector(unpacked_tensor.DataAsSpan<int64_t>());
+
+  // ReshapeHelper applies the ONNX rules to create the concrete output shape
+  ReshapeHelper helper(TensorShape(input_shape), new_shape);
 
-  model_builder.AddLayer(std::move(layer));
+#if defined(COREML_ENABLE_MLPROGRAM)
+  if (model_builder.CreateMLProgram()) {
+    using namespace CoreML::Specification::MILSpec;
+
+    // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.reshape
+    std::unique_ptr<Operation> reshape_op = model_builder.CreateOperation(node, "reshape");
+
+    AddOperationInput(*reshape_op, "x", data_name);
+    AddOperationInput(*reshape_op, "shape",
+                      model_builder.AddConstant(reshape_op->type(), "shape", ToConstSpan(new_shape)));
+
+    AddOperationOutput(*reshape_op, *node.OutputDefs()[0]);
+
+    model_builder.AddOperation(std::move(reshape_op));
+  } else
+#endif  // defined(COREML_ENABLE_MLPROGRAM)
+  {
+    std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
+
+    *layer->mutable_reshapestatic()->mutable_targetshape() = {new_shape.cbegin(), new_shape.cend()};
+    *layer->mutable_input()->Add() = data_name;
+    *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
+
+    model_builder.AddLayer(std::move(layer));
+  }
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                                          const logging::Logger& logger) const {
   const auto& input_defs = node.InputDefs();
   const auto& new_shape_name = input_defs[1]->Name();
-  const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors();
-  if (!Contains(initializers, new_shape_name)) {
+  const auto* new_shape_tensor = input_params.graph_viewer.GetConstantInitializer(new_shape_name);
+  if (!new_shape_tensor) {
+    // ONNX has different rules around how -1 and 0 values are used/combined, and
+    // we can't check if those can be translated to CoreML if the shape is unknown.
     LOGS(logger, VERBOSE) << "New shape of reshape must be a constant initializer";
     return false;
   }
 
-  const auto& new_shape_tensor = *initializers.at(new_shape_name);
-  Initializer unpacked_tensor(new_shape_tensor);
+  Initializer unpacked_tensor(*new_shape_tensor);
   auto new_shape = unpacked_tensor.DataAsSpan<int64_t>();
   if (new_shape.empty()) {
     LOGS(logger, VERBOSE) << "New shape of reshape cannot be empty";
@@ -100,7 +106,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP
     return false;
   }
 
-  // CoreML reshape doesn't support new shape with more than 5 dimensions
+  // CoreML reshape doesn't support new shape with more than 5 dimensions.
   if (new_shape.size() > 5) {
     LOGS(logger, VERBOSE) << "Reshape does not support new shape with rank greater than 5. Input shape: "
                           << Shape2String(input_shape) << ", new shape: " << Shape2String(new_shape);
@@ -109,7 +115,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP
 
   // CoreML reshape does not support 0 as dimension
   NodeAttrHelper helper(node);
-  const bool allow_zero = helper.Get("allowzero ", 0) == 1;
+  const bool allow_zero = helper.Get("allowzero", 0) == 1;
   if (allow_zero) {
     if (std::find(new_shape.begin(), new_shape.end(), int64_t{0}) != new_shape.end()) {
       LOGS(logger, VERBOSE) << "Reshape does not support new shape with 0 as dimension when allowzero is enabled. "
diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc
index 5f963dc30dd8f..6c2fcc2ace856 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc
@@ -8,31 +8,21 @@
 #include "core/providers/common.h"
 #include "core/providers/coreml/builders/helper.h"
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/cpu/tensor/reshape_helper.h"
 #include "core/providers/shared/utils/utils.h"
 
-#ifdef __APPLE__
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime {
 namespace coreml {
 
 class ResizeOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- public:
   void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
 
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 
@@ -41,7 +31,7 @@ class ResizeOpBuilder : public BaseOpBuilder {
   int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; }
 };
 
-// Helper functions
+namespace {
 bool GetResizeScales(const InitializedTensorSet& initializers,
                      const Node& node, std::vector<float>& scales,
                      const logging::Logger&) {
@@ -73,10 +63,8 @@ bool GetResizeOutputSizes(const InitializedTensorSet& initializers,
   sizes = std::vector<int64_t>(sizes_data.begin(), sizes_data.end());
   return true;
 }
+}  // namespace
 
-// Add operator related
-
-#ifdef __APPLE__
 void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
   // We don't really use ROI here, so add it to skipped list if it's an initializer tensor
   model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());  // ROI
@@ -96,7 +84,7 @@ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N
 Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
                                               const Node& node,
                                               const logging::Logger& logger) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
 
   auto* coreml_upsample = layer->mutable_upsample();
   NodeAttrHelper helper(node);
@@ -110,7 +98,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
   const auto& input_defs = node.InputDefs();
   const auto& initializers(model_builder.GetInitializerTensors());
 
-  if (input_defs.size() == 3) {  // use scales
+  if (input_defs.size() >= 3 && input_defs[2]->Exists()) {  // use scales
     std::vector<float> scales;
     ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales");
     coreml_upsample->add_scalingfactor(static_cast<int64_t>(scales[2]));
@@ -131,9 +119,6 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
   model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                                         const logging::Logger& logger) const {
@@ -197,20 +182,24 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa
       return false;
     }
 
+    bool using_scales = input_defs.size() >= 3 && input_defs[2]->Exists();
     // scales
-    if (input_defs.size() == 3 && !Contains(initializers, input_defs[2]->Name())) {
-      LOGS(logger, VERBOSE) << "Input scales of Resize must be known";
+    if (using_scales && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name())) {
+      LOGS(logger, VERBOSE) << "scales input of Resize must be a constant initializer";
       return false;
     }
 
     // sizes
-    if (input_defs.size() > 3 && !Contains(initializers, input_defs[3]->Name())) {
-      LOGS(logger, VERBOSE) << "Input sizes of Resize must be known";
+    if (!using_scales &&
+        (input_defs.size() < 4 ||
+         !input_defs[3]->Exists() ||
+         !input_params.graph_viewer.GetConstantInitializer(input_defs[3]->Name()))) {
+      LOGS(logger, VERBOSE) << "sizes input of Resize must be a constant initializer";
       return false;
     }
 
     // We want to check if the scales or sizes are not trying to resize on N/C channels here
-    if (input_defs.size() == 3) {  // we are using scales
+    if (using_scales) {
       std::vector<float> scales;
       if (!GetResizeScales(initializers, node, scales, logger))
         return false;
diff --git a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc
index fd64153ffd283..a86e3d9538d87 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc
@@ -2,44 +2,30 @@
 // Licensed under the MIT License.
 
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
-
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/shared/utils/utils.h"  // for NodeAttrHelper
 
-#if defined(__APPLE__)
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime::coreml {
 
 class ShapeOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 };
 
-// Add operator related
-#if defined(__APPLE__)
 Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
-                                             const logging::Logger& logger) const {
-  auto layer = CreateNNLayer(model_builder, node);
+                                             const logging::Logger& /*logger*/) const {
+  auto layer = model_builder.CreateNNLayer(node);
   layer->mutable_getshape();
   *layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
   *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
   model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif  // defined(__APPLE__)
 
-// Operator support related
 bool ShapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
                                        const logging::Logger& logger) const {
   NodeAttrHelper node_attr_helper{node};
diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc
index 2c250b3cc9f5a..39bfbfe5bba1f 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc
@@ -1,39 +1,31 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
-#include "core/providers/coreml/builders/impl/base_op_builder.h"
-
 #include "core/optimizer/initializer.h"
 #include "core/providers/coreml/builders/helper.h"
+#include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/cpu/tensor/slice_helper.h"
 #include "core/providers/shared/utils/utils.h"
 
-#if defined(__APPLE__)
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime::coreml {
 
 class SliceOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- private:
   void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
 
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
   int GetMinSupportedOpSet(const Node& /* node */) const override {
     // Before Slice-10, some inputs were attributes instead. We don't support that for now.
     return 10;
   }
 
-  bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override;
+  bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params,
+                              const logging::Logger& logger) const override;
+
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& builder_params,
                          const logging::Logger& logger) const override;
 };
@@ -62,7 +54,7 @@ Status PrepareSliceComputeMetadataFromConstantInitializers(const Node& slice_nod
       return Status::OK();
     }
 
-    const auto* tensor_proto = graph_viewer.GetConstantInitializer(input_defs[input_idx]->Name(), true);
+    const auto* tensor_proto = graph_viewer.GetConstantInitializer(input_defs[input_idx]->Name());
     ORT_RETURN_IF_NOT(tensor_proto, "Failed to get constant initializer.");
     Initializer unpacked_tensor(*tensor_proto, graph_viewer.ModelPath());
     const auto data_type = unpacked_tensor.data_type();
@@ -107,9 +99,6 @@ bool ValidateSliceComputeMetadataForCoreML(const SliceOp::PrepareForComputeMetad
 }
 }  // namespace
 
-// Add operator related
-#if defined(__APPLE__)
-
 void SliceOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
   const auto& input_defs = node.InputDefs();
 
@@ -132,7 +121,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
   ORT_RETURN_IF_ERROR(PrepareSliceComputeMetadataFromConstantInitializers(node, model_builder.GetGraphViewer(),
                                                                           compute_metadata));
 
-  auto layer = CreateNNLayer(model_builder, node);
+  auto layer = model_builder.CreateNNLayer(node);
   *layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
   *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
   auto* slice_static = layer->mutable_slicestatic();
@@ -163,10 +152,8 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
   return Status::OK();
 }
 
-#endif  // defined(__APPLE__)
-
-// Operator support related
-bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const {
+bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
+                                            const logging::Logger& logger) const {
   int32_t input_type;
   if (!GetType(*node.InputDefs()[0], input_type, logger))
     return false;
diff --git a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc
index c454a2a779f6e..d6584124c6aba 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc
@@ -1,43 +1,29 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
-#include "core/providers/coreml/builders/impl/base_op_builder.h"
-
 #include "core/framework/tensorprotoutils.h"
 #include "core/providers/common.h"
-#include "core/providers/coreml/shape_utils.h"
-#include "core/providers/shared/utils/utils.h"
-
-#ifdef __APPLE__
+#include "core/providers/coreml/builders/impl/base_op_builder.h"
 #include "core/providers/coreml/builders/model_builder.h"
-#endif
 #include "core/providers/coreml/builders/op_builder_factory.h"
+#include "core/providers/coreml/shape_utils.h"
+#include "core/providers/shared/utils/utils.h"
 
 namespace onnxruntime {
 namespace coreml {
 
 class SoftmaxOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 };
 
-// Add operator related
-
-#ifdef __APPLE__
-
 Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
                                                const Node& node,
                                                const logging::Logger& logger) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
   const auto& input_name = node.InputDefs()[0]->Name();
   const auto& output_name = node.OutputDefs()[0]->Name();
 
@@ -66,17 +52,15 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
     target_shape.push_back(size_to_dimension);
     target_shape.push_back(size_from_dimension);
 
-    const auto reshape1_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "reshape1_output"));
+    const auto reshape1_output_name = model_builder.GetUniqueName(node, "reshape1_output");
     {  // Add reshape layer
-      const auto softmax_reshape1_layer_name =
-          model_builder.GetUniqueName(MakeString(node.Name(), "_Softmax_reshape1"));
-      auto reshape_layer = CreateNNLayer(softmax_reshape1_layer_name);
+      auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape1");
       *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()};
       *reshape_layer->mutable_input()->Add() = input_name;
       *reshape_layer->mutable_output()->Add() = reshape1_output_name;
       model_builder.AddLayer(std::move(reshape_layer));
     }
-    const auto softmax_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "softmax_output"));
+    const auto softmax_output_name = model_builder.GetUniqueName(node, "softmax_output");
     {
       auto* coreml_softmaxnd = layer->mutable_softmaxnd();
       coreml_softmaxnd->set_axis(-1);
@@ -86,9 +70,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
     }
     {
       // Add reshape back layer
-      const auto softmax_reshape2_layer_name =
-          model_builder.GetUniqueName(MakeString(node.Name(), "_Softmax_reshape2"));
-      auto reshape_layer = CreateNNLayer(softmax_reshape2_layer_name);
+      auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape2");
       *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {data_shape.cbegin(), data_shape.cend()};
       *reshape_layer->mutable_input()->Add() = softmax_output_name;
       *reshape_layer->mutable_output()->Add() = output_name;
@@ -99,10 +81,6 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
   return Status::OK();
 }
 
-#endif
-
-// Operator support related
-
 bool SoftmaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */,
                                          const logging::Logger& logger) const {
   const auto& input_defs = node.InputDefs();
diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc
index 56c87c883156b..0497357c45c54 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc
@@ -1,35 +1,24 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
-#include "core/providers/coreml/builders/impl/base_op_builder.h"
-
 #include "core/optimizer/initializer.h"
 #include "core/providers/common.h"
 #include "core/providers/coreml/builders/helper.h"
+#include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/shared/utils/utils.h"
 
-#if defined(__APPLE__)
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime {
 namespace coreml {
 
 class SplitOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- private:
   void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
 
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 
@@ -37,10 +26,6 @@ class SplitOpBuilder : public BaseOpBuilder {
   int GetMinSupportedOpSet(const Node& /* node */) const override { return 13; }
 };
 
-// Add operator related
-
-#ifdef __APPLE__
-
 void SplitOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
   const auto& input_defs = node.InputDefs();
 
@@ -63,7 +48,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
   // attribute introduced since opset 18
   uint64_t num_outputs;
 
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
   auto* coreml_splitnd = layer->mutable_splitnd();
   coreml_splitnd->set_axis(axis);
 
@@ -82,7 +67,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
     coreml_splitnd->set_numsplits(num_outputs);
   } else {
     // note: for opset 18+ 'num_outputs' is a required attribute
-    num_outputs = narrow<uint64_t>(helper.GetInt("num_outputs").value());
+    num_outputs = narrow<uint64_t>(helper.GetInt64("num_outputs").value());
     // note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists
     auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())];
     uint64_t chunk_size = narrow<uint64_t>((split_dim_size + num_outputs - 1) / num_outputs);
@@ -111,10 +96,6 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
   return Status::OK();
 }
 
-#endif
-
-// Operator support related
-
 bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                                        const logging::Logger& logger) const {
   const auto& input_defs = node.InputDefs();
@@ -159,7 +140,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar
     }
   } else {
     if (node.SinceVersion() >= 18) {
-      const auto num_outputs = helper.GetInt("num_outputs");
+      const auto num_outputs = helper.GetInt64("num_outputs");
       if (!num_outputs.has_value()) {
         LOGS(logger, VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute.";
         return false;
@@ -169,9 +150,10 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar
                               << "CoreML SplitND requires at least 2 outputs. num_outputs: " << num_outputs.value();
         return false;
       }
-      if (num_outputs.value() != static_cast<int32_t>(node.OutputDefs().size()) || num_outputs.value() > split_dims_at_axis) {
-        LOGS(logger, VERBOSE) << "Invalid num_outputs provided.\n."
-                              << "The value should be smaller or equal to the size of dimension being split. num_outputs: "
+      if (num_outputs.value() != static_cast<int32_t>(node.OutputDefs().size()) ||
+          num_outputs.value() > split_dims_at_axis) {
+        LOGS(logger, VERBOSE) << "Invalid num_outputs provided.\n. The value should be smaller or equal to the size "
+                                 "of dimension being split. num_outputs: "
                               << num_outputs.value();
         return false;
       }
diff --git a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc
index 2e14c85ce69c1..e9cc1c2dbf638 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc
@@ -1,48 +1,30 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
-#include <core/common/safeint.h>
+
+#include "core/common/safeint.h"
 #include "core/framework/tensorprotoutils.h"
 #include "core/providers/common.h"
-#include "core/providers/shared/utils/utils.h"
-#include "core/optimizer/initializer.h"
-
-#ifdef __APPLE__
+#include "core/providers/coreml/builders/impl/base_op_builder.h"
 #include "core/providers/coreml/builders/model_builder.h"
-#endif
 #include "core/providers/coreml/builders/op_builder_factory.h"
-
-#include "base_op_builder.h"
+#include "core/providers/shared/utils/utils.h"
+#include "core/optimizer/initializer.h"
 
 namespace onnxruntime {
 namespace coreml {
 
 class SqueezeOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- public:
   void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
 
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 
-  // Operator support related
- private:
   bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                          const logging::Logger& logger) const override;
 };
 
-// Add operator related
-
-#ifdef __APPLE__
-void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
-  if (node.SinceVersion() > 12 && node.InputDefs().size() > 1) {
-    model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
-  }
-}
-
-/* static */ Status GetAxes(ModelBuilder& model_builder, const Node& node, std::vector<int64_t>& axes) {
+namespace {
+Status GetAxes(ModelBuilder& model_builder, const Node& node, std::vector<int64_t>& axes) {
   // Squeeze opset 13 use input as axes
   if (node.SinceVersion() > 12) {
     // If axes is not provided, return an empty axes as default to squeeze all
@@ -62,11 +44,18 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
 
   return Status::OK();
 }
+}  // namespace
+
+void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
+  if (node.SinceVersion() > 12 && node.InputDefs().size() > 1) {
+    model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
+  }
+}
 
 Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
                                                const Node& node,
                                                const logging::Logger& /* logger */) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
 
   auto* coreml_squeeze = layer->mutable_squeeze();
   std::vector<int64_t> axes;
@@ -84,9 +73,6 @@ Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
   model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 bool SqueezeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
                                          const logging::Logger& /*logger*/) const {
diff --git a/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc
index 7d5018a19f74c..f6a61d55a3d63 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc
@@ -3,33 +3,23 @@
 
 #include "core/providers/coreml/builders/helper.h"
 #include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/shape_utils.h"
 #include "core/providers/shared/utils/utils.h"
 
-#ifdef __APPLE__
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
-
 namespace onnxruntime {
 namespace coreml {
 
 class TransposeOpBuilder : public BaseOpBuilder {
-  // Add operator related
-#ifdef __APPLE__
- private:
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 };
 
-// Add operator related
-
-#ifdef __APPLE__
 Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
                                                  const Node& node,
                                                  const logging::Logger& logger) const {
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
 
   NodeAttrHelper helper(node);
   std::vector<int64_t> perm = helper.Get("perm", std::vector<int64_t>());
@@ -51,7 +41,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
   model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif
 
 void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
   op_registrations.builders.push_back(std::make_unique<TransposeOpBuilder>());
diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc
index 660755b43c043..3403378d59114 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc
@@ -3,32 +3,25 @@
 
 #include "core/providers/common.h"
 
-#ifdef __APPLE__
-#include "core/providers/coreml/builders/model_builder.h"
-#endif
 #include "core/providers/coreml/builders/helper.h"
+#include "core/providers/coreml/builders/impl/base_op_builder.h"
+#include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/builders/op_builder_factory.h"
 
-#include "base_op_builder.h"
-
 namespace onnxruntime {
 namespace coreml {
 
 class UnaryOpBuilder : public BaseOpBuilder {
- private:
-#ifdef __APPLE__
   Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                const logging::Logger& logger) const override;
-#endif
 };
 
-#ifdef __APPLE__
 Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
                                              const logging::Logger& /* logger */) const {
   const auto& op_type(node.OpType());
   const auto& input_defs(node.InputDefs());
 
-  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = CreateNNLayer(model_builder, node);
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
 
   if (op_type == "Sqrt") {
     layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::SQRT);
@@ -45,9 +38,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
   model_builder.AddLayer(std::move(layer));
   return Status::OK();
 }
-#endif
-
-// Operator support related
 
 void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
   op_registrations.builders.push_back(std::make_unique<UnaryOpBuilder>());
@@ -55,4 +45,4 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op
 }
 
 }  // namespace coreml
-}  // namespace onnxruntime
\ No newline at end of file
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc
index 9c8b7bce507e4..eb4723a3b9746 100644
--- a/onnxruntime/core/providers/coreml/builders/model_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc
@@ -2,56 +2,675 @@
 // Licensed under the MIT License.
 
 #include <fstream>
-#include <core/common/safeint.h>
-
-#include "model_builder.h"
-#include "helper.h"
-#include "op_builder_factory.h"
 
+#include "core/common/safeint.h"
+#include "core/framework/tensorprotoutils.h"
+#include "core/platform/env.h"
 #include "core/providers/common.h"
+#include "core/providers/coreml/builders/model_builder.h"
+#include "core/providers/coreml/builders/helper.h"
+#include "core/providers/coreml/builders/op_builder_factory.h"
 #include "core/providers/coreml/builders/impl/builder_utils.h"
+#include "core/providers/coreml/coreml_provider_factory.h"
 #include "core/providers/coreml/model/host_utils.h"
-#include "core/providers/coreml/model/model.h"
 #include "core/providers/coreml/shape_utils.h"
 
+#if defined(COREML_ENABLE_MLPROGRAM)
+// includes from coremltools-src in _deps
+#include "modelpackage/src/ModelPackage.hpp"
+#include "mlmodel/src/MILBlob/Blob/StorageWriter.hpp"
+using MILBlob::Blob::StorageWriter;
+#endif
+
+using namespace CoreML::Specification;
+
 namespace onnxruntime {
 namespace coreml {
 
-ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, uint32_t coreml_flags)
+namespace {
+#if defined(COREML_ENABLE_MLPROGRAM)
+// Should the initializer be written to file or kept as an immediate value
+bool ShouldWriteInitializerToWeightsFile(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
+  // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/load.py#L51-L57
+
+  bool use_weight_file = false;
+
+  switch (tensor_proto.data_type()) {
+    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
+    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
+    case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
+    case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
+      auto num_elements = TensorShape(utils::GetTensorShapeFromTensorProto(tensor_proto)).Size();
+      use_weight_file = num_elements >= 10;
+      break;
+    }
+    default:
+      break;
+  }
+
+  return use_weight_file;
+}
+
+// copy from the ONNX TensorProto to a CoreML field.
+// T1 is the source type. T2 is the target type. If the types differ, T1 must be smaller than T2.
+// e.g. uint32_t data can be written to RepeatedField<uint64_t>
+template <typename T1, typename T2 = T1>
+void CopyRawDataToRepeatedField(const ONNX_NAMESPACE::TensorProto& tensor_proto,
+                                google::protobuf::RepeatedField<T2>& repeated_field) {
+  const auto& raw_data = tensor_proto.raw_data();
+  const T1* data = reinterpret_cast<const T1*>(raw_data.data());
+  const T1* data_end = data + (raw_data.size() / sizeof(T1));
+  if constexpr (sizeof(T1) == sizeof(T2)) {
+    repeated_field.Add(data, data_end);
+  } else {
+    static_assert(sizeof(T1) < sizeof(T2));
+    // we need to iterate over the data and copy to the repeated field, converting to T2 as we go.
+    repeated_field.Resize(data_end - data, T2(0));
+    for (int i = 0; data != data_end; ++data, ++i) {
+      repeated_field[i] = static_cast<T2>(*data);
+    }
+  }
+}
+
+// copy T data from the TensorProto.int32_t field to TensorValue.bytes
+template <typename T>
+void CopyInt32DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue tensor_value) {
+  const int num_entries = tensor_proto.int32_data_size();
+  std::string& bytes = *tensor_value.mutable_bytes()->mutable_values();
+  bytes.resize(num_entries * sizeof(T));
+  T* out = reinterpret_cast<T*>(bytes.data());
+
+  const int32_t* in = tensor_proto.int32_data().data();
+  for (int i = 0; i < num_entries; ++i) {
+    out[i] = static_cast<T>(in[i]);
+  }
+}
+
+// copy T data from the TensorProto.uint64_data field to TensorValue.bytes
+template <typename T>
+void CopyUInt64DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue tensor_value) {
+  const int num_entries = tensor_proto.uint64_data_size();
+  std::string& bytes = *tensor_value.mutable_bytes()->mutable_values();
+  bytes.resize(num_entries * sizeof(T));
+  T* out = reinterpret_cast<T*>(bytes.data());
+
+  const uint64_t* in = tensor_proto.uint64_data().data();
+  for (int i = 0; i < num_entries; ++i) {
+    out[i] = static_cast<T>(in[i]);
+  }
+}
+
+// NOTE: This supports all the ONNX data types. Weights in CoreML may not need all these
+void CopyOnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto,
+                                  MILSpec::TensorValue& tensor_value) {
+  bool has_raw_data = tensor_proto.has_raw_data();
+  auto data_type = tensor_proto.data_type();
+
+  // handling based on
+  // ONNX TensorProto field usage
+  // https://github.com/onnx/onnx/blob/b86cc54efce19530fb953e4b21f57e6b3888534c/onnx/onnx.proto#L544-L572
+  // CoreMLTools conversion implementation that maps data types to fields
+  // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L98
+  // along with some special cased types that are stored in bytes
+  // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L23
+  //   IMMEDIATE_VALUE_TYPES_IN_BYTES = (types.fp16, types.int8, types.uint8, types.uint32)
+
+  switch (data_type) {
+    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
+      // from: float_data/raw, to: floats
+      if (has_raw_data) {
+        CopyRawDataToRepeatedField<float>(tensor_proto, *tensor_value.mutable_floats()->mutable_values());
+      } else {
+        tensor_value.mutable_floats()->mutable_values()->CopyFrom(tensor_proto.float_data());
+      }
+      break;
+    }
+    case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
+      // from: double_data/raw, to: doubles
+      if (has_raw_data) {
+        CopyRawDataToRepeatedField<double>(tensor_proto, *tensor_value.mutable_doubles()->mutable_values());
+      } else {
+        tensor_value.mutable_doubles()->mutable_values()->CopyFrom(tensor_proto.double_data());
+      }
+      break;
+    }
+    case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
+      // from: int32_data/raw, to: ints
+      if (has_raw_data) {
+        CopyRawDataToRepeatedField<int32_t>(tensor_proto, *tensor_value.mutable_ints()->mutable_values());
+      } else {
+        tensor_value.mutable_ints()->mutable_values()->CopyFrom(tensor_proto.int32_data());
+      }
+      break;
+    }
+    case ONNX_NAMESPACE::TensorProto_DataType_INT64: {
+      // enable when this is proven to not be the case
+      ORT_THROW(
+          "INT64 is unexpected as CoreML uses 32-bit int for indices. "
+          "Most likely an initializer that should have been skipped was not.");
+      //// from: int64_data/raw, to: longints
+      // if (has_raw_data) {
+      //   CopyRawDataToRepeatedField<int64_t>(tensor_proto, *tensor_value.mutable_longints()->mutable_values());
+
+      //} else {
+      //  tensor_value.mutable_longints()->mutable_values()->CopyFrom(tensor_proto.int64_data());
+      //}
+      // break;
+    }
+    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: {
+      // from: int32_data/raw, to: bytes
+      if (has_raw_data) {
+        *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data();
+      } else {
+        // iterate the int32_data, taking the 16-bits from each entry, and copying to the bytes.
+        // we use uint16_t as only the size of the data type matters
+        CopyInt32DataToBytes<uint16_t>(tensor_proto, tensor_value);
+      }
+      break;
+    }
+    case ONNX_NAMESPACE::TensorProto_DataType_INT8:
+    case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
+      // from: int32_data/raw, to: bytes
+      if (has_raw_data) {
+        *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data();
+      } else {
+        // copy from int32_data to bytes. uint8_t for both as only the size of the data type matters when copying
+        CopyInt32DataToBytes<uint8_t>(tensor_proto, tensor_value);
+      }
+      break;
+    }
+    case ONNX_NAMESPACE::TensorProto_DataType_UINT32: {
+      // from: uint64_data/raw, to: bytes
+      if (has_raw_data) {
+        *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data();
+      } else {
+        // copy uint32_t values from TensorProto.uint64_data
+        CopyUInt64DataToBytes<uint32_t>(tensor_proto, tensor_value);
+      }
+      break;
+    }
+    case ONNX_NAMESPACE::TensorProto_DataType_UINT64: {
+      // enable when this is proven to not be the case
+      ORT_THROW(
+          "UINT64 is unexpected as CoreML uses 32-bit int for indices. "
+          "Most likely an initializer that should have been skipped was not.");
+      //// from: uint64_data/raw, to: longints
+      // if (has_raw_data) {
+      //   CopyRawDataToRepeatedField<uint64_t>(tensor_proto, *tensor_value.mutable_longints()->mutable_values());
+      // } else {
+      //   // TODO: Is this safe? Need to check the CopyFrom implementation. As it's a straight copy of bytes this
+      //   // hopefully can do it as one block instead of iterating and potentially doing a static_cast of each
+      //   // individual value.
+      //   tensor_value.mutable_longints()->mutable_values()->CopyFrom(
+      //       reinterpret_cast<const google::protobuf::RepeatedField<int64_t>&>(tensor_proto.uint64_data()));
+      // }
+
+      // break;
+    }
+    case ONNX_NAMESPACE::TensorProto_DataType_BOOL: {
+      // from: int32_data/raw, to: bools
+      if (has_raw_data) {
+        CopyRawDataToRepeatedField<bool>(tensor_proto, *tensor_value.mutable_bools()->mutable_values());
+      } else {
+        const auto& int32s = tensor_proto.int32_data();
+        auto& bools = *tensor_value.mutable_bools()->mutable_values();
+        const int num_entries = int32s.size();
+        bools.Reserve(num_entries);
+        const int32_t* in = int32s.data();
+        for (int i = 0; i < num_entries; ++i) {
+          *bools.AddAlreadyReserved() = *in++;
+        }
+      }
+
+      break;
+    }
+    case ONNX_NAMESPACE::TensorProto_DataType_STRING: {
+      // from: string_data (which is protobuf type bytes), to: strings (protobuf type string)
+      // due to the protobuf type mismatch we need to iterate and copy
+      auto& in = tensor_proto.string_data();
+      auto& out = *tensor_value.mutable_strings()->mutable_values();
+      out.Reserve(in.size());
+      for (const auto& iter : in) {
+        *out.Add() = iter;
+      }
+
+      break;
+    }
+    /* Not clear if there's an actual use-case for 16-bit int data currently, so leaving commented out
+    case ONNX_NAMESPACE::TensorProto_DataType_INT16:
+    case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
+      // from: int32_data/raw, to: ints
+      // WARNING: This may change to write to mutable_bytes
+      // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L113-L115
+      if (has_raw_data) {
+          CopyRawDataToRepeatedField<uint16_t, int32_t>(tensor_proto, *tensor_value.mutable_ints()->mutable_values());
+      } else {
+          tensor_value.mutable_ints()->mutable_values()->CopyFrom(tensor_proto.int32_data());
+      }
+      break;
+    } */
+    default:
+      ORT_THROW("AddTensorProtoDataToMILSpecTensorValue: Unsupported data type: ", data_type);
+  }
+}
+
+template <typename T>
+uint64_t WriteRawDataUsingStorageWriter(const onnx::TensorProto& tensor_proto,
+                                        MILBlob::Blob::StorageWriter& writer) {
+  MILBlob::Util::Span<const T> data(reinterpret_cast<const T*>(tensor_proto.raw_data().data()),
+                                    tensor_proto.raw_data().size() / sizeof(T));
+  return writer.WriteData(data);
+}
+
+// Write T1 data from the TensorProto.int32_data field using StorageWriter.
+// Currently int32_data can have any of these data types:
+//   INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16,
+//   FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
+// T1 provides the size of the ONNX data type. T2 is the CoreML type.
+// The sizes and layout of T1 and T2 must match as we simply cast the bytes to T2.
+template <typename T1, typename T2 = T1>
+uint64_t WriteFromInt32DataUsingStorageWriter(const onnx::TensorProto& tensor_proto,
+                                              MILBlob::Blob::StorageWriter& writer) {
+  static_assert(sizeof(T1) == sizeof(T2), "Data sizes must match");
+
+  // need to copy to temporary data as we have to extract a subset of bytes from each int32_t entry.
+  // works better to extract the ONNX type first with static_cast, and reinterpret_cast to the CoreML type at the end.
+  std::vector<T1> values;
+  const int num_values = tensor_proto.int32_data_size();
+  values.resize(num_values);  // resize so we're not updating the length inside the copy loop
+
+  const int32_t* in = tensor_proto.int32_data().data();
+  for (int i = 0; i < num_values; ++i) {
+    values[i] = static_cast<T1>(in[i]);
+  }
+
+  MILBlob::Util::Span<const T2> data(reinterpret_cast<const T2*>(values.data()),
+                                     num_values);
+  return writer.WriteData(data);
+}
+
+// write the initializer to weight.bin and return the offset
+// StorageWriter is currently limited to fp32, fp16, bfloat16, uint8/int8, uint16/int16.
+// AFAIK we don't use bfloat16/int16/uint16 for weights in ONNX, so limit handling to fp32, fp16, uint8/int8
+uint64_t CopyOnnxTensorToCoreMLWeightsFile(const onnx::TensorProto& tensor_proto,
+                                           MILBlob::Blob::StorageWriter& writer) {
+  bool has_raw_data = tensor_proto.has_raw_data();
+  auto data_type = tensor_proto.data_type();
+
+  uint64_t offset = 0;
+
+  // See AddTensorProtoDataToMILSpecTensorValue for links to sources for info on where the different typed data is
+  // stored for ONNX and CoreML
+
+  switch (data_type) {
+    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
+      // from: float_data/raw, to: floats
+      if (has_raw_data) {
+        offset = WriteRawDataUsingStorageWriter<float>(tensor_proto, writer);
+      } else {
+        MILBlob::Util::Span<const float> data(tensor_proto.float_data().data(), tensor_proto.float_data().size());
+        offset = writer.WriteData(data);
+      }
+      break;
+    }
+    case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: {
+      // from: int32_data/raw, to: bytes
+      if (has_raw_data) {
+        offset = WriteRawDataUsingStorageWriter<MILBlob::Fp16>(tensor_proto, writer);
+      } else {
+        offset = WriteFromInt32DataUsingStorageWriter<uint16_t, MILBlob::Fp16>(tensor_proto, writer);
+      }
+
+      break;
+    }
+
+    case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
+      // from: int32_data/raw, to: bytes
+      if (has_raw_data) {
+        offset = WriteRawDataUsingStorageWriter<int8_t>(tensor_proto, writer);
+      } else {
+        offset = WriteFromInt32DataUsingStorageWriter<int8_t>(tensor_proto, writer);
+      }
+      break;
+    }
+    case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
+      // from: int32_data/raw, to: bytes
+      if (has_raw_data) {
+        offset = WriteRawDataUsingStorageWriter<uint8_t>(tensor_proto, writer);
+
+      } else {
+        offset = WriteFromInt32DataUsingStorageWriter<uint8_t>(tensor_proto, writer);
+      }
+      break;
+    }
+    default:
+      ORT_THROW("AddWeightToFile: Unsupported data type: ", data_type);
+  }
+
+  return offset;
+}
+
+MILSpec::Value OnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto,
+                                        MILBlob::Blob::StorageWriter& weights_file_writer) {
+  MILSpec::Value value;
+
+  // populate ValueType with tensor data type, dims and rank
+  MILSpec::ValueType& value_type = *value.mutable_type();
+  MILSpec::TensorType& tensor_type = *value_type.mutable_tensortype();
+  tensor_type.set_datatype(OnnxDataTypeToMILSpec(tensor_proto.data_type()));
+
+  tensor_type.set_rank(tensor_proto.dims().size());
+  for (const auto& dim : tensor_proto.dims()) {
+    tensor_type.add_dimensions()->mutable_constant()->set_size(dim);
+  }
+
+  // add data to either weights.bin or as an immediate value
+  if (ShouldWriteInitializerToWeightsFile(tensor_proto)) {
+    uint64_t offset = CopyOnnxTensorToCoreMLWeightsFile(tensor_proto, weights_file_writer);
+
+    auto* file_value = value.mutable_blobfilevalue();
+    // Filename copied from
+    // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L329
+    file_value->set_filename("@model_path/weights/weight.bin");
+    file_value->set_offset(offset);
+  } else {
+    MILSpec::TensorValue& tensor_value = *value.mutable_immediatevalue()->mutable_tensor();
+    CopyOnnxTensorToCoreMLTensor(tensor_proto, tensor_value);
+  }
+
+  return value;
+}
+
+void CreateEmptyFile(const std::string& filename) {
+  std::ofstream file(filename, std::ofstream::out | std::ofstream::binary);
+  ORT_ENFORCE(file.is_open(), "Failed to open file ", filename);
+}
+
+#endif  // defined(COREML_ENABLE_MLPROGRAM)
+
+std::string GetModelOutputPath(bool create_ml_program) {
+  // path is used to create the ML Package directory for ML Program, and for the model directly otherwise.
+  auto path = util::GetTemporaryFilePath();
+  if (!create_ml_program) {
+    path += ".model.mlmodel";
+  }
+
+  return path;
+}
+}  // namespace
+
+ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
+                           int32_t coreml_version, uint32_t coreml_flags,
+                           std::vector<std::string>&& onnx_input_names,
+                           std::vector<std::string>&& onnx_output_names)
     : graph_viewer_(graph_viewer),
       logger_(logger),
-      coreml_flags_(coreml_flags) {
-}
+      coreml_version_(coreml_version),
+      coreml_flags_(coreml_flags),
+      create_ml_program_((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0),
+      model_output_path_(GetModelOutputPath(create_ml_program_)),
+      onnx_input_names_(std::move(onnx_input_names)),
+      onnx_output_names_(std::move(onnx_output_names)),
+      coreml_model_(std::make_unique<CoreML::Specification::Model>()) {
+  if (create_ml_program_) {
+#if defined(COREML_ENABLE_MLPROGRAM)
+    coreml_model_->set_specificationversion(CoreMLSpecVersion());
+    MILSpec::Program& mlprogram = *coreml_model_->mutable_mlprogram();
+    mlprogram.set_version(1);
+    mlprogram_main_fn_ = &(*mlprogram.mutable_functions())["main"];
 
-Status ModelBuilder::Initialize() {
-  coreml_model_ = std::make_unique<CoreML::Specification::Model>();
-  {  // initialize CoreML model
+    const std::string coreml_opset = "CoreML" + std::to_string(CoreMLVersion());
+    *mlprogram_main_fn_->mutable_opset() = coreml_opset;
+    mlprogram_main_block_ = &(*mlprogram_main_fn_->mutable_block_specializations())[coreml_opset];
+
+    // create the ModelPackage. this creates the output directory.
+    mlpackage_ = std::make_unique<MPL::ModelPackage>(model_output_path_, /* create */ true);
+
+    // ModelPackage::addItem does a copy of the file. Due to this we 'add' an empty file first,
+    // and do the actual writes to the file created in the package.
+    // We can't use ModelPackage::createFile as we have to add a directory for the weights.
+    std::string tmp_dir = model_output_path_ + "/tmp";
+    ORT_THROW_IF_ERROR(Env::Default().CreateFolder(ToPathString(tmp_dir)));
+    CreateEmptyFile(tmp_dir + "/weight.bin");
+
+    std::string weights_id = mlpackage_->addItem(tmp_dir, "weights", "com.microsoft.OnnxRuntime",
+                                                 "CoreML Model Weights");
+    auto weights_info = mlpackage_->findItem(weights_id);
+    weights_file_writer_ = std::make_unique<StorageWriter>(weights_info->path() + "/weight.bin");
+#else
+    // should never happen due to handling in coreml_execution_provider.cc
+    // throw here so all other code in this class can assume create_ml_program_ is only ever true in a build
+    // where ML Program support is enabled.
+    ORT_THROW("ML Program is not enabled in this build");
+#endif
+  } else {
     // We support CorelML Specification Version 4 (Core ML 3)
     coreml_model_->set_specificationversion(4);
     auto* neural_network = coreml_model_->mutable_neuralnetwork();
-    neural_network->set_arrayinputshapemapping(::CoreML::Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING);
+    neural_network->set_arrayinputshapemapping(
+        CoreML::Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING);
   }
 
-  PreprocessInitializers();
-  ORT_RETURN_IF_ERROR(RegisterInitializers());
-  ORT_RETURN_IF_ERROR(RegisterModelInputs());
-  ORT_RETURN_IF_ERROR(AddOperations());
-  ORT_RETURN_IF_ERROR(RegisterModelOutputs());
+  // populate names.
+  const auto& initializers = graph_viewer_.GetAllInitializedTensors();
+  const auto& inputs = graph_viewer_.GetInputs();
+  // rough guess to try and avoid reallocs. most nodes produce one output but some have more so allow for that.
+  // also need to convert attributes to constants so allow for that
+  unique_names_.reserve(initializers.size() + inputs.size() + size_t(graph_viewer_.NumberOfNodes() * 1.5));
+  for (const auto& pair : initializers) {
+    unique_names_.insert(pair.first);
+  }
 
-  return Status::OK();
+  for (const auto* input : inputs) {
+    unique_names_.insert(input->Name());
+  }
+
+  for (const auto& node : graph_viewer_.Nodes()) {
+    for (const auto& def : node.OutputDefs()) {
+      if (def->Exists()) {
+        unique_names_.insert(def->Name());
+      }
+    }
+  }
 }
 
-/* static */ const IOpBuilder* ModelBuilder::GetOpBuilder(const Node& node) {
-  const auto& op_builders = GetOpBuilders();
-  const auto it = op_builders.find(node.OpType());
-  if (it != op_builders.cend())
-    return it->second;
+ModelBuilder::~ModelBuilder() = default;
 
-  return nullptr;
+/*
+ * NeuralNetwork related helpers
+ */
+std::unique_ptr<NeuralNetworkLayer> ModelBuilder::CreateNNLayer(const Node& node, std::string_view suffix) {
+  auto layer_name = GetUniqueName(node, suffix);
+
+  std::unique_ptr<NeuralNetworkLayer> layer = std::make_unique<NeuralNetworkLayer>();
+  layer->set_name(layer_name);
+  return layer;
+}
+
+void ModelBuilder::AddLayer(std::unique_ptr<NeuralNetworkLayer> layer) {
+  auto* neural_network = coreml_model_->mutable_neuralnetwork();
+  neural_network->mutable_layers()->AddAllocated(layer.release());
 }
 
+/*
+ * ML Program related helpers
+ */
+#if defined(COREML_ENABLE_MLPROGRAM)
+const std::string& ModelBuilder::GetSafeName(const std::string& name) {
+  // Check the name is valid according to the MILSpec rules
+  // `Identifiers, generally used for names and keys, must match the regular expression [A-Za-z\_][A-Za-z0-9\_@]*.`
+  //
+  // There is a secondary list of reserved words that the coremltools python uses, but it's not clear if those are
+  // required here, or if we will ever hit a model that uses one of them. Due to that, skip checking them for now as
+  // it adds cost and code complexity
+  // https://github.com/apple/coremltools/blob/8b37641f243b1a3e81452feea311c6e30dcc9287/coremltools/converters/mil/mil/passes/defs/preprocess.py#L151C1-L175C10
+  // static InlinedHashSet<std::string> reserved_names =
+  //    {"any", "bool", "program", "func", "tensor", "list", "dict", "tuple", "true", "false",
+  //     "string", "bf16", "fp16", "fp32", "fp64", "int8", "int16", "int32", "int64",
+  //     "uint8", "uint16", "uint32", "uint64"};
+
+  // handle empty name. shouldn't happen but code below assumes name is not empty
+  if (name.empty()) {
+    return name;
+  }
+
+  // We don't need '@' or '\' even though they're allowed. Optimize for a good name that does not need to be changed.
+
+  // has been sanitized and changed already
+  const auto entry = values_to_rename_.find(name);
+  if (entry != values_to_rename_.end()) {
+    return entry->second;
+  }
+
+  // Replace anything but a good char with '_'. If first char is 0-9 we prefix with '_';
+  bool changed = false;
+  std::string result = name;
+
+  if (std::isdigit(result[0])) {
+    changed = true;
+    result = '_' + name;
+  }
+
+  for (char& c : result) {
+    if (!std::isalnum(c) && c != '_') {
+      changed = true;
+      c = '_';
+    }
+  }
+
+  if (!changed) {
+    return name;  // return original as the return value is a reference that must remain valid
+  }
+
+  return (values_to_rename_[name] = GetUniqueName(result));
+}
+
+void ModelBuilder::SanitizeNames() {
+  // ML Model level inputs/outputs
+  auto* desc = coreml_model_->mutable_description();
+  for (auto& input : *desc->mutable_input()) {
+    input.set_name(GetSafeName(input.name()));
+  }
+
+  for (auto& output : *desc->mutable_output()) {
+    output.set_name(GetSafeName(output.name()));
+  }
+
+  // main function inputs/outputs.
+  for (auto& input : *mlprogram_main_fn_->mutable_inputs()) {
+    input.set_name(GetSafeName(input.name()));
+  }
+
+  // outputs from block with operations for current coreml version
+  for (auto& output : *mlprogram_main_block_->mutable_outputs()) {
+    output = GetSafeName(output);
+  }
+
+  // iterate operations changing input/output/node names
+  for (auto& op : *mlprogram_main_block_->mutable_operations()) {
+    for (auto& input : *op.mutable_inputs()) {
+      for (auto& arg : *input.second.mutable_arguments()) {
+        arg.set_name(GetSafeName(arg.name()));
+      }
+    }
+
+    for (auto& output : *op.mutable_outputs()) {
+      output.set_name(GetSafeName(output.name()));
+    }
+  }
+}
+
+std::unique_ptr<COREML_SPEC::MILSpec::Operation> ModelBuilder::CreateOperation(const Node& node,
+                                                                               std::string_view op_type,
+                                                                               std::string_view suffix) {
+  std::string operation_name = GetUniqueName(node, suffix);
+
+  std::unique_ptr<MILSpec::Operation> op = std::make_unique<MILSpec::Operation>();
+  op->set_type(std::string(op_type));
+  (*op->mutable_attributes())["name"] = CreateScalarTensorValue(operation_name);
+
+  return op;
+}
+
+const std::string& ModelBuilder::AddConstantOperation(std::string_view name, MILSpec::Value&& coreml_tensor) {
+  // Replicates coremltools/converters/mil/backend/mil/load.py translate_const logic
+  MILSpec::Operation& const_op = *mlprogram_main_block_->mutable_operations()->Add();
+  const_op.set_type("const");
+
+  MILSpec::NamedValueType& output = *const_op.mutable_outputs()->Add();
+  output.set_name(std::string(name));
+  *output.mutable_type() = coreml_tensor.type();
+
+  auto& attr_map = *const_op.mutable_attributes();
+  // the operation name doesn't really matter as it isn't used elsewhere, so sanitize name now
+  attr_map["name"] = CreateScalarTensorValue(GetSafeName(output.name()));
+  attr_map["val"] = std::move(coreml_tensor);
+
+  return output.name();
+}
+
+// Add operation to the Block for the main function in the ML Program
+void ModelBuilder::AddOperation(std::unique_ptr<COREML_SPEC::MILSpec::Operation> operation) {
+  mlprogram_main_block_->mutable_operations()->AddAllocated(operation.release());
+}
+
+const std::string& ModelBuilder::AddTensorValueAsConstantOperation(std::string_view op_type,
+                                                                   std::string_view value_type,
+                                                                   MILSpec::Value&& input_value) {
+  auto unique_value_name = GetUniqueName(MakeString(op_type, "_", value_type));
+  return AddConstantOperation(unique_value_name, std::move(input_value));
+}
+
+template <typename T>
+std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type,
+                                               gsl::span<const T> value,
+                                               std::optional<gsl::span<const int64_t>> shape) {
+  // add specialization below
+  static_assert(false_for_T<T>, "Missing specialization for value type");
+
+  return "ModelBuilder::AddConstant error";  // unreachable
+}
+
+template <>
+std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type,
+                                               gsl::span<const float> value,
+                                               std::optional<gsl::span<const int64_t>> shape) {
+  auto input_value = CreateTensorValue<float>(value, shape);
+  return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value));
+}
+
+template <>
+std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type,
+                                               gsl::span<const int64_t> value,
+                                               std::optional<gsl::span<const int64_t>> shape) {
+  auto input_value = CreateTensorValue<int64_t, int32_t>(value, shape);  // CoreML uses int32
+  return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value));
+}
+
+template <>
+std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type,
+                                               gsl::span<const bool> value,
+                                               std::optional<gsl::span<const int64_t>> shape) {
+  auto input_value = CreateTensorValue<bool>(value, shape);
+  return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value));
+}
+
+template <>
+std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type,
+                                               gsl::span<const std::string> value,
+                                               std::optional<gsl::span<const int64_t>> shape) {
+  auto input_value = CreateTensorValue<std::string>(value, shape);
+  return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value));
+}
+
+#endif  // defined(COREML_ENABLE_MLPROGRAM)
+
+/*
+ * General implementation
+ */
 void ModelBuilder::PreprocessInitializers() {
-  // TODO: We should be using GetConstantInitializer not GetAllInitializedTensors in all places
+  // TODO: We should be using GetConstantInitializer not GetAllInitializedTensors in all places.
+  // non-constant initializers need to be passed in as model inputs in case they're overridden at runtime.
   const auto& initializers = graph_viewer_.GetAllInitializedTensors();
   const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder();
 
@@ -64,6 +683,7 @@ void ModelBuilder::PreprocessInitializers() {
         initializer_usage_[input->Name()]++;
       }
     }
+
     if (const auto* op_builder = GetOpBuilder(node)) {
       op_builder->AddInitializersToSkip(*this, node);
     }
@@ -77,27 +697,36 @@ Status ModelBuilder::RegisterInitializers() {
 
     // skip initializer if there is no remaining usage
     auto usage_count = initializer_usage_[name];
-    if (usage_count == 0)
+    if (usage_count == 0) {
       continue;
+    }
 
-    std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = std::make_unique<COREML_SPEC::NeuralNetworkLayer>();
-    layer->set_name(GetUniqueName("initializer_" + name));
+#if defined(COREML_ENABLE_MLPROGRAM)
+    if (create_ml_program_) {
+      MILSpec::Value coreml_tensor = OnnxTensorToCoreMLTensor(tensor, *weights_file_writer_);
+      ORT_IGNORE_RETURN_VALUE(AddConstantOperation(name, std::move(coreml_tensor)));
+    } else
+#endif
+    {
+      std::unique_ptr<NeuralNetworkLayer> layer = std::make_unique<NeuralNetworkLayer>();
+      layer->set_name(GetUniqueName("initializer_" + name));
 
-    // TODO,look at using LoadConstantLayer instead of LoadConstantNDLayer
-    auto* constant_tensor = layer->mutable_loadconstantnd();
-    const auto& shape = tensor.dims();
-    if (shape.empty()) {
-      // This is a scalar initializer, CoreML constant layer requires a shape, make this a {1} tensor
-      constant_tensor->mutable_shape()->Add(1);
-    } else {
-      std::transform(shape.cbegin(), shape.cend(),
-                     google::protobuf::RepeatedFieldBackInserter(constant_tensor->mutable_shape()),
-                     [](int64_t dim) -> uint64_t { return SafeInt<uint64_t>(dim); });
-    }
+      // TODO,look at using LoadConstantLayer instead of LoadConstantNDLayer
+      auto* constant_tensor = layer->mutable_loadconstantnd();
+      const auto& shape = tensor.dims();
+      if (shape.empty()) {
+        // This is a scalar initializer, CoreML constant layer requires a shape, make this a {1} tensor
+        constant_tensor->mutable_shape()->Add(1);
+      } else {
+        std::transform(shape.cbegin(), shape.cend(),
+                       google::protobuf::RepeatedFieldBackInserter(constant_tensor->mutable_shape()),
+                       [](int64_t dim) -> uint64_t { return SafeInt<uint64_t>(dim); });
+      }
 
-    ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*constant_tensor->mutable_data(), tensor));
-    *layer->mutable_output()->Add() = name;
-    AddLayer(std::move(layer));
+      ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*constant_tensor->mutable_data(), tensor));
+      *layer->mutable_output()->Add() = name;
+      AddLayer(std::move(layer));
+    }
   }
 
   return Status::OK();
@@ -109,32 +738,33 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
 
   if (is_input) {
     // input should not be an initializer
-    if (Contains(GetInitializerTensors(), name))
+    if (Contains(GetInitializerTensors(), name)) {
       return Status::OK();
+    }
 
     // This input will not be used
-    if (Contains(skipped_inputs_, name))
+    if (Contains(skipped_inputs_, name)) {
       return Status::OK();
+    }
   }
 
   auto* model_description = coreml_model_->mutable_description();
-  auto& input_output = is_input
-                           ? *model_description->mutable_input()->Add()
-                           : *model_description->mutable_output()->Add();
+  auto& input_output = is_input ? *model_description->mutable_input()->Add()
+                                : *model_description->mutable_output()->Add();
 
   input_output.set_name(name);
+
   auto* multi_array = input_output.mutable_type()->mutable_multiarraytype();
 
   std::vector<int64_t> shape;
-  ORT_RETURN_IF_NOT(GetShape(node_arg, shape, logger_),
-                    "Unable to get shape for ", input_output_type, ": ", name);
+  ORT_RETURN_IF_NOT(GetShape(node_arg, shape, logger_), "Unable to get shape for ", input_output_type, ": ", name);
 
   if (shape.empty()) {
-    // If we have an empty shape, this is a scalar input,
-    // Since all the input output of CoreML EP is MultiArray, we will make the scalar input output as a {1} MultiArray
+    // If we have an empty shape, this is a scalar
+    // Since all the input/output of CoreML EP is MultiArray, we will make the scalar input/output a {1} MultiArray
     shape.push_back(1);
 
-    // we need to change the shapes of these scalar outputs back to {} when CoreML EP returns these values to ORT
+    // we need to change the shapes of scalar outputs back to {} when CoreML EP returns values to ORT
     if (!is_input) {
       AddScalarOutput(name);
     }
@@ -179,15 +809,15 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
     data_type = type_proto->tensor_type().elem_type();
     switch (data_type) {
       case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
-        multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::FLOAT32);
+        multi_array->set_datatype(ArrayFeatureType::FLOAT32);
         break;
       case ONNX_NAMESPACE::TensorProto_DataType_INT32:
-        multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::INT32);
+        multi_array->set_datatype(ArrayFeatureType::INT32);
         break;
       case ONNX_NAMESPACE::TensorProto_DataType_INT64:
         // If we have an int64 input/output type, since COREML_SPEC:ArrayFeatureType does not support INT64
         // we assign it to be INT32 here
-        multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::INT32);
+        multi_array->set_datatype(ArrayFeatureType::INT32);
         if (!is_input) {
           // Record the output names and we need to change them back to Int64 when CoreML EP returns these values to ORT
           AddInt64Output(name);
@@ -204,6 +834,26 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
 
   input_output_info_.emplace(name, OnnxTensorInfo{data_type, shape});
 
+#if defined(COREML_ENABLE_MLPROGRAM)
+  if (create_ml_program_) {
+    if (is_input) {
+      // the model inputs need to be wired up as args to the 'main' function.
+      auto tensor_value_type = CreateNamedTensorValueType(node_arg);
+      tensor_value_type.set_name(name);
+      if (node_arg.Shape()->dim_size() == 0) {
+        // update shape from {} to {1} (same change we made at the model input level above).
+        tensor_value_type.mutable_type()->mutable_tensortype()->set_rank(1);
+        tensor_value_type.mutable_type()->mutable_tensortype()->add_dimensions()->mutable_constant()->set_size(1);
+      }
+
+      mlprogram_main_fn_->mutable_inputs()->Add(std::move(tensor_value_type));
+    } else {
+      // the model outputs need to be set as outputs of the Block for the 'main' function
+      *mlprogram_main_block_->mutable_outputs()->Add() = name;
+    }
+  }
+#endif  // defined(COREML_ENABLE_MLPROGRAM)
+
   return Status::OK();
 }
 
@@ -215,16 +865,16 @@ Status ModelBuilder::RegisterModelInputs() {
   return Status::OK();
 }
 
-Status ModelBuilder::AddOperations() {
-  const auto builder_params = MakeOpBuilderParams(graph_viewer_, coreml_flags_);
-  const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder();
-  for (size_t i = 0; i < node_indices.size(); i++) {
-    const auto* node(graph_viewer_.GetNode(node_indices[i]));
-    if (const auto* op_builder = GetOpBuilder(*node)) {
-      ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, *node, builder_params, logger_));
+Status ModelBuilder::ProcessNodes() {
+  for (const auto node_idx : graph_viewer_.GetNodesInTopologicalOrder()) {
+    const auto& node = *graph_viewer_.GetNode(node_idx);
+    if (const auto* op_builder = GetOpBuilder(node)) {
+      ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, node, logger_));
     } else {
+      // This shouldn't happen as this is called from CoreMLExecutionProvider::Compile and should only be processing
+      // nodes that we said were supported and were returned from CoreMLExecutionProvider::GetCapability.
       return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
-                             "Node [", node->Name(), "], type [", node->OpType(), "] is not supported");
+                             "Node [", node.Name(), "], type [", node.OpType(), "] was not able to be processed");
     }
   }
 
@@ -239,29 +889,121 @@ Status ModelBuilder::RegisterModelOutputs() {
   return Status::OK();
 }
 
-Status ModelBuilder::Compile(std::unique_ptr<Model>& model, const std::string& path) {
-  ORT_RETURN_IF_ERROR(SaveCoreMLModel(path));
-  model.reset(new Model(path, logger_, coreml_flags_));
-  model->SetScalarOutputs(std::move(scalar_outputs_));
-  model->SetInt64Outputs(std::move(int64_outputs_));
-  model->SetInputOutputInfo(std::move(input_output_info_));
-  return model->LoadModel();
+Status ModelBuilder::CreateModel() {
+  PreprocessInitializers();
+
+  ORT_RETURN_IF_ERROR(RegisterInitializers());
+  ORT_RETURN_IF_ERROR(RegisterModelInputs());
+  ORT_RETURN_IF_ERROR(ProcessNodes());
+  ORT_RETURN_IF_ERROR(RegisterModelOutputs());
+
+#if defined(COREML_ENABLE_MLPROGRAM)
+  if (create_ml_program_) {
+    SanitizeNames();
+  }
+#endif
+
+  return Status::OK();
 }
 
-Status ModelBuilder::SaveCoreMLModel(const std::string& path) {
-  ORT_RETURN_IF_ERROR(Initialize());
-  std::ofstream stream(path, std::ofstream::out | std::ofstream::binary);
-  ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&stream), "Save the CoreML model failed");
+Status ModelBuilder::SaveModel() {
+  std::string output_path = model_output_path_;
+
+#if defined(COREML_ENABLE_MLPROGRAM)
+  if (create_ml_program_) {
+    std::string tmp_model_path = model_output_path_ + "/tmp/model.mlmodel";
+    CreateEmptyFile(tmp_model_path);
+
+    std::string model_id = mlpackage_->setRootModel(tmp_model_path, "model.mlmodel", "com.microsoft.OnnxRuntime",
+                                                    "CoreML Model Specification");
+    auto model_info = mlpackage_->findItem(model_id);
+    output_path = model_info->path();
+  }
+#endif
 
-  // TODO, Delete, debug only
-  if (const char* path = std::getenv("ORT_COREML_EP_CONVERTED_MODEL_PATH")) {
-    std::ofstream temp_stream(path, std::ofstream::out | std::ofstream::binary);
-    ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&temp_stream), "Save the CoreML model failed");
+  // scope this so the stream is closed and flushed by the ofstream dtor
+  {
+    LOGS(logger_, INFO) << "Writing CoreML Model to " << output_path;
+    std::ofstream stream(output_path, std::ofstream::out | std::ofstream::binary);
+    ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&stream), "Saving the CoreML model failed. Path=", output_path);
   }
 
+#if defined(COREML_ENABLE_MLPROGRAM)
+  // need to delete the ModelPackage instance for it to write out the manifest. clear out the other ML Program
+  // related types as well.
+  mlprogram_main_block_ = nullptr;
+  mlpackage_.reset();
+  weights_file_writer_.reset();
+#endif
+
   return Status::OK();
 }
 
+Status ModelBuilder::LoadModel(std::unique_ptr<Model>& model) {
+#if defined(COREML_ENABLE_MLPROGRAM)
+  if (create_ml_program_) {
+    // we need to provide the sanitized names for model inputs/outputs so that info is captured.
+    // the input/output matching when we execute the model from the CoreML EP is based on order, so the change
+    // to the names doesn't matter for that.
+    auto get_sanitized_names = [this](std::vector<std::string>&& names) -> std::vector<std::string> {
+      std::vector<std::string> output(std::move(names));
+
+      for (std::string& name : output) {
+        name = GetSafeName(name);
+      }
+
+      return output;
+    };
+
+    // also need to update the keys in input_output_info_
+    auto get_sanitized_io_info = [this](std::unordered_map<std::string, OnnxTensorInfo>&& info) {
+      std::unordered_map<std::string, OnnxTensorInfo> output;
+      output.reserve(info.size());
+
+      for (auto entry = info.begin(), end = info.end(); entry != end; ++entry) {
+        output.emplace(GetSafeName(entry->first), std::move(entry->second));
+      }
+
+      return output;
+    };
+
+    model = std::make_unique<Model>(model_output_path_,
+                                    get_sanitized_names(std::move(onnx_input_names_)),
+                                    get_sanitized_names(std::move(onnx_output_names_)),
+                                    get_sanitized_io_info(std::move(input_output_info_)),
+                                    std::move(scalar_outputs_),
+                                    std::move(int64_outputs_),
+                                    logger_, coreml_flags_);
+  } else
+#endif
+  {
+    model = std::make_unique<Model>(model_output_path_,
+                                    std::move(onnx_input_names_),
+                                    std::move(onnx_output_names_),
+                                    std::move(input_output_info_),
+                                    std::move(scalar_outputs_),
+                                    std::move(int64_outputs_),
+                                    logger_, coreml_flags_);
+  }
+
+  return model->LoadModel();  // load using CoreML API, including compilation
+}
+
+// static
+Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logger& logger,
+                           int32_t coreml_version, uint32_t coreml_flags,
+                           std::vector<std::string>&& onnx_input_names,
+                           std::vector<std::string>&& onnx_output_names,
+                           std::unique_ptr<Model>& model) {
+  ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_flags,
+                       std::move(onnx_input_names), std::move(onnx_output_names));
+
+  ORT_RETURN_IF_ERROR(builder.CreateModel());
+  ORT_RETURN_IF_ERROR(builder.SaveModel());
+
+  return builder.LoadModel(model);
+}
+
 void ModelBuilder::AddScalarOutput(const std::string& output_name) {
   scalar_outputs_.insert(output_name);
 }
@@ -270,11 +1012,6 @@ void ModelBuilder::AddInt64Output(const std::string& output_name) {
   int64_outputs_.insert(output_name);
 }
 
-void ModelBuilder::AddLayer(std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer) {
-  auto* neural_network = coreml_model_->mutable_neuralnetwork();
-  neural_network->mutable_layers()->AddAllocated(layer.release());
-}
-
 void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) {
   // decrement usage count if this is a known initializer.
   // For simplicity the OpBuilder::AddInitializersToSkip implementations may call this for arbitrary input names
@@ -289,16 +1026,34 @@ void ModelBuilder::AddInputToSkip(const std::string& input_name) {
   skipped_inputs_.insert(input_name);
 }
 
-std::string ModelBuilder::GetUniqueName(const std::string& base_name) {
+const std::string& ModelBuilder::GetUniqueName(const std::string& base_name) {
+  if (unique_names_.find(base_name) == unique_names_.end()) {
+    return *unique_names_.insert(base_name).first;
+  }
+
   std::string unique_name;
-  do {
-    std::ostringstream os;
-    os << base_name << "_token_" << name_token_++;
-    unique_name = os.str();
-  } while (Contains(unique_names_, unique_name));
+  std::string suffix;
+
+  // supports up to 1000 unique names without having to grow in the loop
+  unique_name.reserve(base_name.size() + 5);
+  unique_name = base_name;
+
+  while (Contains(unique_names_, unique_name)) {
+    // assign followed by += to avoid creating temporary strings.
+    unique_name = base_name;
+    unique_name += "__";
+    unique_name += std::to_string(name_token_++);
+  }
 
-  return unique_name;
+  return *unique_names_.insert(unique_name).first;
 }
 
+const std::string& ModelBuilder::GetUniqueName(const Node& node, std::string_view suffix) {
+  if (node.Name().empty()) {
+    return GetUniqueName(MakeString(node.OpType(), "_", node.Index(), suffix));
+  } else {
+    return GetUniqueName(node.Name() + std::string(suffix));
+  }
+}
 }  // namespace coreml
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h
index af2d5437be8d1..8f85ab2c09e7c 100644
--- a/onnxruntime/core/providers/coreml/builders/model_builder.h
+++ b/onnxruntime/core/providers/coreml/builders/model_builder.h
@@ -3,57 +3,175 @@
 
 #pragma once
 
+#include "core/common/span_utils.h"
 #include "core/graph/graph_viewer.h"
 #include "core/providers/coreml/builders/coreml_spec.h"
+#include "core/providers/coreml/model/model.h"
+
+#if defined(COREML_ENABLE_MLPROGRAM)
+// coremltools classes
+namespace MPL {
+class ModelPackage;
+}
+
+namespace MILBlob {
+namespace Blob {
+class StorageWriter;
+}
+}  // namespace MILBlob
+#endif
 
 namespace onnxruntime {
 namespace coreml {
 
 class IOpBuilder;
-class Model;
-struct OnnxTensorInfo;
 
 class ModelBuilder {
+ private:
+  ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
+               int32_t coreml_version, uint32_t coreml_flags,
+               std::vector<std::string>&& onnx_input_names,
+               std::vector<std::string>&& onnx_output_names);
+
  public:
-  ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, uint32_t coreml_flags);
-  ~ModelBuilder() = default;
+  // Create the CoreML model, serialize to disk, load and compile using the CoreML API and return in `model`
+  static Status Build(const GraphViewer& graph_viewer, const logging::Logger& logger,
+                      int32_t coreml_version, uint32_t coreml_flags,
+                      std::vector<std::string>&& onnx_input_names,
+                      std::vector<std::string>&& onnx_output_names,
+                      std::unique_ptr<Model>& model);
 
-  Status Compile(std::unique_ptr<Model>& model, const std::string& path);
-  Status SaveCoreMLModel(const std::string& path);
+  ~ModelBuilder();
 
-  // Accessors for members
   const GraphViewer& GetGraphViewer() const { return graph_viewer_; }
   const InitializedTensorSet& GetInitializerTensors() const { return graph_viewer_.GetAllInitializedTensors(); }
-
+  const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name) const {
+    return graph_viewer_.GetConstantInitializer(name, true);
+  }
+
+  // Since CoreML 2 the spec version is +1 as CoreML 1.1 was spec version 2.
+  // We only support CoreML 3 and later so the spec version is always version + 1.
+  int32_t CoreMLVersion() const { return coreml_version_; }
+  int32_t CoreMLSpecVersion() const { return coreml_version_ + 1; }
+
+  // Returns true if we are creating an ML Program
+  bool CreateMLProgram() const {
+#if defined(COREML_ENABLE_MLPROGRAM)
+    return create_ml_program_;
+#else
+    return false;
+#endif
+  }
+
+  /*
+   * NeuralNetworkLayer helpers
+   */
+
+  // Create a NeuralNetwork layer using the node name and optional suffix for the name.
+  // If Node has no name a unique name will be generated from the node index and operator.
+  std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> CreateNNLayer(const Node& node, std::string_view suffix = "");
+
+  // Add layer to the Core ML NeuralNetwork model
   void AddLayer(std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer);
 
-  // The initializer will be processed separately, skip it as an initializer
+#if defined(COREML_ENABLE_MLPROGRAM)
+  /*
+   * MLProgram helpers
+   */
+
+  // Create Operation, set type and the unique name attribute.
+  std::unique_ptr<COREML_SPEC::MILSpec::Operation> CreateOperation(const Node& node, std::string_view op_type,
+                                                                   std::string_view suffix = "");
+
+  //
+  // Helpers for adding attributes from ONNX nodes as inputs to an ML Program Operation
+  //
+
+  /// <summary>
+  /// Add a value as a 'const' operation, generating a unique name for the value from op_type and value_type.
+  /// Use for values that were not initializers in the original ONNX model. e.g. attributes from ONNX nodes.
+  /// Add existing initializers using AddConstant with the TensorProto.
+  ///
+  /// e.g. adding the bias input of Gemm would have op_type='gemm' and value_type='bias'.
+  /// </summary>
+  /// <typeparam name="T">Value type.</typeparam>
+  /// <param name="op_type">Typically MILSpec::Operation.type().</param>
+  /// <param name="value_type">Typically the input name of the operation that will consume the value.</param>
+  /// <param name="value">Value to add.</param>
+  /// <param name="shape">Optional shape for the value.
+  /// If T is a primitive type `shape` is ignored and the value is treated as a scalar.
+  /// For a container type, if `shape` is not provided the shape is inferred to be 1-D of {value.size()}.
+  /// </param>
+  /// <returns>Unique name generated for value.</returns>
+  template <typename T>
+  std::string_view AddConstant(std::string_view op_type, std::string_view value_type, gsl::span<const T> value,
+                               std::optional<gsl::span<const int64_t>> shape = std::nullopt) {
+    static_assert(std::is_same_v<T, float> ||
+                      std::is_same_v<T, int64_t> ||
+                      std::is_same_v<T, std::string> ||
+                      std::is_same_v<T, bool>,
+                  // add specialization in AddConstantImpl for new types if needed
+                  "AddConstant currently supports float, int64_t, std::string and bool.");
+    return AddConstantImpl(op_type, value_type, value, shape);
+  }
+
+  template <typename T>
+  std::string_view AddConstant(std::string_view op_type, std::string_view value_type, const std::vector<T>& value,
+                               std::optional<gsl::span<const int64_t>> shape = std::nullopt) {
+    return AddConstant(op_type, value_type, AsSpan(value), shape);
+  }
+
+  /// <summary>
+  /// Add a scalar value as a 'const' operation. See AddConstant for details.
+  /// </summary>
+  template <typename T>
+  std::string_view AddScalarConstant(std::string_view op_type, std::string_view value_type, const T& value) {
+    return AddConstant(op_type, value_type, AsSpan({value}), AsSpan<const int64_t>({}));
+  }
+
+  // add the operation to the main function
+  void AddOperation(std::unique_ptr<COREML_SPEC::MILSpec::Operation> operation);
+#endif
+
+  /*
+   * General helpers
+   */
+
+  // The initializer is processed separately (e.g. layout is transformed) by the operator builder,
+  // so we don't do a copy of the original initializer into the model.
   void AddInitializerToSkip(const std::string& tensor_name);
 
   // There are some input which will not be used, add it to a list which will not
   // be added to CoreML model, since CoreML does not like input unused
   void AddInputToSkip(const std::string& input_name);
 
-  std::string GetUniqueName(const std::string& base_name);
-
- private:
-  const GraphViewer& graph_viewer_;
-  const logging::Logger& logger_;
-  uint32_t coreml_flags_;
-
-  std::unique_ptr<CoreML::Specification::Model> coreml_model_;
-  std::unordered_set<std::string> scalar_outputs_;
-  std::unordered_set<std::string> int64_outputs_;
-  std::unordered_map<std::string, OnnxTensorInfo> input_output_info_;
+  const std::string& GetUniqueName(const std::string& base_name);
+  const std::string& GetUniqueName(const Node& node, std::string_view suffix);
 
-  std::unordered_map<std::string, int> initializer_usage_;
-  std::unordered_set<std::string> skipped_inputs_;
+  const logging::Logger& Logger() const { return logger_; }
 
-  uint32_t name_token_{0};
-  std::unordered_set<std::string> unique_names_;
-
-  // Convert the onnx model to CoreML::Specification::Model
-  Status Initialize();
+ private:
+#if defined(COREML_ENABLE_MLPROGRAM)
+  template <typename T>
+  std::string_view AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span<const T> value,
+                                   std::optional<gsl::span<const int64_t>> shape = std::nullopt);
+
+  // apply the CoreML naming rules and fix any invalid names.
+  const std::string& GetSafeName(const std::string& name);
+  // sanitize all the names in the ML Model
+  void SanitizeNames();
+
+  // add Value as a const operation. return value name in case sanitization changed it
+  const std::string& AddConstantOperation(std::string_view name, COREML_SPEC::MILSpec::Value&& initializer);
+  const std::string& AddTensorValueAsConstantOperation(std::string_view op_type, std::string_view value_type,
+                                                       COREML_SPEC::MILSpec::Value&& input_value);
+#endif
+
+  // Convert the ONNX model in graph_viewer_ to a CoreML::Specification::Model and serialize to disk.
+  // We then load it using CoreML in order compile it.
+  Status CreateModel();
+  Status SaveModel();
+  Status LoadModel(std::unique_ptr<Model>& model);
 
   // If a CoreML operation will use initializers directly, we will add the initializers to the skip list
   void PreprocessInitializers();
@@ -61,7 +179,7 @@ class ModelBuilder {
   // Copy and process all the initializers to CoreML model
   Status RegisterInitializers();
 
-  Status AddOperations();
+  Status ProcessNodes();
   Status RegisterModelInputs();
   Status RegisterModelOutputs();
   Status RegisterModelInputOutput(const NodeArg& node_arg, bool is_input);
@@ -72,7 +190,45 @@ class ModelBuilder {
   // Record the onnx int64 type output names
   void AddInt64Output(const std::string& output_name);
 
-  static const IOpBuilder* GetOpBuilder(const Node& node);
+  const GraphViewer& graph_viewer_;
+  const logging::Logger& logger_;
+  const int32_t coreml_version_;
+  const uint32_t coreml_flags_;
+  const bool create_ml_program_;         // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old)
+  const std::string model_output_path_;  // create_ml_program_ ? dir for mlpackage : filename for mlmodel
+
+  std::vector<std::string> onnx_input_names_;
+  std::vector<std::string> onnx_output_names_;
+
+  std::unique_ptr<CoreML::Specification::Model> coreml_model_;
+  std::unordered_set<std::string> scalar_outputs_;
+  std::unordered_set<std::string> int64_outputs_;
+  std::unordered_map<std::string, OnnxTensorInfo> input_output_info_;
+
+  std::unordered_map<std::string, int> initializer_usage_;
+  std::unordered_set<std::string> skipped_inputs_;
+
+  uint32_t name_token_{0};
+  std::unordered_set<std::string> unique_names_;
+
+#if defined(COREML_ENABLE_MLPROGRAM)
+  // mlprogram_main_ is the main block of the CoreML ML Program.
+  // It is set in CreateModel to the CoreML Model.mlprogram.functions['main'].block_specializations['CoreML<ver>']
+  // entry we create.
+  COREML_SPEC::MILSpec::Function* mlprogram_main_fn_{nullptr};  // Function that contains a Block with the operations
+  COREML_SPEC::MILSpec::Block* mlprogram_main_block_{nullptr};  // Block that all the operations are added to
+  std::unique_ptr<MPL::ModelPackage> mlpackage_;
+  std::unique_ptr<MILBlob::Blob::StorageWriter> weights_file_writer_;
+
+  // Values must start with [a-zA-A_]
+  // Additionally they can't be in a list of reserved words.
+  // If we need to sanitize an initializer name we do so during PreprocessInitializers and apply the change during
+  // RegisterInitializers.
+  // We also check inputs in AddOperation and apply the change there.
+  // This means an op builder author doesn't need to be aware of the renaming.
+  // https://github.com/apple/coremltools/blob/8b37641f243b1a3e81452feea311c6e30dcc9287/coremltools/converters/mil/mil/passes/defs/preprocess.py#L146-L149
+  std::unordered_map<std::string, std::string> values_to_rename_;
+#endif
 };
 
 }  // namespace coreml
diff --git a/onnxruntime/core/providers/coreml/builders/op_builder.h b/onnxruntime/core/providers/coreml/builders/op_builder.h
index 79de6438c9700..0bb7f280c33e6 100644
--- a/onnxruntime/core/providers/coreml/builders/op_builder.h
+++ b/onnxruntime/core/providers/coreml/builders/op_builder.h
@@ -11,36 +11,39 @@ namespace coreml {
 class ModelBuilder;
 
 struct OpBuilderInputParams {
-  OpBuilderInputParams(const GraphViewer& graph_viewer, bool only_allow_static_input_shapes)
+  OpBuilderInputParams(const GraphViewer& graph_viewer,
+                       int32_t coreml_version,
+                       bool only_allow_static_input_shapes,
+                       bool create_mlprogram)
       : graph_viewer(graph_viewer),
-        only_allow_static_input_shapes(only_allow_static_input_shapes) {}
+        coreml_version(coreml_version),
+        only_allow_static_input_shapes(only_allow_static_input_shapes),
+        create_mlprogram(create_mlprogram) {}
 
   const GraphViewer& graph_viewer;
+  const int32_t coreml_version;  // required to determine which version of an operation can be used.
   const bool only_allow_static_input_shapes;
+  const bool create_mlprogram;  // whether to create ML Program (Core ML 5+) or NeuralNetwork (Core ML 3+)
 };
 
 class IOpBuilder {
  public:
   virtual ~IOpBuilder() = default;
 
-  // Add operator related
-#ifdef __APPLE__
- public:
   // Check if the initializers of this operator need preprocess
   // which will not be copied
   virtual void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const = 0;
 
   // Add the operator to CoreML model
   virtual Status AddToModelBuilder(ModelBuilder& model_builder, const Node& node,
-                                   const OpBuilderInputParams& input_params,
                                    const logging::Logger& logger) const = 0;
-#endif
 
-  // Operator support related
- public:
   // Check if an operator is supported
   virtual bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params,
                              const logging::Logger& logger) const = 0;
+
+  // Does the builder implementation support creating an ML Program?
+  virtual bool SupportsMLProgram() const = 0;
 };
 
 }  // namespace coreml
diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h
index d72420bcfff88..6469b4cefa5ea 100644
--- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h
+++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h
@@ -3,7 +3,7 @@
 
 #pragma once
 
-#include "op_builder.h"
+#include "core/providers/coreml/builders/op_builder.h"
 
 namespace onnxruntime {
 namespace coreml {
diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc
index c133f7b82aba4..0ba715cc7c6d9 100644
--- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc
+++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc
@@ -2,9 +2,11 @@
 // Licensed under the MIT License.
 
 #include "core/providers/coreml/coreml_execution_provider.h"
+#include "core/providers/coreml/coreml_provider_factory.h"  // defines flags
 
 #include <algorithm>
 
+#include "core/common/logging/logging.h"
 #include "core/framework/compute_capability.h"
 #include "core/framework/tensorprotoutils.h"
 #include "core/graph/graph_viewer.h"
@@ -12,12 +14,10 @@
 #include "core/providers/partitioning_utils.h"
 #include "core/session/onnxruntime_cxx_api.h"
 
-#ifdef __APPLE__
 #include "core/providers/coreml/builders/model_builder.h"
 #include "core/providers/coreml/model/host_utils.h"
 #include "core/providers/coreml/model/model.h"
 #include "core/providers/coreml/shape_utils.h"
-#endif
 
 namespace onnxruntime {
 
@@ -25,7 +25,24 @@ constexpr const char* COREML = "CoreML";
 
 CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags)
     : IExecutionProvider{onnxruntime::kCoreMLExecutionProvider},
-      coreml_flags_(coreml_flags) {
+      coreml_flags_(coreml_flags),
+      coreml_version_(coreml::util::CoreMLVersion()) {
+  if (coreml_version_ < MINIMUM_COREML_VERSION) {
+    LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform.";
+  }
+
+#if defined(COREML_ENABLE_MLPROGRAM)
+  if (coreml_version_ < MINIMUM_COREML_MLPROGRAM_VERSION &&
+      (coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) {
+    LOGS_DEFAULT(WARNING) << "ML Program is not supported on this OS version. Falling back to NeuralNetwork.";
+    coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM;
+  }
+#else
+  if ((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) {
+    LOGS_DEFAULT(WARNING) << "ML Program is not supported in this build. Falling back to NeuralNetwork.";
+    coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM;
+  }
+#endif
 }
 
 CoreMLExecutionProvider::~CoreMLExecutionProvider() {}
@@ -35,28 +52,34 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
                                        const IKernelLookup& /*kernel_lookup*/) const {
   std::vector<std::unique_ptr<ComputeCapability>> result;
 
-  // We do not run CoreML EP on subgraph, instead we cover this in the control flow nodes
-  // TODO investigate whether we want to support subgraph using CoreML EP
-  if (graph_viewer.IsSubgraph() && !(coreml_flags_ & COREML_FLAG_ENABLE_ON_SUBGRAPH)) {
+  if (coreml_version_ < MINIMUM_COREML_VERSION) {
     return result;
   }
 
   const auto& logger = *GetLogger();
 
+  // We do not run CoreML EP on subgraph, instead we cover this in the control flow nodes
+  // TODO investigate whether we want to support subgraph using CoreML EP. May simply require processing the
+  // implicit inputs of the control flow node that contains the subgraph as inputs to the CoreML model we generate.
+  if (graph_viewer.IsSubgraph() && !(coreml_flags_ & COREML_FLAG_ENABLE_ON_SUBGRAPH)) {
+    return result;
+  }
+
   const bool has_neural_engine = coreml::HasNeuralEngine(logger);
   if ((coreml_flags_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) && !has_neural_engine) {
-    LOGS(logger, VERBOSE) << "The current system does not have Apple Neural Engine";
+    LOGS(logger, WARNING) << "The current system does not have Apple Neural Engine. CoreML EP will not be used.";
     return result;
   }
 
-  const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_flags_);
+  const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_version_, coreml_flags_);
   const auto supported_nodes = coreml::GetSupportedNodes(graph_viewer, builder_params, logger);
 
-  const auto gen_metadef_name = [&]() {
-    HashValue model_hash;
-    int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
-    return MakeString(COREML, "_", model_hash, "_", metadef_id);
-  };
+  const auto gen_metadef_name =
+      [&]() {
+        HashValue model_hash;
+        int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
+        return MakeString(COREML, "_", model_hash, "_", metadef_id);
+      };
 
   result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {},
                                             gen_metadef_name, COREML, kCoreMLExecutionProvider);
@@ -86,34 +109,32 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
   return result;
 }
 
-#ifdef __APPLE__
+#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
 common::Status CoreMLExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
                                                 std::vector<NodeComputeInfo>& node_compute_funcs) {
   for (const auto& fused_node_and_graph : fused_nodes_and_graphs) {
     Node& fused_node = fused_node_and_graph.fused_node;
-    const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
 
-    coreml::ModelBuilder builder(graph_viewer, *GetLogger(), coreml_flags_);
     std::unique_ptr<coreml::Model> coreml_model;
-    const std::string coreml_model_file_path = coreml::util::GetTemporaryFilePath();
-    ORT_RETURN_IF_ERROR(builder.Compile(coreml_model, coreml_model_file_path));
-
     {
-      const auto& input_defs = fused_node.InputDefs();
-      std::vector<std::string> onnx_input_names(input_defs.size());
-      for (size_t i = 0, end = input_defs.size(); i < end; ++i) {
-        onnx_input_names[i] = input_defs[i]->Name();
-      }
-      coreml_model->SetOnnxInputs(std::move(onnx_input_names));
-    }
+      auto get_names = [](const ConstPointerContainer<std::vector<NodeArg*>>& args) -> std::vector<std::string> {
+        std::vector<std::string> names;
+        names.reserve(args.size());
 
-    {
-      const auto& output_defs = fused_node.OutputDefs();
-      std::vector<std::string> onnx_output_names(output_defs.size());
-      for (size_t i = 0, end = output_defs.size(); i < end; ++i) {
-        onnx_output_names[i] = output_defs[i]->Name();
-      }
-      coreml_model->SetOnnxOutputs(std::move(onnx_output_names));
+        for (const NodeArg* def : args) {
+          names.push_back(def->Name());
+        }
+
+        return names;
+      };
+
+      std::vector<std::string> onnx_input_names = get_names(fused_node.InputDefs());
+      std::vector<std::string> onnx_output_names = get_names(fused_node.OutputDefs());
+
+      const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
+      ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_flags_,
+                                                      std::move(onnx_input_names), std::move(onnx_output_names),
+                                                      coreml_model));
     }
 
     coreml_models_.emplace(fused_node.Name(), std::move(coreml_model));
@@ -131,13 +152,14 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector<FusedNodeAndGr
 
     compute_info.compute_func = [](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) {
       Ort::KernelContext ctx(context);
-
       const size_t num_inputs = ctx.GetInputCount();
       const size_t num_outputs = ctx.GetOutputCount();
 
       coreml::Model* model = reinterpret_cast<coreml::Model*>(state);
-      const auto& model_inputs = model->GetOnnxInputs();
-      const auto& model_outputs = model->GetOnnxOutputs();
+
+      // input/output names used by the CoreML model in the order that matches the fused_node InputDefs/OutputDefs
+      const auto& model_inputs = model->GetOrderedInputs();
+      const auto& model_outputs = model->GetOrderedOutputs();
 
       ORT_RETURN_IF_NOT(model_inputs.size() <= num_inputs, "Inconsistent input sizes");
       ORT_RETURN_IF_NOT(model_outputs.size() == num_outputs, "Inconsistent output sizes");
@@ -160,28 +182,25 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector<FusedNodeAndGr
 
         // Disallow inputs with dynamic shape which actually have zero elements.
         // CoreML doesn't consistently handle this well (e.g., there may be runtime errors).
-        {
-          const auto& inferred_shape = input_info->shape;
-          ORT_RETURN_IF(!coreml::IsStaticShape(inferred_shape) && coreml::DoesShapeSpecifyZeroElements(shape),
-                        "Input (", input_name, ") has a dynamic shape (", coreml::Shape2String(inferred_shape),
-                        ") but the runtime shape (", coreml::Shape2String(shape),
-                        ") has zero elements. This is not supported by the CoreML EP.");
-        }
+        const auto& inferred_shape = input_info->shape;
+        ORT_RETURN_IF(!coreml::IsStaticShape(inferred_shape) && coreml::DoesShapeSpecifyZeroElements(shape),
+                      "Input (", input_name, ") has a dynamic shape (", coreml::Shape2String(inferred_shape),
+                      ") but the runtime shape (", coreml::Shape2String(shape),
+                      ") has zero elements. This is not supported by the CoreML EP.");
 
         // If we have an empty shape, this is a scalar input,
         // Since all the input output of CoreML EP is MultiArray, we will make the scalar input as a {1} MultiArray
-        if (shape.empty())
+        if (shape.empty()) {
           shape.push_back(1);
+        }
 
         // CoreML MLMultiArray API expect input to be non-const
         // https://developer.apple.com/documentation/coreml/mlmultiarray/2881219-initwithdatapointer?language=objc
         void* inputBuffer = const_cast<void*>(input_tensor.GetTensorRawData());
-        inputs.emplace(
-            input_name,
-            coreml::OnnxTensorData{
-                coreml::OnnxTensorInfo{tensor_info.GetElementType(), shape},
-                inputBuffer,
-            });
+        inputs.emplace(input_name, coreml::OnnxTensorData{
+                                       coreml::OnnxTensorInfo{tensor_info.GetElementType(), shape},
+                                       inputBuffer,
+                                   });
       }
 
       // From this point we will need to take the exclusive lock on the model until the Predict is
@@ -193,14 +212,13 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector<FusedNodeAndGr
         outputs.reserve(model_outputs.size());
 
         coreml::GetOutputTensorMutableRawDataFn get_output_tensor_mutable_raw_data_fn =
-            [&ctx, &model_outputs](
-                const std::string& name,
-                int32_t requested_onnx_tensor_element_type,
-                gsl::span<const int64_t> static_shape) -> void* {
+            [&ctx, &model_outputs](const std::string& name,
+                                   int32_t requested_onnx_tensor_element_type,
+                                   gsl::span<const int64_t> static_shape) -> void* {
           const auto model_output_it = std::find(model_outputs.begin(), model_outputs.end(), name);
           ORT_ENFORCE(model_output_it != model_outputs.end(), "Failed to find CoreML model output name: ", name);
-          const auto output_idx = gsl::narrow_cast<size_t>(std::distance(model_outputs.begin(), model_output_it));
 
+          const auto output_idx = gsl::narrow_cast<size_t>(std::distance(model_outputs.begin(), model_output_it));
           auto output_tensor = ctx.GetOutput(output_idx, static_shape.data(), static_shape.size());
 
           const auto type_and_shape_info = output_tensor.GetTensorTypeAndShapeInfo();
@@ -221,13 +239,15 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector<FusedNodeAndGr
 
           // Since CoreML EP use {1} MLMultiArray as scalar, if the model output should have empty shape
           // We are going to replace the {1} shape of the output back to {}
-          if (model->IsScalarOutput(output_name))
+          if (model->IsScalarOutput(output_name)) {
             output_shape.clear();
+          }
 
           // Since CoreML EP only accepts int32 output type and onnx requires int64 output,
           // We are going to set the model output (from int32) ->int64
-          if (model->IsInt64Output(output_name))
+          if (model->IsInt64Output(output_name)) {
             output_type = ONNX_NAMESPACE::TensorProto_DataType_INT64;
+          }
 
           outputs.emplace(output_name, coreml::OnnxTensorInfo{output_type, output_shape});
         }
@@ -241,22 +261,6 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector<FusedNodeAndGr
 
   return Status::OK();
 }
-#else
-common::Status CoreMLExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
-                                                std::vector<NodeComputeInfo>& node_compute_funcs) {
-  for (const auto& fused_node_and_graph : fused_nodes_and_graphs) {
-    ORT_UNUSED_PARAMETER(fused_node_and_graph);
-    NodeComputeInfo compute_info;
-    compute_info.create_state_func = [](ComputeContext* /*context*/, FunctionState* /*state*/) { return 0; };
-    compute_info.release_state_func = [](FunctionState /*state*/) {};
-    compute_info.compute_func = [](FunctionState /* state */, const OrtApi* /* api */,
-                                   OrtKernelContext* /* context */) {
-      return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Compute is not supported in this build.");
-    };
-    node_compute_funcs.push_back(compute_info);
-  }
-  return Status::OK();
-}
-#endif  //__APPLE__
+#endif  // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
 
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h
index 0201739547dd1..24a001280eef5 100644
--- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h
+++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h
@@ -3,9 +3,9 @@
 
 #pragma once
 
+#include "core/common/inlined_containers.h"
 #include "core/framework/execution_provider.h"
 #include "core/framework/model_metadef_id_generator.h"
-#include "core/providers/coreml/coreml_provider_factory.h"
 
 namespace onnxruntime {
 namespace coreml {
@@ -26,15 +26,14 @@ class CoreMLExecutionProvider : public IExecutionProvider {
                          std::vector<NodeComputeInfo>& node_compute_funcs) override;
 #endif
 
+ private:
   // The bit flags which define bool options for COREML EP, bits are defined as
   // COREMLFlags in include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
-  const uint32_t coreml_flags_;
-
- private:
-// <fused_node_name, <coreml_model_file_path, compiled_coreml_model>>
-#ifdef __APPLE__
-  std::unordered_map<std::string, std::unique_ptr<onnxruntime::coreml::Model>> coreml_models_;
-#endif
+  uint32_t coreml_flags_;
+  const int32_t coreml_version_;
   ModelMetadefIdGenerator metadef_id_generator_;
+
+  // map of fused_node_name to compiled_coreml_model
+  InlinedHashMap<std::string, std::unique_ptr<onnxruntime::coreml::Model>> coreml_models_;
 };
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/dump_mlprogram_model.py b/onnxruntime/core/providers/coreml/dump_mlprogram_model.py
new file mode 100644
index 0000000000000..a3ceee70684dc
--- /dev/null
+++ b/onnxruntime/core/providers/coreml/dump_mlprogram_model.py
@@ -0,0 +1,27 @@
+import sys
+
+import coremltools as ct
+
+if len(sys.argv) < 2:
+    print(f"Usage: {sys.argv[0]} <path to model.mlmodel in ML Package>")
+    print("If generated by onnxruntime this will be <ML Package root>/Data/com.microsoft.onnxruntime/model.mlmodel")
+    sys.exit(-1)
+
+model_path = sys.argv[1]
+m = ct.models.MLModel(model_path)
+
+spec = m.get_spec()
+print(spec)
+
+# Example code if you want to filter output or do more advanced things
+# main = spec.mlProgram.functions["main"]
+# block = main.block_specializations[main.opset]
+# print(f"{len(block.operations)} operators")
+# for op in block.operations:
+# if op.type == 'const':
+#     if op.attributes["name"].immediateValue.tensor.strings.values[0] == "conv_0_pad_type_0":
+#         print(f"Conv pad_type={op.attributes['val'].immediateValue.tensor.strings.values}")
+#
+# if op.type == 'conv':
+#     #print(op)
+#     pass
diff --git a/onnxruntime/core/providers/coreml/model/host_utils.h b/onnxruntime/core/providers/coreml/model/host_utils.h
index f7f45bce087bc..a9991ccb945ce 100644
--- a/onnxruntime/core/providers/coreml/model/host_utils.h
+++ b/onnxruntime/core/providers/coreml/model/host_utils.h
@@ -8,10 +8,50 @@
 
 #include <string>
 
-#define API_AVAILABLE_OS_VERSIONS API_AVAILABLE(macos(10.15), ios(13))
+#if defined(__APPLE__)
+// See https://apple.github.io/coremltools/mlmodel/Format/Model.html for the info on each CoreML specification version.
+// See https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html for the list of ops
+// in each CoreML specification version.
 
-// Base requireed OS to run CoreML Specification Version 4 (Core ML 3)
-#define HAS_VALID_BASE_OS_VERSION @available(macOS 10.15, iOS 13, *)
+// Specification Versions : OS Availability(Core ML Version)
+//
+// 4 : iOS 13, macOS 10.15, tvOS 13, watchOS 6 (Core ML 3)
+//     - initial version of CoreML EP
+// 5 : iOS 14, macOS 11, tvOS 14, watchOS 7 (Core ML 4)
+//     - additional layers in NeuralNetwork but currently none are implemented by the CoreML EP
+// 6 : iOS 15, macOS 12, tvOS 15, watchOS 8 (Core ML 5)
+//     - adds MLProgram (MILSpec.Program)
+//     - iOS 15 ops
+// 7 : iOS 16, macOS 13, tvOS 16, watchOS 9 (Core ML 6)
+//     - iOS 16 ops
+// 8 : iOS 17, macOS 14, tvOS 17, watchOS 10 (Core ML 7)
+//     - iOS 17 ops
+//
+// **NOTE** We use the Core ML version not the spec version.
+//
+// e.g. iOS 13 has Core ML 3 (which is Core ML Specification version 4), and the related macros are
+// API_AVAILABLE_COREML3, HAS_COREML3_OR_LATER and onnxruntime::coreml::util::CoreMLVersion() will return 3.
+
+// https://developer.apple.com/documentation/swift/marking-api-availability-in-objective-c
+// API_AVAILABLE is used to decorate Objective-C APIs
+#define API_AVAILABLE_COREML3 API_AVAILABLE(macos(10.15), ios(13))
+#define API_AVAILABLE_COREML4 API_AVAILABLE(macos(11), ios(14))
+#define API_AVAILABLE_COREML5 API_AVAILABLE(macos(12), ios(15))
+#define API_AVAILABLE_COREML6 API_AVAILABLE(macos(13), ios(16))
+#define API_AVAILABLE_COREML7 API_AVAILABLE(macos(14), ios(17))
+
+// @available is used in implementation code
+// Base required OS to run CoreML Specification Version 4 (Core ML 3)
+#define HAS_COREML3_OR_LATER @available(macOS 10.15, iOS 13, *)
+#define HAS_COREML4_OR_LATER @available(macOS 11, iOS 14, *)
+#define HAS_COREML5_OR_LATER @available(macOS 12, iOS 15, *)
+#define HAS_COREML6_OR_LATER @available(macOS 13, iOS 16, *)
+#define HAS_COREML7_OR_LATER @available(macOS 14, iOS 17, *)
+
+#endif
+
+#define MINIMUM_COREML_VERSION 3            // first version we support
+#define MINIMUM_COREML_MLPROGRAM_VERSION 5  // first version where ML Program was available
 
 namespace onnxruntime {
 namespace coreml {
@@ -21,9 +61,18 @@ namespace util {
 // This corresponds to [CoreML Specification Version 4 (Core ML 3)]
 bool HasRequiredBaseOS();
 
+// Return the CoreML version if 3 or higher. Otherwise returns -1.
+int CoreMLVersion();
+
 // Get a temporary macOS/iOS temp file path
 std::string GetTemporaryFilePath();
 
+#if !defined(NDEBUG) && defined(__APPLE__)
+// Override location the model is written to so that a) it's easily found and b) it is not automatically deleted
+// when the EP exits. Use to debug the model that is generated.
+// See onnxruntime/core/providers/coreml/dump_mlprogram_model.py for a script to dump the ML Program.
+constexpr const char* kOverrideModelOutputDirectoryEnvVar = "ORT_COREML_EP_MODEL_DIR";
+#endif
 }  // namespace util
 }  // namespace coreml
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/model/host_utils.mm b/onnxruntime/core/providers/coreml/model/host_utils.mm
index 4c394386cd37a..5487ea35388f5 100644
--- a/onnxruntime/core/providers/coreml/model/host_utils.mm
+++ b/onnxruntime/core/providers/coreml/model/host_utils.mm
@@ -1,6 +1,7 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
+#include "core/platform/env.h"
 #include "core/providers/coreml/model/host_utils.h"
 
 #import <Foundation/Foundation.h>
@@ -10,19 +11,42 @@
 namespace util {
 
 bool HasRequiredBaseOS() {
-  // This may look strange, but it is required "@available(macOS ....)" to safe-guard some code
-  // otherwise the compiler will spit -Wunsupported-availability-guard
-  if (HAS_VALID_BASE_OS_VERSION)
-    return true;
-  else
-    return false;
+  return CoreMLVersion() >= 3;
+}
+
+int32_t CoreMLVersion() {
+  if (HAS_COREML7_OR_LATER)
+    return 7;
+  if (HAS_COREML6_OR_LATER)
+    return 6;
+  if (HAS_COREML5_OR_LATER)
+    return 5;
+  if (HAS_COREML4_OR_LATER)
+    return 4;
+  if (HAS_COREML3_OR_LATER)
+    return 3;
+
+  return -1;
 }
 
 std::string GetTemporaryFilePath() {
-  // Get temporary directory.
+  // Get temporary directory for user.
   NSURL* temporary_directory_url = [NSURL fileURLWithPath:NSTemporaryDirectory() isDirectory:YES];
+
+#if !defined(NDEBUG)
+  std::string path_override = Env::Default().GetEnvironmentVar(kOverrideModelOutputDirectoryEnvVar);
+  if (!path_override.empty()) {
+    NSString* ns_path_override = [NSString stringWithUTF8String:path_override.c_str()];
+    temporary_directory_url = [NSURL fileURLWithPath:ns_path_override isDirectory:YES];
+  }
+#endif
+
   // Generate a Unique file name to use.
   NSString* temporary_filename = [[NSProcessInfo processInfo] globallyUniqueString];
+
+  // make it easy to see who generated it
+  temporary_filename = [@"onnxruntime-" stringByAppendingString:temporary_filename];
+
   // Create URL to that file.
   NSURL* temporary_file_url = [temporary_directory_url URLByAppendingPathComponent:temporary_filename];
 
diff --git a/onnxruntime/core/providers/coreml/model/host_utils_stub.cc b/onnxruntime/core/providers/coreml/model/host_utils_stub.cc
new file mode 100644
index 0000000000000..5c383b0274e8c
--- /dev/null
+++ b/onnxruntime/core/providers/coreml/model/host_utils_stub.cc
@@ -0,0 +1,40 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include <atomic>
+
+#include "core/platform/env.h"
+#include "core/providers/coreml/model/host_utils.h"
+
+namespace onnxruntime {
+namespace coreml {
+namespace util {
+
+bool HasRequiredBaseOS() {
+  return true;
+}
+
+int CoreMLVersion() {
+  return 7;  // CoreML 7 is the latest we support.
+}
+
+std::string GetTemporaryFilePath() {
+  static std::atomic<int> counter = 0;
+
+  // we want to avoid creating endless directories/names whilst avoiding clashes if tests run in parallel so cycle
+  // through 20 potential output names.
+  auto dir_name = "coreml_ep_test_run." + std::to_string(counter++ % 20);
+
+  // to replicate the iOS/macOS host_utils.mm behavior where the output is <user temporary directory>/<unique_name>
+  // we want to return the name of something that does not exist. this is required for ML Package creation.
+  auto& env = Env::Default();
+  if (env.FolderExists(dir_name)) {
+    ORT_THROW_IF_ERROR(env.DeleteFolder(ToPathString(dir_name)));
+  }
+
+  return dir_name;
+}
+
+}  // namespace util
+}  // namespace coreml
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h
index 105b6a0333b15..e3cd43d786fc3 100644
--- a/onnxruntime/core/providers/coreml/model/model.h
+++ b/onnxruntime/core/providers/coreml/model/model.h
@@ -33,59 +33,62 @@ using GetOutputTensorMutableRawDataFn = std::function<void*(const std::string& n
                                                             gsl::span<const int64_t> static_shape)>;
 
 class Model {
-  friend class ModelBuilder;
-
  public:
+  Model(const std::string& path,
+        std::vector<std::string>&& model_input_names,
+        std::vector<std::string>&& model_output_names,
+        std::unordered_map<std::string, OnnxTensorInfo>&& input_output_info,
+        std::unordered_set<std::string>&& scalar_outputs,
+        std::unordered_set<std::string>&& int64_outputs,
+        const logging::Logger& logger, uint32_t coreml_flags);
+
   ~Model();
   ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model);
 
+  Status LoadModel();
+
   Status Predict(const std::unordered_map<std::string, OnnxTensorData>& inputs,
                  const std::unordered_map<std::string, OnnxTensorInfo>& outputs,
                  const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn);
 
-  bool IsScalarOutput(const std::string& output_name) const;
+  bool IsScalarOutput(const std::string& output_name) const {
+    return Contains(scalar_outputs_, output_name);
+  }
 
-  bool IsInt64Output(const std::string& output_name) const;
+  bool IsInt64Output(const std::string& output_name) const {
+    return Contains(int64_outputs_, output_name);
+  }
 
   // Mutex for exclusive lock to this model object
   OrtMutex& GetMutex() { return mutex_; }
 
-  // Input and output names in the onnx model's order
-  const std::vector<std::string>& GetOnnxInputs() const { return onnx_inputs_; }
-  void SetOnnxInputs(std::vector<std::string>&& inputs) { onnx_inputs_ = std::move(inputs); }
+  // Input and output names in the ORT fused node's order.
+  // Names may have been adjusted from the originals due to CoreML naming rules.
+  // We do inputs/outputs based on order at the ONNX level so this doesn't matter.
+  const std::vector<std::string>& GetOrderedInputs() const { return model_input_names_; }
+  const std::vector<std::string>& GetOrderedOutputs() const { return model_output_names_; }
 
-  const std::vector<std::string>& GetOnnxOutputs() const { return onnx_outputs_; }
-  void SetOnnxOutputs(std::vector<std::string>&& outputs) { onnx_outputs_ = std::move(outputs); }
+  const OnnxTensorInfo* TryGetInputOutputInfo(const std::string& name) const {
+    const auto info_it = input_output_info_.find(name);
+    return info_it != input_output_info_.end() ? &info_it->second : nullptr;
+  }
 
-  const OnnxTensorInfo* TryGetInputOutputInfo(const std::string& name) const;
-  const OnnxTensorInfo& GetInputOutputInfo(const std::string& name) const;
+  const OnnxTensorInfo& GetInputOutputInfo(const std::string& name) const {
+    const auto* info = TryGetInputOutputInfo(name);
+    ORT_ENFORCE(info != nullptr, "Failed to get info for input/output: ", name);
+    return *info;
+  }
 
  private:
   std::unique_ptr<Execution> execution_;
-  std::unordered_set<std::string> scalar_outputs_;
-  std::unordered_set<std::string> int64_outputs_;
-
-  std::vector<std::string> onnx_inputs_;
-  std::vector<std::string> onnx_outputs_;
+  std::vector<std::string> model_input_names_;   // input names in the order of the ORT fused node's inputs
+  std::vector<std::string> model_output_names_;  // output names in the order of the ORT fused node's outputs
 
   std::unordered_map<std::string, OnnxTensorInfo> input_output_info_;
+  std::unordered_set<std::string> scalar_outputs_;
+  std::unordered_set<std::string> int64_outputs_;
 
   OrtMutex mutex_;
-
-  Model(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags);
-  Status LoadModel();
-
-  void SetInputOutputInfo(std::unordered_map<std::string, OnnxTensorInfo>&& input_output_info) {
-    input_output_info_ = std::move(input_output_info);
-  }
-
-  void SetScalarOutputs(std::unordered_set<std::string>&& scalar_outputs) {
-    scalar_outputs_ = std::move(scalar_outputs);
-  }
-
-  void SetInt64Outputs(std::unordered_set<std::string>&& int64_outputs) {
-    int64_outputs_ = std::move(int64_outputs);
-  }
 };
 
 }  // namespace coreml
diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm
index 155201ad4c39c..1434043e064f4 100644
--- a/onnxruntime/core/providers/coreml/model/model.mm
+++ b/onnxruntime/core/providers/coreml/model/model.mm
@@ -19,6 +19,7 @@
 #include "core/common/narrow.h"
 #include "core/common/span_utils.h"
 #include "core/graph/onnx_protobuf.h"
+#include "core/platform/env.h"
 #include "core/providers/coreml/builders/helper.h"
 #include "core/providers/coreml/coreml_provider_factory.h"
 #include "core/providers/coreml/model/host_utils.h"
@@ -252,14 +253,14 @@ - (instancetype)initWithPath:(const std::string&)path
                 coreml_flags:(uint32_t)coreml_flags;
 - (void)cleanup;
 - (void)dealloc;
-- (Status)loadModel API_AVAILABLE_OS_VERSIONS;
+- (Status)loadModel API_AVAILABLE_COREML3;
 - (Status)predict:(const std::unordered_map<std::string, OnnxTensorData>&)inputs
                   outputs:(const std::unordered_map<std::string, OnnxTensorInfo>&)outputs
     getOutputTensorDataFn:(const GetOutputTensorMutableRawDataFn&)
                               get_output_tensor_mutable_raw_data_fn
-    API_AVAILABLE_OS_VERSIONS;
+    API_AVAILABLE_COREML3;
 
-@property(nullable) MLModel* model API_AVAILABLE_OS_VERSIONS;
+@property(nullable) MLModel* model API_AVAILABLE_COREML3;
 
 @end
 
@@ -287,6 +288,14 @@ - (void)cleanup {
     compiled_model_path_ = nil;
   }
 
+#if !defined(NDEBUG)
+  std::string path_override = Env::Default().GetEnvironmentVar(util::kOverrideModelOutputDirectoryEnvVar);
+  if (!path_override.empty()) {
+    // don't cleanup
+    coreml_model_path_ = nil;
+  }
+#endif
+
   if (coreml_model_path_ != nil) {
     error = nil;
     [[NSFileManager defaultManager] removeItemAtPath:coreml_model_path_ error:&error];
@@ -308,6 +317,10 @@ - (Status)loadModel {
     return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create model URL from path");
   }
 
+  // TODO: Update this to version with callback handler as the API used here is deprecated.
+  // https://developer.apple.com/documentation/coreml/mlmodel/3929553-compilemodelaturl
+  // As we call loadModel during EP Compile there shouldn't be an issue letting the actual compile run in the
+  // background. We will have to check for completion in `predict` and block until it is done.
   NSError* error = nil;
   NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error];
 
@@ -454,7 +467,7 @@ Status Predict(const std::unordered_map<std::string, OnnxTensorData>& inputs,
     return Status::OK();
   }
 
-  if (HAS_VALID_BASE_OS_VERSION) {
+  if (HAS_COREML3_OR_LATER) {
     Status status{};
     @autoreleasepool {
       status = [execution_ loadModel];
@@ -471,7 +484,7 @@ Status Predict(const std::unordered_map<std::string, OnnxTensorData>& inputs,
                           const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn) {
   ORT_RETURN_IF_NOT(model_loaded, "Execution::Predict requires Execution::LoadModel");
 
-  if (HAS_VALID_BASE_OS_VERSION) {
+  if (HAS_COREML3_OR_LATER) {
     @autoreleasepool {
       return [execution_ predict:inputs
                          outputs:outputs
@@ -482,8 +495,20 @@ Status Predict(const std::unordered_map<std::string, OnnxTensorData>& inputs,
   return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::Predict requires macos 10.15+ or ios 13+");
 }
 
-Model::Model(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags)
-    : execution_(std::make_unique<Execution>(path, logger, coreml_flags)) {
+Model::Model(const std::string& path,
+             std::vector<std::string>&& model_input_names,
+             std::vector<std::string>&& model_output_names,
+             std::unordered_map<std::string, OnnxTensorInfo>&& input_output_info,
+             std::unordered_set<std::string>&& scalar_outputs,
+             std::unordered_set<std::string>&& int64_outputs,
+             const logging::Logger& logger,
+             uint32_t coreml_flags)
+    : execution_(std::make_unique<Execution>(path, logger, coreml_flags)),
+      model_input_names_(std::move(model_input_names)),
+      model_output_names_(std::move(model_output_names)),
+      input_output_info_(std::move(input_output_info)),
+      scalar_outputs_(std::move(scalar_outputs)),
+      int64_outputs_(std::move(int64_outputs)) {
 }
 
 Model::~Model() {}
@@ -497,25 +522,5 @@ Status Predict(const std::unordered_map<std::string, OnnxTensorData>& inputs,
                       const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn) {
   return execution_->Predict(inputs, outputs, get_output_tensor_mutable_raw_data_fn);
 }
-
-bool Model::IsScalarOutput(const std::string& output_name) const {
-  return Contains(scalar_outputs_, output_name);
-}
-
-bool Model::IsInt64Output(const std::string& output_name) const {
-  return Contains(int64_outputs_, output_name);
-}
-
-const OnnxTensorInfo* Model::TryGetInputOutputInfo(const std::string& name) const {
-  const auto info_it = input_output_info_.find(name);
-  return info_it != input_output_info_.end() ? &info_it->second : nullptr;
-}
-
-const OnnxTensorInfo& Model::GetInputOutputInfo(const std::string& name) const {
-  const auto* info = TryGetInputOutputInfo(name);
-  ORT_ENFORCE(info != nullptr, "Failed to get info for input/output: ", name);
-  return *info;
-}
-
 }  // namespace coreml
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/coreml/model/model_stub.cc b/onnxruntime/core/providers/coreml/model/model_stub.cc
new file mode 100644
index 0000000000000..c6f2e7401ea1e
--- /dev/null
+++ b/onnxruntime/core/providers/coreml/model/model_stub.cc
@@ -0,0 +1,42 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/coreml/model/model.h"
+
+namespace onnxruntime {
+namespace coreml {
+
+class Execution {};
+
+Model::Model(const std::string& /*path*/,
+             std::vector<std::string>&& model_input_names,
+             std::vector<std::string>&& model_output_names,
+             std::unordered_map<std::string, OnnxTensorInfo>&& input_output_info,
+             std::unordered_set<std::string>&& scalar_outputs,
+             std::unordered_set<std::string>&& int64_outputs,
+             const logging::Logger& /*logger*/,
+             uint32_t /*coreml_flags*/)
+    : execution_(std::make_unique<Execution>()),
+      model_input_names_(std::move(model_input_names)),
+      model_output_names_(std::move(model_output_names)),
+      input_output_info_(std::move(input_output_info)),
+      scalar_outputs_(std::move(scalar_outputs)),
+      int64_outputs_(std::move(int64_outputs)) {
+}
+
+Model::~Model() {
+}
+
+Status Model::LoadModel() {
+  // return OK so we hit more CoreML EP code.
+  return Status::OK();
+}
+
+Status Model::Predict(const std::unordered_map<std::string, OnnxTensorData>& /*inputs*/,
+                      const std::unordered_map<std::string, OnnxTensorInfo>& /*outputs*/,
+                      const GetOutputTensorMutableRawDataFn& /*get_output_tensor_mutable_raw_data_fn*/) {
+  return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Executing a CoreML model is not supported on this platform.");
+}
+
+}  // namespace coreml
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 813fdc54ecd0d..c3d5a51b636ef 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -143,9 +143,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Aco
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Atan);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm);
-#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
-class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm);
-#endif
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax);
@@ -335,9 +332,6 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm);
-#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
-class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, Gemm);
-#endif
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul);
@@ -497,9 +491,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Sp
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm);
-#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
-class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Gemm);
-#endif
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift);
@@ -606,9 +597,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm);
-#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm);
-#endif
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul);
@@ -726,6 +714,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, BFloat16, IsNaN);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero);
@@ -1035,6 +1024,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, BFloat16, IsNaN);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu);
 #if !defined(DISABLE_FLOAT8_TYPES)
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN);
@@ -2007,8 +1998,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
                                                                 Greater)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Less)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Less)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, Less)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, Less)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13,
+                                                                int32_t, Less)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13,
+                                                                int64_t, Less)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13,
                                                                           float, Add)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13,
@@ -2562,6 +2555,9 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16,
                                                                 IsNaN)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, BFloat16,
+                                                                IsNaN)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu)>,
 #if !defined(DISABLE_FLOAT8_TYPES)
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN,
                                                                 IsNaN)>,
@@ -2613,15 +2609,6 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) {
                                                                             MLFloat16, LeakyRelu)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, MLFloat16,
                                                                   LeakyRelu)>,
-      BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8,
-                                                                            MLFloat16, Gemm)>,
-      BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10,
-                                                                            MLFloat16, Gemm)>,
-      BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12,
-                                                                            MLFloat16, Gemm)>,
-
-      BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16,
-                                                                  Gemm)>,
   };
 
   for (auto& function_table_entry : function_table) {
diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc
index bf73c59fb78ca..c4a83efa01a91 100644
--- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc
+++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc
@@ -25,6 +25,7 @@
 #include "core/providers/cpu/tensor/tile.h"
 #include "core/providers/cpu/tensor/gather_elements.h"
 #include "core/providers/cpu/tensor/unsqueeze.h"
+#include "core/providers/cpu/tensor/upsamplebase.h"
 
 #ifndef DISABLE_CONTRIB_OPS
 #include "contrib_ops/cpu/bert/attention_base.h"
@@ -62,6 +63,7 @@
 #endif
 
 #include "cpu_provider_shared.h"
+#include <limits>
 
 namespace onnxruntime {
 // The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor."
@@ -292,6 +294,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
   Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); }
   Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); }
 
+  void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
+                                              gsl::span<const int64_t> input_dims,
+                                              InlinedVector<float>& scales) const override {
+    p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales);
+  }
+
 #ifdef ENABLE_ATEN
   Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); }
 #endif
diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h
index f33eec4b93e98..c0e674827e4d1 100644
--- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h
+++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h
@@ -24,6 +24,7 @@ class SliceOp__PrepareForComputeMetadata;  // Directly maps to SliceOp::PrepareF
 class UnsqueezeBase__Prepare;              // Directly maps to UnsqueezeBase::Prepare
 class contrib__AdamWOptimizerBase__Prepare;
 class contrib__SGDOptimizerV2Base__Prepare;
+class UpsampleBase;
 
 using PadsVector = InlinedVector<int64_t, kTensorShapeSmallBufferElementsSize * 2>;
 
@@ -202,6 +203,10 @@ struct ProviderHostCPU {
   virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0;
   virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0;
 
+  virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
+                                                      gsl::span<const int64_t> input_dims,
+                                                      InlinedVector<float>& scales) const = 0;
+
 #ifdef ENABLE_ATEN
   virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0;
 #endif
diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc
index 180b3153fbb34..e2981da3a6f25 100644
--- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc
+++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc
@@ -1,6 +1,8 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
+#if !defined(ORT_MINIMAL_BUILD)
+
 #include "core/providers/cpu/ml/tree_ensemble_helper.h"
 #include "core/common/common.h"
 #include "onnx/defs/tensor_proto_util.h"
@@ -64,3 +66,5 @@ Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name
 
 }  // namespace ml
 }  // namespace onnxruntime
+
+#endif  // !defined(ORT_MINIMAL_BUILD)
diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h
index 3c8a5a840bc76..33172c343a88e 100644
--- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h
+++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h
@@ -2,6 +2,9 @@
 // Licensed under the MIT License.
 
 #pragma once
+
+#if !defined(ORT_MINIMAL_BUILD)
+
 #include "core/common/common.h"
 #include "core/framework/op_kernel.h"
 
@@ -13,3 +16,5 @@ Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name
 
 }  // namespace ml
 }  // namespace onnxruntime
+
+#endif  // !defined(ORT_MINIMAL_BUILD)
diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h
index a5d46aff83b50..ccecbabfa3db3 100644
--- a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h
+++ b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h
@@ -25,6 +25,8 @@ class BatchNormHelper {
                                        const Tensor* var,
                                        bool is_spatial = true,
                                        bool is_nhwc = false) {
+    // NHWC dependent shape: X
+    // All other shapes are assumed to be in NCHW layout?
     const auto& x_dims = X->Shape().GetDims();
 
     // If x_dims size < 2, num_channels defaults to 1.
@@ -48,16 +50,22 @@ class BatchNormHelper {
     // validate 'scales' shape
     const auto& scale_dims = scale->Shape().GetDims();
     if (static_cast<int>(scale_dims.size()) != kNumInputScaleDimensions) {
-      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions);
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+                             "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions);
     }
     if (scale_dims[0] != num_channels) {
       return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: 0th dimension != ", num_channels);
     }
+    // N & C do not belong to features
+    // skip the first element for NHWC and the first two elements for NCHW.
+    int feature_offset = is_nhwc ? 1 : 2;
+
     // in non-spatial cases - the other dims of 'scale' must be validated
     if (!is_spatial) {
       for (int feature = 0; feature < num_feature_dims; ++feature) {
-        if (scale_dims[1 + feature] != x_dims[2 + feature]) {
-          return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
+        if (scale_dims[1 + feature] != x_dims[feature_offset + feature]) {
+          return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature),
+                                 " dimension != ", x_dims[feature_offset + feature]);
         }
       }
     }
@@ -65,7 +73,8 @@ class BatchNormHelper {
     // validate 'B' shape
     const auto& B_dims = B->Shape().GetDims();
     if (static_cast<int>(B_dims.size()) != kNumInputBiasDimensions) {
-      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions);
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+                             "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions);
     }
     if (B_dims[0] != num_channels) {
       return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: 0th dimension != ", num_channels);
@@ -73,8 +82,9 @@ class BatchNormHelper {
     // in non-spatial cases - the other dims of 'B' must be validated
     if (!is_spatial) {
       for (int feature = 0; feature < num_feature_dims; ++feature) {
-        if (B_dims[1 + feature] != x_dims[2 + feature]) {
-          return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
+        if (B_dims[1 + feature] != x_dims[feature_offset + feature]) {
+          return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature),
+                                 " dimension != ", x_dims[feature_offset + feature]);
         }
       }
     }
@@ -82,16 +92,19 @@ class BatchNormHelper {
     // validate 'mean' shape
     const auto& mean_dims = mean->Shape().GetDims();
     if (static_cast<int>(mean_dims.size()) != kNumInputMeanDimensions) {
-      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions);
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+                             "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions);
     }
     if (mean_dims[0] != num_channels) {
-      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: 0th dimension != ", num_channels);
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+                             "Invalid input mean: 0th dimension != ", num_channels);
     }
     // in non-spatial cases - the other dims of 'mean' must be validated
     if (!is_spatial) {
       for (int feature = 0; feature < num_feature_dims; ++feature) {
-        if (mean_dims[1 + feature] != x_dims[2 + feature]) {
-          return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
+        if (mean_dims[1 + feature] != x_dims[feature_offset + feature]) {
+          return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature),
+                                 " dimension != ", x_dims[feature_offset + feature]);
         }
       }
     }
@@ -99,7 +112,8 @@ class BatchNormHelper {
     // validate 'var' shape
     const auto& var_dims = var->Shape().GetDims();
     if (static_cast<int>(var_dims.size()) != kNumInputVarianceDimensions) {
-      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions);
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+                             "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions);
     }
     if (var_dims[0] != num_channels) {
       return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: 0th dimension != ", num_channels);
@@ -107,8 +121,9 @@ class BatchNormHelper {
     // in non-spatial cases - the other dims of 'var' must be validated
     if (!is_spatial) {
       for (int feature = 0; feature < num_feature_dims; ++feature) {
-        if (var_dims[1 + feature] != x_dims[2 + feature]) {
-          return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
+        if (var_dims[1 + feature] != x_dims[feature_offset + feature]) {
+          return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature),
+                                 " dimension != ", x_dims[feature_offset + feature]);
         }
       }
     }
diff --git a/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc b/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc
index 8064bc0a58cb1..2913f4ac32b6e 100644
--- a/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc
+++ b/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc
@@ -453,7 +453,7 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu
   int num_remaining_splits = 0;
   InlinedVector<int64_t> split_sizes;
   const bool is_string_type = input.IsDataTypeString();
-  const size_t element_size = (is_string_type) ? 0U : input.DataType()->Size();
+  const size_t element_size = input.DataType()->Size();
 
   // figure out split_scalar or split_sizes
   if (p_split_input) {
diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc
index 15bf633579e5f..50fe7d1344eaf 100644
--- a/onnxruntime/core/providers/cpu/signal/dft.cc
+++ b/onnxruntime/core/providers/cpu/signal/dft.cc
@@ -506,7 +506,7 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside
 
   // Calculate the window size with preference to the window input.
   const auto window_size = window ? window->Shape()[0] : frame_length;
-  ORT_ENFORCE(window_size < signal_size, "Ensure that the dft size is smaller than the signal.");
+  ORT_ENFORCE(window_size <= signal_size, "Ensure that the dft size is smaller than the signal.");
 
   // Calculate the number of dfts to run
   const auto n_dfts =
diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc
new file mode 100644
index 0000000000000..d55973eda180f
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc
@@ -0,0 +1,108 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/common/common.h"
+#include "core/common/narrow.h"
+#include "core/framework/op_kernel.h"
+#include "core/util/math_cpuonly.h"
+#include "core/mlas/inc/mlas.h"
+
+#include "core/platform/threadpool.h"
+#include <unsupported/Eigen/SpecialFunctions>
+#include "core/providers/cpu/element_wise_ranged_transform.h"
+#include "core/providers/cpu/tensor/gelu.h"
+
+using onnxruntime::narrow;
+using namespace onnxruntime::common;
+
+namespace onnxruntime {
+
+// May revisit the implementations to support inplace computation, if needed.
+
+ONNX_CPU_OPERATOR_KERNEL(
+    Gelu,
+    20,
+    KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
+    Gelu<float>);
+
+#ifndef DISABLE_CONTRIB_OPS
+namespace contrib {
+ONNX_OPERATOR_KERNEL_EX(
+    Gelu,
+    kMSDomain,
+    1,
+    kCpuExecutionProvider,
+    KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
+    Gelu<float>);
+}
+#endif
+
+template <typename T>
+Status Gelu<T>::Compute(OpKernelContext* context) const {
+  const Tensor* input = context->Input<Tensor>(0);
+  const T* input_data = input->Data<T>();
+
+  Tensor* output = context->Output(0, input->Shape());
+  T* output_data = output->MutableData<T>();
+
+  concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
+  int64_t elem_count = input->Shape().Size();
+  constexpr int64_t length_per_task = 4096;  // this number comes from FastGelu.
+  int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
+
+  if (approximation_algorithm_ == "tanh") {
+    // FastGelu allows optional bias. Here we split input data into chunks. Each chunk
+    // has N elements (except the last chunk), and use thread pool to parallel chunks.
+    // N = 4096 is selected based on performance test results on input shape 1x128x768.
+    // FastGelu uses approximation for Gelu. The formula is 0.5 * (1 + Tanh(x * (C * x * x + B))) * x.
+    static constexpr float B = 0.7978845608028654f;    // sqrt(2.0 / M_PI)
+    static constexpr float C = 0.035677408136300125f;  // 0.044715 * sqrt(2.0 / M_PI)
+
+    concurrency::ThreadPool::TryBatchParallelFor(
+        tp, static_cast<int32_t>(task_count),
+        [&](ptrdiff_t task_idx) {
+          const auto start = task_idx * length_per_task;
+          const T* p_input = input_data + start;
+          T* p_output = output_data + start;
+          int64_t count = std::min(length_per_task, elem_count - start);
+
+          for (int64_t i = 0; i < count; i++) {
+            T value = p_input[i];
+            p_output[i] = value * (static_cast<T>(C) * value * value + static_cast<T>(B));
+          }
+
+          MlasComputeTanh(p_output, p_output, narrow<size_t>(count));
+
+          for (int64_t i = 0; i < count; i++) {
+            p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f);
+          }
+        },
+        0);
+    return Status::OK();
+  } else if (approximation_algorithm_ == "none") {
+    concurrency::ThreadPool::TryBatchParallelFor(
+        tp, static_cast<int32_t>(task_count),
+        [&](ptrdiff_t task_idx) {
+          const auto start = task_idx * length_per_task;
+          const T* p_input = input_data + start;
+          T* p_output = output_data + start;
+          int64_t count = std::min(length_per_task, elem_count - start);
+
+          for (int64_t i = 0; i < count; i++) {
+            T value = p_input[i];
+            p_output[i] = value * static_cast<T>(M_SQRT1_2);
+          }
+
+          MlasComputeErf(p_output, p_output, narrow<size_t>(count));
+
+          for (int64_t i = 0; i < count; i++) {
+            p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f);
+          }
+        },
+        0);
+    return Status::OK();
+  }
+  return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_);
+}
+
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.h b/onnxruntime/core/providers/cpu/tensor/gelu.h
new file mode 100644
index 0000000000000..13238028d878a
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/tensor/gelu.h
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+namespace onnxruntime {
+
+template <typename T>
+class Gelu final : public OpKernel {
+ public:
+  explicit Gelu(const OpKernelInfo& info) : OpKernel(info) {
+    approximation_algorithm_ = info.GetAttrOrDefault<std::string>("approximate", "none");
+  }
+  Status Compute(OpKernelContext* ctx) const override;
+
+ private:
+  std::string approximation_algorithm_;
+};
+
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/tensor/isinf.cc b/onnxruntime/core/providers/cpu/tensor/isinf.cc
index 1b449f46927a2..9d18d1fa62288 100644
--- a/onnxruntime/core/providers/cpu/tensor/isinf.cc
+++ b/onnxruntime/core/providers/cpu/tensor/isinf.cc
@@ -23,7 +23,9 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
 using IsInfTypesOpset20 =
     TypeList<
         float,
-        double
+        double,
+        MLFloat16,
+        BFloat16
 #if !defined(DISABLE_FLOAT8_TYPES)
         ,
         Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
@@ -76,10 +78,8 @@ ONNX_CPU_OPERATOR_KERNEL(
     IsInf);
 
 IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) {
-  Status status = info.GetAttr("detect_positive", &detect_positive_);
-  ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive");
-  status = info.GetAttr("detect_negative", &detect_negative_);
-  ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative");
+  detect_positive_ = info.GetAttrOrDefault<int64_t>("detect_positive", 1);
+  detect_negative_ = info.GetAttrOrDefault<int64_t>("detect_negative", 1);
   opset_ = info.node().SinceVersion();
 }
 
@@ -87,29 +87,67 @@ namespace isinf_internal {
 template <class T>
 struct ComputeDispatchTarget {
   void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
-    const auto total_items = X.Shape().Size();
+    auto input_data = X.DataAsSpan<T>();
     auto output_data = Y.MutableData<bool>();
 
     if (detect_positive && detect_negative) {
       EigenMap<bool>(Y) = EigenMap<T>(X).array().isInf();
     } else if (detect_positive) {
-      auto input_data = X.Data<T>();
-      auto end_data = input_data + total_items;
       std::transform(
-          input_data, end_data, output_data, [](T v) {
+          input_data.begin(), input_data.end(), output_data, [](T v) {
             return (v == std::numeric_limits<T>::infinity());
           });
 
     } else if (detect_negative) {
-      auto input_data = X.Data<T>();
-      auto end_data = input_data + total_items;
       std::transform(
-          input_data, end_data, output_data, [](T v) {
+          input_data.begin(), input_data.end(), output_data, [](T v) {
             return (v == -std::numeric_limits<T>::infinity());
           });
     } else {
       // all false
-      memset(output_data, false, onnxruntime::narrow<size_t>(total_items));
+      memset(output_data, false, input_data.size());
+    }
+  }
+};
+
+template <>
+struct ComputeDispatchTarget<MLFloat16> {
+  void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
+    auto output_data = Y.MutableData<bool>();
+    auto input_data = X.DataAsSpan<MLFloat16>();
+    if (detect_positive && detect_negative) {
+      std::transform(input_data.begin(), input_data.end(), output_data,
+                     [](MLFloat16 v) { return v.IsInfinity(); });
+    } else if (detect_positive) {
+      std::transform(input_data.begin(), input_data.end(), output_data,
+                     [](MLFloat16 v) { return v.IsPositiveInfinity(); });
+    } else if (detect_negative) {
+      std::transform(input_data.begin(), input_data.end(), output_data,
+                     [](MLFloat16 v) { return v.IsNegativeInfinity(); });
+    } else {
+      // all false
+      memset(output_data, false, input_data.size());
+    }
+  }
+};
+
+template <>
+struct ComputeDispatchTarget<BFloat16> {
+  void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
+    auto output_data = Y.MutableData<bool>();
+    auto input_data = X.DataAsSpan<BFloat16>();
+    if (detect_positive && detect_negative) {
+      std::transform(input_data.begin(), input_data.end(), output_data,
+                     [](BFloat16 v) { return v.IsInfinity(); });
+    } else if (detect_positive) {
+      std::transform(input_data.begin(), input_data.end(), output_data,
+                     [](BFloat16 v) { return v.IsPositiveInfinity(); });
+    } else if (detect_negative) {
+      std::transform(input_data.begin(), input_data.end(), output_data,
+                     [](BFloat16 v) { return v.IsNegativeInfinity(); });
+    } else {
+      // all false
+      memset(output_data, false, input_data.size());
     }
   }
 };
diff --git a/onnxruntime/core/providers/cpu/tensor/isnan.cc b/onnxruntime/core/providers/cpu/tensor/isnan.cc
index 34495e382278a..0e15c64b126f3 100644
--- a/onnxruntime/core/providers/cpu/tensor/isnan.cc
+++ b/onnxruntime/core/providers/cpu/tensor/isnan.cc
@@ -46,9 +46,11 @@ ADD_TYPED_ISNAN_OP_9(MLFloat16);
 ADD_TYPED_ISNAN_OP_13(float);
 ADD_TYPED_ISNAN_OP_13(double);
 ADD_TYPED_ISNAN_OP_13(MLFloat16);
+ADD_TYPED_ISNAN_OP_13(BFloat16);
 ADD_TYPED_ISNAN_OP(float);
 ADD_TYPED_ISNAN_OP(double);
 ADD_TYPED_ISNAN_OP(MLFloat16);
+ADD_TYPED_ISNAN_OP(BFloat16);
 
 #if !defined(DISABLE_FLOAT8_TYPES)
 ADD_TYPED_ISNAN_OP(Float8E4M3FN);
@@ -75,9 +77,7 @@ Status IsNaN<T>::Compute(OpKernelContext* context) const {
 template <>
 Status IsNaN<MLFloat16>::Compute(OpKernelContext* context) const {
   const auto* X_ptr = context->Input<Tensor>(0);
-  if (!X_ptr) {
-    return Status(common::ONNXRUNTIME, common::FAIL, "Null input ptr");
-  }
+
   auto X_data = X_ptr->Data<MLFloat16>();
   auto& dims = X_ptr->Shape();
   auto shape_size = dims.Size();
@@ -91,6 +91,19 @@ Status IsNaN<MLFloat16>::Compute(OpKernelContext* context) const {
   return Status::OK();
 }
 
+template <>
+Status IsNaN<BFloat16>::Compute(OpKernelContext* context) const {
+  const auto* X_ptr = context->Input<Tensor>(0);
+
+  auto X_data = X_ptr->DataAsSpan<BFloat16>();
+  auto& Y = *context->Output(0, X_ptr->Shape());
+
+  std::transform(X_data.begin(), X_data.end(), Y.MutableData<bool>(),
+                 [](BFloat16 x) { return x.IsNaN(); });
+
+  return Status::OK();
+}
+
 #if !defined(DISABLE_FLOAT8_TYPES)
 template <>
 Status IsNaN<Float8E4M3FN>::Compute(OpKernelContext* context) const {
diff --git a/onnxruntime/core/providers/cpu/tensor/reshape_helper.h b/onnxruntime/core/providers/cpu/tensor/reshape_helper.h
index 5961686674424..d7ceda16e61ea 100644
--- a/onnxruntime/core/providers/cpu/tensor/reshape_helper.h
+++ b/onnxruntime/core/providers/cpu/tensor/reshape_helper.h
@@ -37,12 +37,14 @@ class ReshapeHelper {
     if (unknown_dim != -1) {
       // calculate unknown dimension
       ORT_ENFORCE(size != 0 && (input_shape_size % size) == 0,
-                  "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, ", requested shape:", TensorShape(requested_shape));
+                  "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape,
+                  ", requested shape:", TensorShape(requested_shape));
       requested_shape[unknown_dim] = input_shape_size / size;
     } else {
       // check if the output shape is valid.
       ORT_ENFORCE(input_shape_size == size,
-                  "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, ", requested shape:", TensorShape(requested_shape));
+                  "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape,
+                  ", requested shape:", TensorShape(requested_shape));
     }
   }
 };
diff --git a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h
index 7d117317ba172..3218c8952d6ec 100644
--- a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h
+++ b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h
@@ -14,6 +14,7 @@ class SpaceDepthBase {
                 "Attribute blocksize is not set.");
   }
 
+  template <bool IsNHWC = false>
   Status InputValidationsAndOutputDimsCalc(const Tensor& input,
                                            int64_t& batch,
                                            int64_t& input_depth, int64_t& input_height, int64_t& input_width,
@@ -27,9 +28,15 @@ class SpaceDepthBase {
     }
 
     batch = input_shape[0];
-    input_depth = input_shape[1];
-    input_height = input_shape[2];
-    input_width = input_shape[3];
+    if constexpr (IsNHWC) {
+      input_depth = input_shape[3];
+      input_height = input_shape[1];
+      input_width = input_shape[2];
+    } else {
+      input_depth = input_shape[1];
+      input_height = input_shape[2];
+      input_width = input_shape[3];
+    }
 
     if (is_space_to_depth) {  // SpaceToDepth op
       if ((input_height % this->blocksize_) != 0) {
@@ -46,7 +53,8 @@ class SpaceDepthBase {
 
     } else {  // DepthToSpace op
       if ((input_depth % (blocksize_ * blocksize_) != 0)) {
-        return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "DepthToSpace requires input depth to be a multiple of (block_size * blok_size)");
+        return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+                               "DepthToSpace requires input depth to be a multiple of (block_size * block_size)");
       }
 
       output_depth = input_depth / blocksize_ / blocksize_;
diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc
index fa69e144be554..babbac0b7be17 100644
--- a/onnxruntime/core/providers/cpu/tensor/upsample.cc
+++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc
@@ -1,10 +1,15 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
+#include "core/providers/cpu/tensor/upsample.h"
+
+#include <limits>
+
+#include "core/common/inlined_containers.h"
 #include "core/common/safeint.h"
 #include "core/platform/threadpool.h"
-#include "core/providers/cpu/tensor/upsample.h"
 #include "core/providers/cpu/tensor/upsample_antialias.h"
+
 using namespace onnxruntime::common;
 using namespace std;
 using onnxruntime::narrow;
@@ -30,6 +35,46 @@ REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9);
 REGISTER_VERSIONED_TYPED_KERNEL(int8_t, 9, 9);
 REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9);
 
+void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span<const int64_t> input_dims,
+                                            InlinedVector<float>& scales) const {
+  // AspectRatioPolicy::STRETCH is default policy when opset < 18
+  if (keep_aspect_ratio_policy_ == AspectRatioPolicy::STRETCH) {
+    return;
+  }
+
+  InlinedHashSet<int64_t> axes_set(axes_.begin(), axes_.end());
+
+  float scale_in_policy = 0.0f;
+  if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) {
+    scale_in_policy = std::numeric_limits<float>::max();
+
+    for (size_t i = 0; i < scales.size(); i++) {
+      if (axes_set.empty() || axes_set.count(i) > 0) {
+        scale_in_policy = std::min(scale_in_policy, scales[i]);
+      }
+    }
+  } else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) {
+    scale_in_policy = std::numeric_limits<float>::min();
+
+    for (size_t i = 0; i < scales.size(); i++) {
+      if (axes_set.empty() || axes_set.count(i) > 0) {
+        scale_in_policy = std::max(scale_in_policy, scales[i]);
+      }
+    }
+  }
+
+  for (size_t i = 0; i < scales.size(); i++) {
+    // if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes
+    if (axes_set.empty() || axes_set.count(i) > 0) {
+      scales[i] = scale_in_policy;
+      output_dims[i] = static_cast<int64_t>(std::round(scales[i] * input_dims[i]));
+    } else {
+      scales[i] = 1.0f;
+      output_dims[i] = input_dims[i];
+    }
+  }
+}
+
 template <typename T>
 void UpsampleNearest2x(int64_t batch_size,
                        int64_t num_channels,
@@ -94,8 +139,8 @@ UpsampleNearestSetupInputMappings(int64_t n_dim,
                                   const TensorShape& input_shape,
                                   const TensorShape& output_shape,
                                   const std::vector<int64_t>& input_dim_factor,
-                                  const vector<float>& scales,
-                                  const vector<float>& roi,
+                                  gsl::span<const float> scales,
+                                  gsl::span<const float> roi,
                                   bool extrapolation_enabled,
                                   const GetOriginalCoordinateFunc& get_original_coordinate,
                                   const GetNearestPixelFunc& get_nearest_pixel) {
@@ -141,8 +186,8 @@ static Status UpsampleNearestImpl(const T* input,
                                   T* output,
                                   const TensorShape& input_shape,
                                   const TensorShape& output_shape,
-                                  const vector<float>& scales,
-                                  const vector<float>& roi,
+                                  gsl::span<const float> scales,
+                                  gsl::span<const float> roi,
                                   bool extrapolation_enabled,
                                   const T extrapolation_value,
                                   const GetOriginalCoordinateFunc& get_original_coordinate,
@@ -285,8 +330,8 @@ static Status UpsampleNearest(const T* input,
                               T* output,
                               const TensorShape& input_shape,
                               const TensorShape& output_shape,
-                              const vector<float>& scales,
-                              const vector<float>& roi,
+                              gsl::span<const float> scales,
+                              gsl::span<const float> roi,
                               bool is_resize,
                               bool extrapolation_enabled,
                               T extrapolation_value,
@@ -412,7 +457,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height,
                                      const int32_t output_width,
                                      const float height_scale,
                                      const float width_scale,
-                                     const std::vector<float>& roi,
+                                     gsl::span<const float> roi,
                                      AllocatorPtr& alloc,
                                      const GetOriginalCoordinateFunc& get_original_coordinate,
                                      const bool is_nchw) {
@@ -518,7 +563,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height,
                                                    const int32_t output_width,
                                                    const float height_scale,
                                                    const float width_scale,
-                                                   const std::vector<float>& roi,
+                                                   gsl::span<const float> roi,
                                                    AllocatorPtr& alloc,
                                                    const GetOriginalCoordinateFunc& get_original_coordinate,
                                                    const bool is_nchw) {
@@ -650,7 +695,7 @@ static TrilinearParams SetupUpsampleTrilinear(int64_t input_depth,
                                               float depth_scale,
                                               float height_scale,
                                               float width_scale,
-                                              const std::vector<float>& roi,
+                                              gsl::span<const float> roi,
                                               AllocatorPtr& alloc,
                                               const GetOriginalCoordinateFunc& get_original_coordinate) {
   TrilinearParams p;
@@ -796,7 +841,7 @@ void UpsampleTrilinear(int64_t batch_size,
                        float depth_scale,
                        float height_scale,
                        float width_scale,
-                       const std::vector<float>& roi,
+                       gsl::span<const float> roi,
                        bool use_extrapolation,
                        float extrapolation_value,
                        const T* XdataBase,
@@ -929,7 +974,7 @@ void ResizeBiCubic(int64_t batch_size,
                    bool use_extrapolation,
                    float extrapolation_value,
                    bool exclude_outside,
-                   const std::vector<float>& roi,
+                   gsl::span<const float> roi,
                    const T* Xdata,
                    T* Ydata,
                    const GetOriginalCoordinateFunc& get_original_coordinate) {
@@ -1067,9 +1112,9 @@ void ResizeBiCubic(int64_t batch_size,
 
 template <typename T>
 Status Upsample<T>::BaseCompute(OpKernelContext* context,
-                                const std::vector<float>& roi,
-                                const std::vector<float>& scales,
-                                const gsl::span<const int64_t>& output_dims) const {
+                                gsl::span<const float> roi,
+                                gsl::span<const float> scales,
+                                gsl::span<const int64_t> output_dims) const {
   const auto* X = context->Input<Tensor>(0);
   auto dims = X->Shape().GetDims();
   ORT_RETURN_IF_NOT(output_dims.size() == dims.size(), "Rank of input and output tensor should be same.");
@@ -1327,7 +1372,7 @@ Status Upsample<T>::Compute(OpKernelContext* context) const {
   // Initialize the roi array to all zeros as this will be the most common case
   // Roi data is needed only when coordinate transformation mode is set to tf_crop_and_resize
   // for all other cases we need a 0 initialized roi array
-  std::vector<float> roi_array(roi_);
+  InlinedVector<float> roi_array(roi_);
 
   if (!roi_cached_) {
     bool use_default_roi = true;
@@ -1353,7 +1398,7 @@ Status Upsample<T>::Compute(OpKernelContext* context) const {
 
   ComputeROIWithAxes(roi_array, input_dims.size());
   // Get scales data
-  std::vector<float> scales_array(input_dims.size());
+  InlinedVector<float> scales_array(input_dims.size());
 
   if (OpKernel::Node().InputDefs().size() == 1) {
     // Compute output shape from scales and input dims
diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.h b/onnxruntime/core/providers/cpu/tensor/upsample.h
index 3046ee4b8260d..8ff04781f6ad0 100644
--- a/onnxruntime/core/providers/cpu/tensor/upsample.h
+++ b/onnxruntime/core/providers/cpu/tensor/upsample.h
@@ -66,8 +66,8 @@ class Upsample : public UpsampleBase, public OpKernel {
 
   Status Compute(OpKernelContext* context) const override;
 
-  Status BaseCompute(OpKernelContext* context, const std::vector<float>& roi, const std::vector<float>& scales,
-                     const gsl::span<const int64_t>& output_dims) const;
+  Status BaseCompute(OpKernelContext* context, gsl::span<const float> roi, gsl::span<const float> scales,
+                     gsl::span<const int64_t> output_dims) const;
 };
 
 BilinearParams SetupUpsampleBilinear(const int32_t input_height,
@@ -76,7 +76,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height,
                                      const int32_t output_width,
                                      const float height_scale,
                                      const float width_scale,
-                                     const std::vector<float>& roi,
+                                     gsl::span<const float> roi,
                                      AllocatorPtr& alloc,
                                      const GetOriginalCoordinateFunc& get_original_coordinate,
                                      const bool is_nchw);
@@ -90,7 +90,7 @@ void UpsampleBilinear(const int32_t batch_size,
                       const int32_t output_width,
                       const float height_scale,
                       const float width_scale,
-                      const std::vector<float>& roi,
+                      gsl::span<const float> roi,
                       const bool use_extrapolation,
                       const float extrapolation_value,
                       const T* const XdataBase,
@@ -144,7 +144,7 @@ void NhwcUpsampleBilinear(const int32_t batch_size,
                           const int32_t output_width,
                           const float height_scale,
                           const float width_scale,
-                          const std::vector<float>& roi,
+                          gsl::span<const float> roi,
                           const float extrapolation_value,
                           const T* const XdataBase,
                           T* const YdataBase,
@@ -227,7 +227,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height,
                                                    const int32_t output_width,
                                                    const float height_scale,
                                                    const float width_scale,
-                                                   const std::vector<float>& roi,
+                                                   gsl::span<const float> roi,
                                                    AllocatorPtr& alloc,
                                                    const GetOriginalCoordinateFunc& get_original_coordinate,
                                                    const bool is_nchw);
@@ -241,7 +241,7 @@ void NhwcUpsampleBilinearInteger(const int32_t batch_size,
                                  const int32_t output_width,
                                  const float height_scale,
                                  const float width_scale,
-                                 const std::vector<float>& roi,
+                                 gsl::span<const float> roi,
                                  const float extrapolation_value,
                                  const T* const XdataBase,
                                  T* const YdataBase,
diff --git a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h
index e1dcaf500a325..1e32b7e874b1a 100644
--- a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h
+++ b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h
@@ -21,32 +21,6 @@
 
 namespace onnxruntime {
 
-namespace ConstValue {
-constexpr int32_t mag_factor = 1 << (22 - 1);
-}
-
-namespace {
-const uint8_t* GetLookupTableShared() {
-  // initialized once
-  static const auto* lookup_table = []() {
-    // if we have already initialized the lookup table, just return
-    // ideally we could have a global lookup table, but that account for too much space.
-    /* Handles values form -640 to 639. */
-    static uint8_t table[1280] = {0};
-
-    // taken from https://github.com/python-pillow/Pillow/blob/66add095a50d76c35c7f58643461f2edf78a3f05/src/libImaging/Resample.c#L94
-    //  we need to handle negative values
-    //  it's equivalent to :x = np.clip(x, 0, 255) where x \in [-640, 639]
-    // we will accept a negative x for (&table[640])[x] means table +640 -x
-    for (int i = 0; i < 1280; ++i) {
-      table[i] = static_cast<uint8_t>(std::min(std::max(i - 640, 0), 255));
-    }
-    return table;
-  }();
-  return lookup_table;
-}
-}  // namespace
-
 template <typename T>
 struct FilterParamsBaseAntiAlias {
   std::vector<int64_t> bound;
@@ -57,15 +31,15 @@ struct FilterParamsBaseAntiAlias {
 
 template <typename T>
 struct FilterParamsAntiAlias {
-  float support_size = 2.0f;
-  float cubic_coeff_a = -0.75f;
+  float support_size = antialias_constants::kSupportSize;
+  float cubic_coeff_a = antialias_constants::kCubicCoeffA;
 
   FilterParamsBaseAntiAlias<T> dim_x;
   FilterParamsBaseAntiAlias<T> dim_y;
   FilterParamsBaseAntiAlias<T> dim_z;
 
   const uint8_t* GetClip8LookupTable() const {
-    return GetLookupTableShared();
+    return UpsampleBase::GetLookupTableShared();
   }
   virtual ~FilterParamsAntiAlias() = default;
   virtual float Filter(float x) const = 0;
@@ -89,7 +63,7 @@ struct BilinearParamsAntiAlias : FilterParamsAntiAlias<T> {
 template <typename T>
 struct BiCubicParamsAntiAlias : FilterParamsAntiAlias<T> {
   BiCubicParamsAntiAlias() {
-    this->support_size = 4.0f;
+    this->support_size = antialias_constants::kBiCubicSupportSize;
   }
 
   // taken from
@@ -124,27 +98,6 @@ struct TriLinearParamsAntiAlias : FilterParamsAntiAlias<T> {
   }
 };
 
-template <typename T>
-struct AccumulateType {
-  using type = int32_t;
-  using Dtype = T;
-};
-
-template <>
-struct AccumulateType<int32_t> {
-  using type = float;
-};
-
-template <>
-struct AccumulateType<float> {
-  using type = float;
-};
-
-template <>
-struct AccumulateType<double> {
-  using type = double;
-};
-
 // The following method supports a 3/4/5-D input in 'Linear mode, cubic mode'
 // that amounts to 'Bilinear,TriLinear, Bicubic/Tricubic' Upsampling/Resizing in the sense that it assumes
 // A N-D tensor has
@@ -156,19 +109,20 @@ struct AccumulateType<double> {
 // - [N, H, W, C] and the scales are [1.0, height_scale, width_scale, 1.0]
 template <class T>
 void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias<T>& p,
-                                  const gsl::span<int64_t> input_h_w_c,
-                                  const gsl::span<int64_t> output_h_w_c,
-                                  const gsl::span<float> scale_h_w_c,
-                                  const std::vector<float>& roi,
+                                  gsl::span<const int64_t> input_h_w_c,
+                                  gsl::span<const int64_t> output_h_w_c,
+                                  gsl::span<const float> scale_h_w_c,
+                                  gsl::span<const float> roi,
                                   AllocatorPtr& alloc,
                                   const GetOriginalCoordinateFunc& get_original_coordinate,
                                   bool exclude_outside, const bool is_nchw) {
-  auto compute_weight_coefficients = [&alloc, &roi, &get_original_coordinate, exclude_outside](const FilterParamsAntiAlias<T>& p,
-                                                                                               const int64_t input_size,
-                                                                                               const int64_t output_size,
-                                                                                               size_t rindex,
-                                                                                               FilterParamsBaseAntiAlias<T>& param_base,
-                                                                                               const float rscale) -> int64_t {
+  auto compute_weight_coefficients = [&alloc, roi, &get_original_coordinate, exclude_outside](
+                                         const FilterParamsAntiAlias<T>& p,
+                                         const int64_t input_size,
+                                         const int64_t output_size,
+                                         size_t rindex,
+                                         FilterParamsBaseAntiAlias<T>& param_base,
+                                         const float rscale) -> int64_t {
     param_base.bound.reserve(static_cast<size_t>(output_size) * 2);
     param_base.out_of_bound_idx.reserve(static_cast<size_t>(output_size));
 
@@ -245,13 +199,14 @@ void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias<T>& p,
 
         // normalize the scale to 1 << 22 for int8/uint8
         if constexpr (std::is_same<T, int32_t>::value) {
-          scale_buffer_int[x] = static_cast<int32_t>(std::round(scale_buffer[x] * ConstValue::mag_factor * 2.f));
+          scale_buffer_int[x] = static_cast<int32_t>(std::round(scale_buffer[x] * ConstValue::mag_factor_x_2));
         }
       }
       /*for (; x < window_size; x++) {
         scale_buffer[x] = 0;
       }*/
     }
+
     return window_size;
   };
 
@@ -269,9 +224,6 @@ void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias<T>& p,
   }
 }
 
-template <class T>
-inline constexpr bool is_8bit_v = std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
-
 /**
  * @brief To compute interpolation along with the last axis.
  * For brief,we assume the input tensor has 3 dimensions and we all it CHW for each character represent a dim.
@@ -398,6 +350,7 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in
                 output += *Xdata_offset * (*weight_coeff_start++);
                 Xdata_offset += output_width;
               }
+
               if constexpr (is_8bit_v<InputType>) {
                 *Ydata_offset++ = static_cast<InputType>(clip8_lookups[output >> 22]);
               } else if constexpr (std::is_same<InputType, int32_t>::value) {
@@ -444,6 +397,7 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in
                 output += *Xdata_offset * (*weight_coeff_start++);
                 Xdata_offset += output_width;
               }
+
               if constexpr (is_8bit_v<InputType>) {
                 *Ydata_offset++ = static_cast<InputType>(clip8_lookups[output >> 22]);
               } else if constexpr (std::is_same<InputType, int32_t>::value) {
@@ -515,6 +469,7 @@ void UpsampleBaseAntiAlias(FilterParamsAntiAlias<T1>& p,
                                        narrow<size_t>(input_height * num_channels * input_width));
       auto ydata_span = gsl::make_span(image_temp_buffer.get(), narrow<size_t>(input_height * num_channels * output_width));
 
+      // This computes only the width direction.Thus height keeps unchanged.
       ComputeInterpolationAtLevel1(num_channels, input_height, input_width, input_height, output_width,
                                    xdata_span, ydata_span, p, p.dim_x, tp);
     }
@@ -546,7 +501,7 @@ void UpsampleBilinearAntiAlias(const int64_t batch_size,
                                const int64_t output_width,
                                const float height_scale,
                                const float width_scale,
-                               const std::vector<float>& roi,
+                               gsl::span<const float> roi,
                                const bool use_extrapolation,
                                const float extrapolation_value,
                                bool exclude_outside,
@@ -575,7 +530,7 @@ void NhwcUpsampleBilinearAntiAlias(const int64_t batch_size,
                                    const int64_t output_width,
                                    const float height_scale,
                                    const float width_scale,
-                                   const std::vector<float>& roi,
+                                   gsl::span<const float> roi,
                                    const bool use_extrapolation,
                                    const float extrapolation_value,
                                    bool exclude_outside,
@@ -608,7 +563,7 @@ void NhwcResizeBiCubicAntiAlias(const int64_t batch_size,
                                 bool use_extrapolation,
                                 float extrapolation_value,
                                 bool exclude_outside,
-                                const std::vector<float>& roi,
+                                gsl::span<const float> roi,
                                 const Tensor* X,
                                 T* Ydata_base,
                                 AllocatorPtr& alloc,
@@ -688,7 +643,7 @@ void ResizeBiCubicAntiAlias(int64_t batch_size,
                             bool use_extrapolation,
                             float extrapolation_value,
                             bool exclude_outside,
-                            const std::vector<float>& roi,
+                            gsl::span<const float> roi,
                             const Tensor* X,
                             T* Ydata_base,
                             AllocatorPtr& alloc,
@@ -719,7 +674,7 @@ void UpsampleTrilinearAntiAlias(int64_t batch_size,
                                 float depth_scale,
                                 float height_scale,
                                 float width_scale,
-                                const std::vector<float>& roi,
+                                gsl::span<const float> roi,
                                 bool use_extrapolation,
                                 float extrapolation_value,
                                 bool exclude_outside,
diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h
index a0e7ca1084fef..b768fedd8513a 100644
--- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h
+++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h
@@ -3,11 +3,13 @@
 
 #pragma once
 
+#include <algorithm>
 #include <string>
 #include <string_view>
 #include <unordered_map>
 #include <vector>
-#include <unordered_set>
+
+#include <core/common/inlined_containers_fwd.h>
 #include "core/common/status.h"
 #include <core/common/safeint.h>
 #include <core/common/narrow.h>
@@ -58,7 +60,73 @@ enum class AspectRatioPolicy {
   NOT_SMALLER,
 };
 
+// Antialias types
+template <typename T>
+struct AccumulateType {
+  using type = int32_t;
+  using Dtype = T;
+};
+
+template <>
+struct AccumulateType<int32_t> {
+  using type = float;
+};
+
+template <>
+struct AccumulateType<float> {
+  using type = float;
+};
+
+template <>
+struct AccumulateType<MLFloat16> {
+  using type = float;
+};
+
+template <>
+struct AccumulateType<double> {
+  using type = double;
+};
+
+namespace antialias_constants {
+constexpr float kCubicCoeffA = -0.75f;
+constexpr float kSupportSize = 2.0f;
+constexpr float kBiCubicSupportSize = 4.0f;
+}  // namespace antialias_constants
+
+namespace ConstValue {
+constexpr int32_t mag_factor = 1 << (22 - 1);
+// We use to multiply by 2, let's make a constant which is twice as big
+constexpr int32_t mag_factor_x_2 = 1 << 22;
+}  // namespace ConstValue
+
+template <class T>
+inline constexpr bool is_8bit_v = std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
+
+template <typename T>
+void PrintAntiAliasBuffers(std::ostream& os, gsl::span<int64_t> bounds, gsl::span<int64_t> out_of_bounds,
+                           gsl::span<T> weight_coefficients) {
+  os << "#### Bounds: ";
+  std::copy(bounds.begin(), bounds.end(), std::ostream_iterator<int64_t>(os, " "));
+  os << std::endl;
+
+  os << "#### Out of Bounds: ";
+  std::copy(out_of_bounds.begin(), out_of_bounds.end(),
+            std::ostream_iterator<int64_t>(os, " "));
+  os << std::endl;
+
+  os << "#### Scale Buffer: ";
+  std::copy(weight_coefficients.begin(), weight_coefficients.end(),
+            std::ostream_iterator<T>(os, " "));
+  os << std::endl;
+}
+
 class UpsampleBase {
+ public:
+  // Make this available in other EP via provider bridge
+  // it works iff output_shape is specified
+  void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span<const int64_t> input_dims,
+                                InlinedVector<float>& scales) const;
+
  protected:
   explicit UpsampleBase(const OpKernelInfo& info)
       : scales_cached_(false), roi_cached_(false), use_extrapolation_(false) {
@@ -69,23 +137,32 @@ class UpsampleBase {
     std::string mode;
     ORT_ENFORCE(info.GetAttr<std::string>("mode", &mode).IsOK());
     mode_ = StringToUpsampleMode(mode);
-    antialias_ = info.GetAttrOrDefault<int64_t>("antialias", 0) == 0 ? false : true;
-    if (antialias_) {
-      ORT_ENFORCE((UpsampleMode::LINEAR == mode_ || UpsampleMode::CUBIC == mode_),
-                  "when anti-aliasing is set, Resize only supports mode `LINEAR` and `CUBIC`.");
-    }
 
     auto input_count = info.GetInputCount();
     if (input_count == 1) {  // opset < 10
-      ORT_THROW_IF_ERROR(info.GetAttrs<float>("scales", scales_));
-      ORT_THROW_IF_ERROR(ScalesValidation(scales_, mode_));
+      std::vector<float> scales;
+      ORT_THROW_IF_ERROR(info.GetAttrs<float>("scales", scales));
+      ORT_THROW_IF_ERROR(ScalesValidation(scales, mode_));
+      scales_.assign(scales.cbegin(), scales.cend());
       scales_cached_ = true;
     }
 
-    std::string keep_aspect_ratio_policy = info.GetAttrOrDefault<std::string>("keep_aspect_ratio_policy", "stretch");
-    keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy);
+    if (opset >= 18) {
+      antialias_ = info.GetAttrOrDefault<int64_t>("antialias", 0) == 0 ? false : true;
+
+      if (antialias_) {
+        ORT_ENFORCE((UpsampleMode::LINEAR == mode_ || UpsampleMode::CUBIC == mode_),
+                    "when anti-aliasing is set, Resize only supports mode `LINEAR` and `CUBIC`.");
+      }
 
-    axes_ = info.GetAttrsOrDefault<int64_t>("axes");
+      // The attribute is absent in opset < 18, but the default value as if stretch.
+      std::string keep_aspect_ratio_policy = info.GetAttrOrDefault<std::string>("keep_aspect_ratio_policy", "stretch");
+      keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy);
+
+      // guard against unit tests that can add an attribute
+      auto axes = info.GetAttrsOrDefault<int64_t>("axes");
+      axes_.assign(axes.cbegin(), axes.cend());
+    }
 
     extrapolation_value_ = info.GetAttrOrDefault<float>("extrapolation_value", 0.0f);
 
@@ -112,7 +189,7 @@ class UpsampleBase {
     nearest_mode_ = StringToNearestMode(nearest_mode_name);
     get_nearest_pixel_ = GetNearestPixelFromOriginal(nearest_mode_);
 
-    cubic_coeff_a_ = info.GetAttrOrDefault<float>("cubic_coeff_a", -0.75f);
+    cubic_coeff_a_ = info.GetAttrOrDefault<float>("cubic_coeff_a", antialias_constants::kCubicCoeffA);
     exclude_outside_ = info.GetAttrOrDefault<int64_t>("exclude_outside", 0) == 0 ? false : true;
 
     if ((exclude_outside_ == 1 && mode_ != CUBIC) && (antialias_ == false || mode_ != LINEAR)) {
@@ -166,7 +243,7 @@ class UpsampleBase {
   ResizeCoordinateTransformationMode coordinate_transform_mode_;
   GetOriginalCoordinateFunc get_original_coordinate_;
   ResizeNearestMode nearest_mode_;
-  AspectRatioPolicy keep_aspect_ratio_policy_;
+  AspectRatioPolicy keep_aspect_ratio_policy_{AspectRatioPolicy::STRETCH};
   GetNearestPixelFunc get_nearest_pixel_;
   float cubic_coeff_a_;
   bool exclude_outside_;
@@ -174,9 +251,9 @@ class UpsampleBase {
   float extrapolation_value_;
   bool use_nearest2x_optimization_ = false;
 
-  std::vector<float> scales_;
-  std::vector<float> roi_;
-  std::vector<int64_t> axes_;
+  InlinedVector<float> scales_;
+  InlinedVector<float> roi_;
+  TensorShapeVector axes_;
 
   bool scales_cached_;
   bool roi_cached_;
@@ -335,7 +412,7 @@ class UpsampleBase {
     }
   }
 
-  [[nodiscard]] Status ScalesValidation(const std::vector<float>& scales, const UpsampleMode mode) const {
+  [[nodiscard]] Status ScalesValidation(gsl::span<const float> scales, const UpsampleMode mode) const {
     if (!is_resize_) {
       for (auto& scale : scales) {
         ORT_RETURN_IF_NOT(scale >= 1, "Scale value should be greater than or equal to 1.");
@@ -372,7 +449,7 @@ class UpsampleBase {
   }
 
   [[nodiscard]] Status
-  ParseScalesData(const Tensor* scale, std::vector<float>& scales, int64_t rank) const {
+  ParseScalesData(const Tensor* scale, InlinedVector<float>& scales, int64_t rank) const {
     const auto* scale_data = scale->Data<float>();
     int64_t scales_size = scale->Shape().Size();
     ORT_RETURN_IF_NOT(scales_size > 0, "scales size should be greater than 0.");
@@ -387,19 +464,19 @@ class UpsampleBase {
     // in which case the other axes is ignored and use default scale of 1
     // scales_size == axes_.size() should be guaranteed if axes is not empty
     if (rank > 0 && (scales_size != rank || axes_.size())) {
-      std::vector<float> new_scales(size_t(rank), 1.0f);
+      InlinedVector<float> new_scales(size_t(rank), 1.0f);
       ORT_RETURN_IF_NOT(*std::max_element(axes_.begin(), axes_.end()) < rank && (int64_t(axes_.size()) == scales_size),
                         "all values in axes should be less than rank of the data");
 
       for (size_t i = 0; i < axes_.size(); i++) {
         new_scales[static_cast<size_t>(axes_[i])] = scales[i];
       }
-      scales = new_scales;
+      scales.swap(new_scales);
     }
     return ScalesValidation(scales, mode_);
   }
 
-  void ParseRoiData(const Tensor* roi, std::vector<float>& roi_array) const {
+  void ParseRoiData(const Tensor* roi, InlinedVector<float>& roi_array) const {
     int64_t roi_size = roi->Shape().Size();
     if (roi_size > 0) {
       roi_array.resize(onnxruntime::narrow<size_t>(roi_size));
@@ -429,52 +506,11 @@ class UpsampleBase {
     return Status::OK();
   }
 
-  // it works iff output_shape is specified
-  void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span<const int64_t> input_dims,
-                                std::vector<float>& scales) const {
-    std::unordered_set<int64_t> axes_set(axes_.begin(), axes_.end());
-
-    // AspectRatioPolicy::STRETCH is default policy when opset < 18
-    if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::STRETCH) {
-      return;
-    }
-
-    float scale_in_policy = 0.0f;
-    if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) {
-      scale_in_policy = std::numeric_limits<float>::max();
-
-      for (size_t i = 0; i < scales.size(); i++) {
-        if (axes_set.empty() || axes_set.count(i) > 0) {
-          scale_in_policy = std::min(scale_in_policy, scales[i]);
-        }
-      }
-    } else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) {
-      scale_in_policy = std::numeric_limits<float>::min();
-
-      for (size_t i = 0; i < scales.size(); i++) {
-        if (axes_set.empty() || axes_set.count(i) > 0) {
-          scale_in_policy = std::max(scale_in_policy, scales[i]);
-        }
-      }
-    }
-
-    for (size_t i = 0; i < scales.size(); i++) {
-      // if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes
-      if (axes_set.empty() || axes_set.count(i) > 0) {
-        scales[i] = scale_in_policy;
-        output_dims[i] = static_cast<int64_t>(std::round(scales[i] * input_dims[i]));
-      } else {
-        scales[i] = 1.0f;
-        output_dims[i] = input_dims[i];
-      }
-    }
-  }
-
   // It's different in Opset 18 and before.
   // we will modify output_shape by sorts of policy even if it's specified
   [[nodiscard]] Status ParseScalesDataAndAdjustOutputSize(TensorShapeVector& output_dims,
                                                           gsl::span<const int64_t> input_dims,
-                                                          std::vector<float>& scales) const {
+                                                          InlinedVector<float>& scales) const {
     for (size_t i = 0, end = input_dims.size(); i < end; ++i) {
       // Handle corner case to avoid dividing by zero in the next step
       if (input_dims[i] == 0) {
@@ -507,9 +543,9 @@ class UpsampleBase {
 
   // Roi is redefined in Opset-18, we have a concept of axes.
   // So we need to update it accordingly.
-  void ComputeROIWithAxes(std::vector<float>& roi_array, size_t rank) const {
+  void ComputeROIWithAxes(InlinedVector<float>& roi_array, size_t rank) const {
     if (axes_.size()) {
-      std::vector<float> roi_tmp(rank * 2, 0);
+      InlinedVector<float> roi_tmp(rank * 2, 0);
       for (size_t i = rank; i < rank * 2; ++i) {
         roi_tmp[i] = 1;
       }
@@ -518,9 +554,32 @@ class UpsampleBase {
         roi_tmp[v_in_axes] = (roi_array[i]);
         roi_tmp[rank + v_in_axes] = (roi_array[axes_.size() + i]);
       }
-      roi_array = roi_tmp;
+      roi_array.swap(roi_tmp);
     }
   }
+
+ public:
+  static constexpr size_t kLookupTableSize = 1280;
+
+  static const uint8_t* GetLookupTableShared() {
+    // initialized once
+    static const auto* lookup_table = []() {
+      // if we have already initialized the lookup table, just return
+      // ideally we could have a global lookup table, but that account for too much space.
+      /* Handles values form -640 to 639. */
+      static uint8_t table[kLookupTableSize] = {0};
+
+      // taken from https://github.com/python-pillow/Pillow/blob/66add095a50d76c35c7f58643461f2edf78a3f05/src/libImaging/Resample.c#L94
+      //  we need to handle negative values
+      //  it's equivalent to :x = np.clip(x, 0, 255) where x \in [-640, 639]
+      // we will accept a negative x for (&table[640])[x] means table +640 -x
+      for (int i = 0; i < static_cast<int>(kLookupTableSize); ++i) {
+        table[i] = static_cast<uint8_t>(std::min(std::max(i - 640, 0), 255));
+      }
+      return table;
+    }();
+    return lookup_table;
+  }
 };  // UpsampleBase
 
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
index 0d9928baa86e0..052dd05574ab1 100644
--- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh
+++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
@@ -5,7 +5,9 @@
 #include <stdint.h>
 #include <vector>
 #include <mutex>
+#include <limits>
 #include <assert.h>
+#include <math.h>
 #include <cuda_runtime.h>
 #include <cuda_fp16.h>
 #include "core/providers/cuda/cuda_common.h"
@@ -194,13 +196,13 @@ template <>
 __device__ __inline__ half _Ceil(half a) { return half(ceilf((float)a)); }
 
 template <typename T>
-__device__ __inline__ T _Floor(T a);
+__device__ __host__ __inline__ T _Floor(T a);
 
 template <>
-__device__ __inline__ float _Floor(float a) { return floorf(a); }
+__device__ __host__ __inline__ float _Floor(float a) { return floorf(a); }
 
 template <>
-__device__ __inline__ double _Floor(double a) { return floor(a); }
+__device__ __host__ __inline__ double _Floor(double a) { return floor(a); }
 
 template <>
 __device__ __inline__ half _Floor(half a) { return half(floorf((float)a)); }
@@ -230,13 +232,13 @@ template <>
 __device__ __inline__ half _Erf(half a) { return half(erff((float)a)); }
 
 template <typename T>
-__device__ __inline__ T _Round(T a);
+__device__ __host__ __inline__ T _Round(T a);
 
 template <>
-__device__ __inline__ float _Round(float a) { return rintf(a); }
+__device__ __host__ __inline__ float _Round(float a) { return rintf(a); }
 
 template <>
-__device__ __inline__ double _Round(double a) { return rint(a); }
+__device__ __host__ __inline__ double _Round(double a) { return rint(a); }
 
 template <>
 __device__ __inline__ half _Round(half a) {
@@ -345,9 +347,29 @@ __device__ __inline__ half _Pow(half a, half b) { return half(powf((float)a, (fl
 template <typename T>
 __device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; }
 
+template <>
+__device__ __inline__ float _Min(float a, float b) {
+  return (isnan(a) || isnan(b)) ? std::numeric_limits<float>::quiet_NaN() : ( a < b ? a : b );
+}
+
+template <>
+__device__ __inline__ double _Min(double a, double b) {
+  return (isnan(a) || isnan(b)) ? std::numeric_limits<double>::quiet_NaN() : ( a < b ? a : b );
+}
+
 template <typename T>
 __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }
 
+template <>
+__device__ __inline__ float _Max(float a, float b) {
+  return (isnan(a) || isnan(b)) ? std::numeric_limits<float>::quiet_NaN() : ( a > b ? a : b );
+}
+
+template <>
+__device__ __inline__ double _Max(double a, double b) {
+  return (isnan(a) || isnan(b)) ? std::numeric_limits<double>::quiet_NaN() : ( a > b ? a : b );
+}
+
 template <typename T>
 __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }
 
@@ -438,6 +460,157 @@ __device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) {
   return fmodf((float)a, (float)b);
 }
 
+namespace isinf_details {
+template <typename T>
+struct IsInfTyped {
+  static __device__ __inline__ bool IsInf(T a) {
+    // cast is needed because on non MS compilers,
+    // because there isinf() returns int
+    // and we want to avoid stupid warnings
+    return static_cast<bool>(isinf(a));
+  }
+  static __device__ __inline__ bool IsInfPos(T a) {
+    return a == std::numeric_limits<T>::infinity();
+  }
+  static __device__ __inline__ bool IsInfNeg(T a) {
+    return a == -std::numeric_limits<T>::infinity();
+  }
+};
+
+template <>
+struct IsInfTyped<half> {
+  static __device__ __inline__ bool IsInf(half a) {
+    return MLFloat16::kPositiveInfinityBits ==
+           static_cast<uint16_t>(*reinterpret_cast<uint16_t*>(&a) & ~MLFloat16::kSignMask);
+  }
+  static __device__ __inline__ bool IsInfPos(half a) {
+    return MLFloat16::kPositiveInfinityBits == *reinterpret_cast<uint16_t*>(&a);
+  }
+  static __device__ __inline__ bool IsInfNeg(half a) {
+    return MLFloat16::kNegativeInfinityBits == *reinterpret_cast<uint16_t*>(&a);
+  }
+};
+
+template <>
+struct IsInfTyped<BFloat16> {
+  static __device__ __inline__ bool IsInf(BFloat16 a) {
+    return BFloat16::kPositiveInfinityBits ==
+           static_cast<uint16_t>(*reinterpret_cast<uint16_t*>(&a) & ~BFloat16::kSignMask);
+  }
+  static __device__ __inline__ bool IsInfPos(BFloat16 a) {
+    return BFloat16::kPositiveInfinityBits == *reinterpret_cast<uint16_t*>(&a);
+  }
+  static __device__ __inline__ bool IsInfNeg(BFloat16 a) {
+    return BFloat16::kNegativeInfinityBits == *reinterpret_cast<uint16_t*>(&a);
+  }
+};
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+
+template <typename T>
+struct ReturnFalse {
+  constexpr static bool __device__ __inline__ IsInf(T) { return false; }
+  constexpr static bool __device__ __inline__ IsInfPos(T) { return false; }
+  constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; }
+};
+
+template <>
+struct IsInfTyped<Float8E4M3FN> : ReturnFalse<Float8E4M3FN> {};
+
+template <>
+struct IsInfTyped<Float8E4M3FNUZ> : ReturnFalse<Float8E4M3FNUZ> {};
+
+template <>
+struct IsInfTyped<Float8E5M2> {
+  static __device__ __inline__ bool IsInf(Float8E5M2 a) {
+    return a.val == 0b01111100 || a.val == 0b11111100;
+  }
+  static __device__ __inline__ bool IsInfPos(Float8E5M2 a) {
+    return a.val == 0b01111100;
+  }
+  static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) {
+    return a.val == 0b11111100;
+  }
+};
+
+template <>
+struct IsInfTyped<Float8E5M2FNUZ> : ReturnFalse<Float8E5M2FNUZ> {};
+
+#endif
+}  // namespace isinf_details
+
+template <typename T, bool detect_positive, bool detect_negative>
+struct _IsInf {
+  __device__ __inline__ bool operator()(T a) const {
+    if constexpr (detect_positive && detect_negative) {
+      return isinf_details::IsInfTyped<T>::IsInf(a);
+    } else if constexpr (detect_positive) {
+      return isinf_details::IsInfTyped<T>::IsInfPos(a);
+    } else if constexpr (detect_negative) {
+      return isinf_details::IsInfTyped<T>::IsInfNeg(a);
+    } else {
+      return false;
+    }
+  }
+};
+
+// float and double
+template <typename T>
+struct _IsNan {
+  __device__ __inline__ bool operator()(T a) const {
+    return isnan(a);
+  }
+};
+
+template <>
+struct _IsNan<half> {
+  __device__ __inline__ bool operator()(half a) const {
+    return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~MLFloat16::kSignMask)
+           > MLFloat16::kPositiveInfinityBits;
+  }
+};
+
+template <>
+struct _IsNan<BFloat16> {
+  __device__ __inline__ bool operator()(BFloat16 a) const {
+    return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~BFloat16::kSignMask)
+           > BFloat16::kPositiveInfinityBits;
+  }
+};
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+
+template<>
+struct _IsNan<Float8E4M3FN> {
+  __device__ __inline__ bool operator()(Float8E4M3FN a) const {
+    return (*reinterpret_cast<const uint8_t*>(&a) & 0x7f) == 0x7f;
+  }
+};
+
+template<>
+struct _IsNan<Float8E4M3FNUZ> {
+  __device__ __inline__ bool operator()(Float8E4M3FNUZ a) const {
+    return *reinterpret_cast<const uint8_t*>(&a) == 0x80;
+  }
+};
+
+template<>
+struct _IsNan<Float8E5M2> {
+  __device__ __inline__ bool operator()(Float8E5M2 a) const {
+    uint8_t c = *reinterpret_cast<const uint8_t*>(&a);
+    return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00);
+  }
+};
+
+template<>
+struct _IsNan<Float8E5M2FNUZ> {
+  __device__ __inline__ bool operator()(Float8E5M2FNUZ a) const {
+    return *reinterpret_cast<const uint8_t*>(&a) == 0x80;
+  }
+};
+
+#endif
+
 // We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer
 // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.
 #ifndef CUDA_LONG
diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h
index 41c999bacee13..61da125b40953 100644
--- a/onnxruntime/core/providers/cuda/cuda_common.h
+++ b/onnxruntime/core/providers/cuda/cuda_common.h
@@ -70,6 +70,15 @@ class ToCudaType<Float8E4M3FN> {
   }
 };
 
+template <>
+class ToCudaType<Float8E4M3FNUZ> {
+ public:
+  typedef Float8E4M3FNUZ MappedType;
+  static MappedType FromFloat(float f) {
+    return MappedType(f);
+  }
+};
+
 template <>
 class ToCudaType<Float8E5M2> {
  public:
@@ -79,6 +88,15 @@ class ToCudaType<Float8E5M2> {
   }
 };
 
+template <>
+class ToCudaType<Float8E5M2FNUZ> {
+ public:
+  typedef Float8E5M2FNUZ MappedType;
+  static MappedType FromFloat(float f) {
+    return MappedType(f);
+  }
+};
+
 #endif
 
 inline bool CalculateFdmStrides(gsl::span<fast_divmod> p, const std::vector<int64_t>& dims) {
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index 77e682e05a2a4..05d9f3b5a1e8f 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -3,6 +3,7 @@
 // Licensed under the MIT License.
 
 #include "core/common/inlined_containers.h"
+#include "core/common/parse_string.h"
 #include "core/providers/shared_library/provider_api.h"
 #include "core/platform/env_var_utils.h"
 #include "core/providers/cuda/cuda_execution_provider.h"
@@ -11,6 +12,7 @@
 #include "core/providers/cuda/cuda_fwd.h"
 #include "core/providers/cuda/gpu_data_transfer.h"
 #include "core/providers/cuda/cuda_profiler.h"
+#include "core/session/onnxruntime_run_options_config_keys.h"
 
 #ifndef USE_CUDA_MINIMAL
 #ifndef DISABLE_CONTRIB_OPS
@@ -190,31 +192,60 @@ CUDAExecutionProvider::PerThreadContext::~PerThreadContext() {
 #endif
 }
 
-bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const {
-  return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_;
+bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed(
+    CudaGraphAnnotation_t cuda_graph_annotation_id) const {
+  if (!IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id)) {
+    return false;
+  }
+  if (graph_id_to_run_count_.find(cuda_graph_annotation_id) == graph_id_to_run_count_.end()) {
+    return false;
+  }
+  return graph_id_to_run_count_.at(cuda_graph_annotation_id) >= min_num_runs_before_cuda_graph_capture_;
+}
+
+bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowedOnRun(
+    CudaGraphAnnotation_t cuda_graph_annotation_id) const {
+  return cuda_graph_.IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id);
+}
+
+CudaGraphAnnotation_t CUDAExecutionProvider::PerThreadContext::GetCudaGraphAnnotationId(
+    const onnxruntime::RunOptions& run_options) const {
+  auto graph_annotation_str =
+      run_options.GetConfigOptions().GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation);
+  // If graph annotation is not provided, fall back to the one cuda graph per session behavior
+  CudaGraphAnnotation_t cuda_graph_annotation_id = 0;
+  if (graph_annotation_str.has_value()) {
+    ORT_ENFORCE(TryParseStringWithClassicLocale<int>(*graph_annotation_str, cuda_graph_annotation_id),
+                "Failed to parse the cuda graph annotation id: ",
+                *graph_annotation_str);
+  }
+
+  return cuda_graph_annotation_id;
 }
 
-void CUDAExecutionProvider::PerThreadContext::CaptureBegin() {
-  cuda_graph_.Reset();
-  cuda_graph_.CaptureBegin();
+void CUDAExecutionProvider::PerThreadContext::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) {
+  cuda_graph_.CaptureBegin(cuda_graph_annotation_id);
 }
 
-void CUDAExecutionProvider::PerThreadContext::CaptureEnd() {
-  cuda_graph_.CaptureEnd();
-  is_graph_captured_ = true;
+void CUDAExecutionProvider::PerThreadContext::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) {
+  cuda_graph_.CaptureEnd(cuda_graph_annotation_id);
 }
 
-bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptured() const {
-  return is_graph_captured_;
+bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const {
+  return cuda_graph_.IsGraphCaptured(graph_annotation_id);
 }
 
-Status CUDAExecutionProvider::PerThreadContext::ReplayGraph() {
-  ORT_ENFORCE(IsGraphCaptured());
-  return cuda_graph_.Replay();
+Status CUDAExecutionProvider::PerThreadContext::ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) {
+  return cuda_graph_.Replay(graph_annotation_id);
 }
 
-void CUDAExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() {
-  ++regular_run_count_before_graph_capture_;
+void CUDAExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture(
+    CudaGraphAnnotation_t cuda_graph_annotation_id) {
+  if (graph_id_to_run_count_.find(cuda_graph_annotation_id) == graph_id_to_run_count_.end()) {
+    graph_id_to_run_count_[cuda_graph_annotation_id] = 1;
+    return;
+  }
+  graph_id_to_run_count_[cuda_graph_annotation_id]++;
 }
 
 void OverrideTunableOpInfoByEnv(CUDAExecutionProviderInfo& info) {
@@ -386,25 +417,28 @@ Status CUDAExecutionProvider::Sync() const {
   return Status::OK();
 }
 
-Status CUDAExecutionProvider::OnRunStart() {
+Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) {
   // always set CUDA device when session::Run() in case it runs in a worker thread
   CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId()));
-  if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {
+  CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options);
+  if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id) &&
+      GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) {
     LOGS(*GetLogger(), INFO) << "Capturing the cuda graph for this model";
-    GetPerThreadContext().CaptureBegin();
+    GetPerThreadContext().CaptureBegin(cuda_graph_annotation_id);
   }
   return Status::OK();
 }
 
-Status CUDAExecutionProvider::OnRunEnd(bool sync_stream) {
-  if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) {
-    if (GetPerThreadContext().IsGraphCaptureAllowed()) {
-      GetPerThreadContext().CaptureEnd();
+Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) {
+  CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options);
+  if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) {
+    if (GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) {
+      GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id);
       // CUDA work issued to a capturing stream doesn’t actually run on the GPU,
       // so run the captured graph here to actually execute the work.
-      ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph());
+      ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id));
     } else {
-      GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture();
+      GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(cuda_graph_annotation_id);
     }
   }
 
@@ -433,12 +467,12 @@ bool CUDAExecutionProvider::IsGraphCaptureEnabled() const {
   return info_.enable_cuda_graph;
 }
 
-bool CUDAExecutionProvider::IsGraphCaptured() const {
-  return GetPerThreadContext().IsGraphCaptured();
+bool CUDAExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {
+  return GetPerThreadContext().IsGraphCaptured(graph_annotation_id);
 }
 
-Status CUDAExecutionProvider::ReplayGraph() {
-  return GetPerThreadContext().ReplayGraph();
+Status CUDAExecutionProvider::ReplayGraph(int graph_annotation_id) {
+  return GetPerThreadContext().ReplayGraph(graph_annotation_id);
 }
 
 namespace cuda {
@@ -722,6 +756,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Cast);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Cast);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, bool, Cast);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, IsNaN);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, float, Pad);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, double, Pad);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, MLFloat16, Pad);
@@ -830,6 +865,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu);
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK);
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 19, IsInf);
 
 // opset 11
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress);
@@ -913,7 +949,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom
 
 // OpSet 12
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Clip);
-
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, float, MaxPool);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, double, MaxPool);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, MLFloat16, MaxPool);
@@ -989,6 +1024,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sqrt);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sqrt);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sqrt);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Sqrt);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Log);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Log);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Log);
@@ -1061,6 +1097,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, U
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Concat);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Gather);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, GatherElements);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 19, IsNaN);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, MatMul);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, MatMul);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul);
@@ -1108,11 +1145,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceSumSquare);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, GatherND);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Dropout);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, int32_t, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, uint8_t, Resize);
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, If);
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, Loop);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Flatten);
@@ -1201,9 +1238,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM);
-class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization);
-class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization);
-class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+    kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+    kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+    kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin);
@@ -1255,6 +1295,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample);
 
 // Opset 17
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
@@ -1275,6 +1316,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, Resize);
 
 // Opset 19
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast);
@@ -1328,6 +1374,13 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, S
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Shape);
 #endif
 
+// Opset 20
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN);
+
 template <>
 KernelCreateInfo BuildKernelCreateInfo<void>() {
   return {};
@@ -1512,6 +1565,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, float, Erf)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, double, Erf)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Erf)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, IsNaN)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, bool, Not)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization)>,
@@ -1724,6 +1778,8 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10,
+                                                                    19, IsInf)>,
 
     // opset 11
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,
@@ -1882,6 +1938,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sqrt)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sqrt)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sqrt)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Sqrt)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Log)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Log)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Log)>,
@@ -1935,6 +1992,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint32_t, Cast)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint64_t, Cast)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, bool, Cast)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 19, IsNaN)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Reshape)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size)>,
@@ -2001,11 +2059,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceSumSquare)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, GatherND)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Dropout)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Resize)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Resize)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Resize)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, int32_t, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, uint8_t, Resize)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, If)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, Loop)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Flatten)>,
@@ -2094,9 +2152,12 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 18, Reshape)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+        kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+        kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+        kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Mul)>,
@@ -2141,6 +2202,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample)>,
 
     // Opset 17
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
@@ -2167,6 +2229,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, Resize)>,
 
     // Opset 19
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast)>,
@@ -2220,6 +2287,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Reshape)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Scan)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Shape)>,
+
+    // Opset 20
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN)>,
 #endif
   };
 
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h
index 55f0b5570e0ee..f53779058a8af 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h
@@ -29,9 +29,9 @@ class CUDAExecutionProvider : public IExecutionProvider {
 
   Status Sync() const override;
 
-  Status OnRunStart() override;
+  Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
 
-  Status OnRunEnd(bool sync_stream) override;
+  Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override;
 
   DataLayout GetPreferredLayout() const override;
 
@@ -92,8 +92,8 @@ class CUDAExecutionProvider : public IExecutionProvider {
   std::unique_ptr<profiling::EpProfiler> GetProfiler() override;
 
   bool IsGraphCaptureEnabled() const override;
-  bool IsGraphCaptured() const override;
-  Status ReplayGraph() override;
+  bool IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const override;
+  Status ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) override;
   void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
   OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
   std::vector<AllocatorPtr> CreatePreferredAllocators() override;
@@ -115,6 +115,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
     PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy,
                      CUDAExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg);
     ~PerThreadContext();
+    ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext);
 
     cublasHandle_t CublasHandle() const {
       return cublas_handle_;
@@ -130,41 +131,33 @@ class CUDAExecutionProvider : public IExecutionProvider {
 
     template <typename T>
     const T* GetConstOnes(size_t count, cudaStream_t stream) {
-      constexpr bool is_float = std::is_same<T, float>::value;
-      constexpr bool is_double = std::is_same<T, double>::value;
-      constexpr bool is_half = std::is_same<T, half>::value;
-      constexpr bool is_BFloat16 = std::is_same<T, BFloat16>::value;
-#if !defined(DISABLE_FLOAT8_TYPES)
-      constexpr bool is_Float8E4M3FN = std::is_same<T, Float8E4M3FN>::value;
-      constexpr bool is_Float8E5M2 = std::is_same<T, Float8E5M2>::value;
-#endif
-      if (is_float) {
+      if constexpr (std::is_same<T, float>::value) {
         if (!constant_ones_float_) {
           constant_ones_float_ = cuda::CreateConstantOnes<float>();
         }
         return reinterpret_cast<const T*>(constant_ones_float_->GetBuffer(stream, count));
-      } else if (is_double) {
+      } else if constexpr (std::is_same<T, double>::value) {
         if (!constant_ones_double_) {
           constant_ones_double_ = cuda::CreateConstantOnes<double>();
         }
         return reinterpret_cast<const T*>(constant_ones_double_->GetBuffer(stream, count));
-      } else if (is_half) {
+      } else if constexpr (std::is_same<T, half>::value) {
         if (!constant_ones_half_) {
           constant_ones_half_ = cuda::CreateConstantOnes<half>();
         }
         return reinterpret_cast<const T*>(constant_ones_half_->GetBuffer(stream, count));
-      } else if (is_BFloat16) {
+      } else if constexpr (std::is_same<T, BFloat16>::value) {
         if (!constant_ones_bfloat16_) {
           constant_ones_bfloat16_ = cuda::CreateConstantOnes<BFloat16>();
         }
         return reinterpret_cast<const T*>(constant_ones_bfloat16_->GetBuffer(stream, count));
 #if !defined(DISABLE_FLOAT8_TYPES)
-      } else if (is_Float8E4M3FN) {
+      } else if constexpr (std::is_same<T, Float8E4M3FN>::value) {
         if (!constant_ones_float8e4m3fn_) {
           constant_ones_float8e4m3fn_ = cuda::CreateConstantOnes<Float8E4M3FN>();
         }
         return reinterpret_cast<const T*>(constant_ones_float8e4m3fn_->GetBuffer(stream, count));
-      } else if (is_Float8E5M2) {
+      } else if constexpr (std::is_same<T, Float8E5M2>::value) {
         if (!constant_ones_float8e5m2_) {
           constant_ones_float8e5m2_ = cuda::CreateConstantOnes<Float8E5M2>();
         }
@@ -175,12 +168,14 @@ class CUDAExecutionProvider : public IExecutionProvider {
       }
     }
 
-    bool IsGraphCaptureAllowed() const;
-    void CaptureBegin();
-    void CaptureEnd();
-    bool IsGraphCaptured() const;
-    Status ReplayGraph();
-    void IncrementRegularRunCountBeforeGraphCapture();
+    bool IsGraphCaptureAllowed(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
+    bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
+    void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id);
+    void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id);
+    bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
+    CudaGraphAnnotation_t GetCudaGraphAnnotationId(const onnxruntime::RunOptions& run_options) const;
+    Status ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id);
+    void IncrementRegularRunCountBeforeGraphCapture(CudaGraphAnnotation_t cuda_graph_annotation_id);
 
    private:
     cublasHandle_t cublas_handle_ = nullptr;
@@ -199,8 +194,8 @@ class CUDAExecutionProvider : public IExecutionProvider {
     // Cuda graph with multi threads will be supported in the future, so cuda_graph_
     // is put under PerThreadContext.
     CUDAGraph cuda_graph_;
-    bool is_graph_captured_ = false;
-    int regular_run_count_before_graph_capture_ = 0;
+    // Map of graph id to regular_run_count_before_graph_capture
+    std::unordered_map<CudaGraphAnnotation_t, int> graph_id_to_run_count_;
 
     // There is chance that the second regular run allocates GPU memory for causes like:
     // (1) memory pattern is enabled. (2) arena allocation for stream.
diff --git a/onnxruntime/core/providers/cuda/cuda_graph.cc b/onnxruntime/core/providers/cuda/cuda_graph.cc
index 230d664391611..8353c654681fc 100644
--- a/onnxruntime/core/providers/cuda/cuda_graph.cc
+++ b/onnxruntime/core/providers/cuda/cuda_graph.cc
@@ -9,17 +9,44 @@
 
 namespace onnxruntime {
 
-CUDAGraph::CUDAGraph(cudaStream_t stream) : stream_(stream) {
+CudaGraphSet::~CudaGraphSet() {
+  Clear();
 }
 
-void CUDAGraph::SetStream(cudaStream_t stream) {
+void CudaGraphSet::Clear() {
+  for (auto& it : cuda_graphs_) {
+    CUDA_CALL_THROW(cudaGraphExecDestroy(it.second));
+  }
+  cuda_graphs_.clear();
+}
+
+bool CudaGraphSet::Contains(CudaGraphAnnotation_t cuda_graph_annotation_id) const {
+  return cuda_graphs_.find(cuda_graph_annotation_id) != cuda_graphs_.end();
+}
+
+void CudaGraphSet::Put(CudaGraphAnnotation_t cuda_graph_annotation_id, cudaGraphExec_t graph_exec) {
+  ORT_ENFORCE(!Contains(cuda_graph_annotation_id));
+  cuda_graphs_.emplace(cuda_graph_annotation_id, graph_exec);
+}
+
+cudaGraphExec_t CudaGraphSet::Get(CudaGraphAnnotation_t cuda_graph_annotation_id) const {
+  ORT_ENFORCE(Contains(cuda_graph_annotation_id));
+  return cuda_graphs_.at(cuda_graph_annotation_id);
+}
+
+CUDAGraphManager::CUDAGraphManager(cudaStream_t stream) : stream_(stream) {
+}
+
+void CUDAGraphManager::SetStream(cudaStream_t stream) {
   stream_ = stream;
 }
 
-void CUDAGraph::CaptureBegin() {
-  ORT_ENFORCE(!has_graph_exec_,
-              "This cuda graph has already captured a graph. "
-              "Create a new instance to capture a new graph.");
+void CUDAGraphManager::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) {
+  ORT_ENFORCE(IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id));
+
+  ORT_ENFORCE(!cuda_graph_set_.Contains(cuda_graph_annotation_id),
+              "Trying to capture a graph with annotation id ", cuda_graph_annotation_id,
+              " that already used. Please use a different annotation id.");
 
   CUDA_CALL_THROW(cudaStreamSynchronize(stream_));
   // For now cuda graph can only work with a single thread. In the future, we
@@ -29,40 +56,48 @@ void CUDAGraph::CaptureBegin() {
   CUDA_CALL_THROW(cudaStreamBeginCapture(stream_, cudaStreamCaptureModeGlobal));
 }
 
-void CUDAGraph::CaptureEnd() {
-  CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph_));
-  if (graph_ == NULL) {
+void CUDAGraphManager::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) {
+  cudaGraph_t graph = NULL;
+  CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph));
+  if (graph == NULL) {
     ORT_THROW("CUDAGraph::CaptureEnd: graph_ is NULL");
   }
 
-  has_graph_ = true;
-  CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
-  has_graph_exec_ = true;
-  CUDA_CALL_THROW(cudaGraphDestroy(graph_));
-  has_graph_ = false;
+  cudaGraphExec_t graph_exec = NULL;
+  CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec, graph, NULL, NULL, 0));
+  CUDA_CALL_THROW(cudaGraphDestroy(graph));
+
+  // Currently all the captured graphs will be tied to the session's lifecycle
+  // TODO(wy): Addd an interface to free captured graphs
+  cuda_graph_set_.Put(cuda_graph_annotation_id, graph_exec);
 }
 
-Status CUDAGraph::Replay() {
+Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id) {
   // Although this function is not thread safe, the lock is not needed here because
   // CUDA EP maintains a separate cuda graph per thread
-  LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_;
-  CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec_, stream_));
+  LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_ << " with cuda_graph_annotation_id "
+                     << cuda_graph_annotation_id;
+
+  cudaGraphExec_t graph_exec = cuda_graph_set_.Get(cuda_graph_annotation_id);
+  CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec, stream_));
+
   CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_));
   return Status::OK();
 }
 
-void CUDAGraph::Reset() {
-  if (has_graph_) {
-    CUDA_CALL_THROW(cudaGraphDestroy(graph_));
-    has_graph_ = false;
-  }
-  if (has_graph_exec_) {
-    CUDA_CALL_THROW(cudaGraphExecDestroy(graph_exec_));
-    has_graph_exec_ = false;
-  }
+bool CUDAGraphManager::IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const {
+  return cuda_graph_annotation_id != kCudaGraphAnnotationSkip;
+}
+
+bool CUDAGraphManager::IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const {
+  return cuda_graph_set_.Contains(cuda_graph_annotation_id);
+}
+
+void CUDAGraphManager::Reset() {
+  cuda_graph_set_.Clear();
 }
 
-CUDAGraph::~CUDAGraph() {
+CUDAGraphManager::~CUDAGraphManager() {
   Reset();
 }
 
diff --git a/onnxruntime/core/providers/cuda/cuda_graph.h b/onnxruntime/core/providers/cuda/cuda_graph.h
index 9bcefcc64ea77..064994c1f14ae 100644
--- a/onnxruntime/core/providers/cuda/cuda_graph.h
+++ b/onnxruntime/core/providers/cuda/cuda_graph.h
@@ -3,33 +3,55 @@
 
 #pragma once
 
+#include <unordered_map>
+
 #include "core/common/common.h"
 #include "core/platform/ort_mutex.h"
 #include "core/providers/cuda/cuda_pch.h"
 
 namespace onnxruntime {
 
-using CaptureId_t = unsigned long long;
+using CudaGraphAnnotation_t = int;
+using CudaGraphSet_t = std::unordered_map<CudaGraphAnnotation_t, cudaGraphExec_t>;
+
+constexpr CudaGraphAnnotation_t kCudaGraphAnnotationSkip = -1;
+constexpr CudaGraphAnnotation_t kCudaGraphAnnotationDefault = 0;
+
+struct CudaGraphSet {
+  CudaGraphSet(){};
+  ~CudaGraphSet();
 
-struct CUDAGraph {
-  CUDAGraph(){};
-  CUDAGraph(cudaStream_t stream);
-  ~CUDAGraph();
+  void Clear();
+  bool Contains(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
+  void Put(CudaGraphAnnotation_t cuda_graph_annotation_id, cudaGraphExec_t graph_exec);
+  cudaGraphExec_t Get(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
+
+ private:
+  CudaGraphSet_t cuda_graphs_;
+};
+
+struct CUDAGraphManager {
+  CUDAGraphManager(){};
+  CUDAGraphManager(cudaStream_t stream);
+  ~CUDAGraphManager();
 
   void SetStream(cudaStream_t stream);
-  void CaptureBegin();
-  void CaptureEnd();
-  Status Replay();
+  void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id);
+  void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id);
+  Status Replay(CudaGraphAnnotation_t cuda_graph_annotation_id);
+
   void Reset();
 
- private:
-  cudaGraph_t graph_ = NULL;
-  cudaGraphExec_t graph_exec_ = NULL;
+  bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
+  bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
 
-  bool has_graph_ = false;
-  bool has_graph_exec_ = false;
+ private:
+  CudaGraphSet cuda_graph_set_;
+  CudaGraphAnnotation_t cuda_graph_annotation_id_ = kCudaGraphAnnotationDefault;
 
   cudaStream_t stream_ = nullptr;  // Does not own the stream
 };
 
+using CUDAGraph = CUDAGraphManager;
+
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc
index f416caecd115f..7afd2d430ec46 100644
--- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc
+++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc
@@ -18,10 +18,14 @@ namespace onnxruntime::cuda {
 
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float,
                                                       BatchNormalization);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double,
+                                                      BatchNormalization);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16,
                                                       BatchNormalization);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float,
                                                       BatchNormalization);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double,
+                                                      BatchNormalization);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16,
                                                       BatchNormalization);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float,
@@ -70,14 +74,34 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kM
                                                       MaxPool);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, float, MaxPool);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, int8_t, MaxPool);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, uint8_t, MaxPool);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float,
                                                       BatchNormalization);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double,
+                                                      BatchNormalization);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16,
                                                       BatchNormalization);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, float,
                                             BatchNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, double,
+                                            BatchNormalization);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16,
                                             BatchNormalization);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, DepthToSpace);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 12, DepthToSpace);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, DepthToSpace);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, SpaceToDepth);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, SpaceToDepth);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+    kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, float, LRN);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+    kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, double, LRN);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+    kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, MLFloat16, LRN);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, float, LRN);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, double, LRN);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, MLFloat16, LRN);
 
 Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
   static const BuildKernelCreateInfoFn nhwc_function_table[] = {
@@ -86,18 +110,26 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
           kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
           kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float, BatchNormalization)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+          kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double, BatchNormalization)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
           kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16, BatchNormalization)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
           kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float, BatchNormalization)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+          kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double, BatchNormalization)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
           kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16, BatchNormalization)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
           kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, BatchNormalization)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+          kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double, BatchNormalization)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
                                                                   MLFloat16, BatchNormalization)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
                                                                   float, BatchNormalization)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
+                                                                  double, BatchNormalization)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
           kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16, Conv)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider,
@@ -135,6 +167,7 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
           kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float, MaxPool)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
           kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16, MaxPool)>,
+
       BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
                                                                   float, AveragePool)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
@@ -147,6 +180,10 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
                                                                   float, MaxPool)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
                                                                   MLFloat16, MaxPool)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
+                                                                  int8_t, MaxPool)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
+                                                                  uint8_t, MaxPool)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
                                                                   float, ConvTranspose)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
@@ -155,6 +192,29 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
           kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float, ConvTranspose)>,
       BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
           kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16, ConvTranspose)>,
+
+      BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
+                                                                      1, 10, DepthToSpace)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
+                                                                      11, 12, DepthToSpace)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
+                                                            13, DepthToSpace)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
+                                                                      1, 12, SpaceToDepth)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
+                                                            13, SpaceToDepth)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+          kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, float, LRN)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+          kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, double, LRN)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
+          kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, MLFloat16, LRN)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
+          kCudaExecutionProvider, kMSInternalNHWCDomain, 13, float, LRN)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
+          kCudaExecutionProvider, kMSInternalNHWCDomain, 13, double, LRN)>,
+      BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
+          kCudaExecutionProvider, kMSInternalNHWCDomain, 13, MLFloat16, LRN)>,
   };
 
   for (auto& function_table_entry : nhwc_function_table) {
diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.h b/onnxruntime/core/providers/cuda/cuda_stream_handle.h
index b02c167e9e9ec..15e7a0553c84e 100644
--- a/onnxruntime/core/providers/cuda/cuda_stream_handle.h
+++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.h
@@ -11,6 +11,7 @@
 namespace onnxruntime {
 
 struct CudaStream;
+void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
 
 struct DeferredCpuAllocator : public OrtAllocator {
   DeferredCpuAllocator(CudaStream&);
@@ -47,6 +48,8 @@ struct CudaStream : Stream {
 
   onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); }
 
+  WaitNotificationFn GetWaitNotificationFn() const override { return WaitCudaNotificationOnDevice; }
+
  private:
   std::vector<void*> deferred_cpu_buffers_;
   AllocatorPtr cpu_allocator_;
@@ -64,5 +67,4 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
                                cudnnHandle_t external_cudnn_handle,
                                cublasHandle_t external_cublass_handle,
                                const CUDAExecutionProviderInfo& ep_info);
-void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc
index c850f7b583bfc..9aa011c1d0ec4 100644
--- a/onnxruntime/core/providers/cuda/cudnn_common.cc
+++ b/onnxruntime/core/providers/cuda/cudnn_common.cc
@@ -37,13 +37,28 @@ Status CudnnTensor::Set(gsl::span<const int64_t> input_dims, cudnnDataType_t dat
   TensorPitches pitches(input_dims);
   InlinedVector<int, kTensorShapeSmallBufferElementsSize> dims(rank);
   InlinedVector<int, kTensorShapeSmallBufferElementsSize> strides(rank);
-  for (int i = 0; i < rank; i++) {
-    dims[i] = gsl::narrow_cast<int>(input_dims[i]);
-    strides[i] = gsl::narrow_cast<int>(pitches[i]);
-  }
-  if (is_nhwc) {
-    std::swap(dims[1], dims[rank - 1]);
-    std::swap(strides[1], strides[rank - 1]);
+
+  if (!is_nhwc) {
+    for (int i = 0; i < rank; i++) {
+      dims[i] = gsl::narrow_cast<int>(input_dims[i]);
+      strides[i] = gsl::narrow_cast<int>(pitches[i]);
+    }
+  } else {
+    // NHWDC <-> NCHWD
+
+    // N
+    dims[0] = gsl::narrow_cast<int>(input_dims[0]);
+    strides[0] = gsl::narrow_cast<int>(pitches[0]);
+
+    // HWD
+    for (int i = 1; i < rank - 1; i++) {
+      dims[i + 1] = gsl::narrow_cast<int>(input_dims[i]);
+      strides[i + 1] = gsl::narrow_cast<int>(pitches[i]);
+    }
+
+    // C
+    dims[1] = gsl::narrow_cast<int>(input_dims[rank - 1]);
+    strides[1] = gsl::narrow_cast<int>(pitches[rank - 1]);
   }
   CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast<int>(rank), dims.data(), strides.data()));
   return Status::OK();
@@ -160,7 +175,6 @@ cudnnDataType_t CudnnTensor::GetDataType<half>() {
 template <>
 cudnnDataType_t CudnnTensor::GetDataType<BFloat16>() {
   ORT_THROW("cuDNN doesn't support BFloat16.");
-  return CUDNN_DATA_FLOAT;
 }
 
 template <>
diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h
index fdd14dedad47e..2cbeb13696270 100644
--- a/onnxruntime/core/providers/cuda/cudnn_common.h
+++ b/onnxruntime/core/providers/cuda/cudnn_common.h
@@ -24,12 +24,12 @@ class CudnnTensor final {
 
   operator cudnnTensorDescriptor_t() const { return tensor_; }
 
+  Status CreateTensorIfNeeded();
+
   template <typename T>
   static cudnnDataType_t GetDataType();
 
  private:
-  Status CreateTensorIfNeeded();
-
   cudnnTensorDescriptor_t tensor_;
 };
 
diff --git a/onnxruntime/core/providers/cuda/math/topk.cc b/onnxruntime/core/providers/cuda/math/topk.cc
index d516537e25949..cf26e0acfa557 100644
--- a/onnxruntime/core/providers/cuda/math/topk.cc
+++ b/onnxruntime/core/providers/cuda/math/topk.cc
@@ -56,7 +56,7 @@ TopK<inputk>::TopK(const OpKernelInfo& info) : CudaKernel(info) {
   info.GetAttrOrDefault<int64_t>("largest", &largest_, 1);
   info.GetAttrOrDefault<int64_t>("sorted", &sorted_, 1);
   if (!inputk) {
-    info.GetAttrOrDefault<int64_t>("k", &K_, 0);
+    info.GetAttrOrDefault<int64_t>("k", &attr_k_, 0);
   }
 }
 
@@ -67,7 +67,7 @@ TopK<inputk>::TopK(const OpKernelInfo& info) : CudaKernel(info) {
                                 static_cast<int64_t*>(tensor_I->MutableDataRaw()), \
                                 elem_nums_cuda,                                    \
                                 elem_nums.size(),                                  \
-                                axis, K_, largest_, sorted_, N, dimension)
+                                axis, k_value, largest_, sorted_, N, dimension)
 
 template <bool inputk>
 Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
@@ -77,19 +77,29 @@ Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
   int32_t axis = static_cast<int32_t>(axis_ < 0 ? rank + axis_ : axis_);
   ORT_ENFORCE(axis > -1 && axis < rank);
 
+  int64_t k_value = 0;
   if (inputk) {
     auto tensor_K = ctx->Input<Tensor>(1);
     ORT_ENFORCE(nullptr != tensor_K);
-    K_ = *tensor_K->Data<int64_t>();
-    ORT_ENFORCE(K_ >= 0 && K_ <= tensor_X->Shape().GetDims()[axis]);
+    k_value = *tensor_K->Data<int64_t>();
+  } else {  // from attribute
+    k_value = attr_k_;
   }
 
-  auto output_shape = tensor_X->Shape();
-  output_shape[axis] = K_;
+  // Now that we know the value of 'K' and the input shape,
+  // make a final validation before going to the implementation
+  const auto& input_shape = tensor_X->Shape();
+  if ((k_value < 0) || (k_value > input_shape.GetDims()[axis])) {
+    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Value of K outside range. K value: ", k_value,
+                           ". Input shape: ", input_shape, " . Axis: ", axis);
+  }
+
+  auto output_shape = input_shape;
+  output_shape[axis] = k_value;
   auto tensor_V = ctx->Output(0, output_shape);
   auto tensor_I = ctx->Output(1, output_shape);
 
-  if (0 == K_) {
+  if (output_shape.Size() == 0) {  // Bail out early if the output is going to be empty
     return Status::OK();
   }
 
diff --git a/onnxruntime/core/providers/cuda/math/topk.h b/onnxruntime/core/providers/cuda/math/topk.h
index 9dec13ad2a930..5731df3130c5a 100644
--- a/onnxruntime/core/providers/cuda/math/topk.h
+++ b/onnxruntime/core/providers/cuda/math/topk.h
@@ -17,7 +17,7 @@ class TopK final : public CudaKernel {
   int64_t axis_;
   int64_t largest_;
   int64_t sorted_;
-  mutable int64_t K_;
+  int64_t attr_k_;
 };
 }  // namespace cuda
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
index 655877f425054..24593b255371c 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
@@ -71,6 +71,88 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa
     return Status::OK();                                                                          \
   }
 
+ONNX_OPERATOR_VERSIONED_KERNEL_EX(
+    IsInf,
+    kOnnxDomain,
+    10,
+    19,
+    kCudaExecutionProvider,
+    (*KernelDefBuilder::Create())
+        .TypeConstraint("T1", BuildKernelDefConstraints<float, double>())
+        .TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
+    IsInf);
+
+ONNX_OPERATOR_KERNEL_EX(
+    IsInf,
+    kOnnxDomain,
+    20,
+    kCudaExecutionProvider,
+    (*KernelDefBuilder::Create())
+        .TypeConstraint("T1", BuildKernelDefConstraints<ISINF_OPSET20_ALL_FLOATS>())
+        .TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
+    IsInf);
+
+IsInf::IsInf(const OpKernelInfo& info) : UnaryElementwise(info) {
+  detect_positive_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("detect_positive", 1));
+  detect_negative_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("detect_negative", 1));
+  opset_ = info.node().SinceVersion();
+}
+
+Status IsInf::ComputeInternal(OpKernelContext* context) const {
+  UnaryElementwisePreparation p;
+  ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p));
+
+  Explicit_Impl_IsInf(Stream(context), opset_, detect_positive_, detect_negative_,
+                      p.input_tensor->GetElementType(), p.input_tensor->DataRaw(),
+                      p.output_tensor->MutableData<bool>(),
+                      p.input_tensor->Shape().Size());
+  return Status::OK();
+}
+
+// IsNan
+ONNX_OPERATOR_VERSIONED_KERNEL_EX(
+    IsNaN,
+    kOnnxDomain,
+    9,
+    12,
+    kCudaExecutionProvider,
+    (*KernelDefBuilder::Create())
+        .TypeConstraint("T1", BuildKernelDefConstraints<ISNAN_OPSET9_FLOATS>())
+        .TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
+    IsNaN);
+
+ONNX_OPERATOR_VERSIONED_KERNEL_EX(
+    IsNaN,
+    kOnnxDomain,
+    13,
+    19,
+    kCudaExecutionProvider,
+    (*KernelDefBuilder::Create())
+        .TypeConstraint("T1", BuildKernelDefConstraints<ISNAN_OPSET13_FLOATS>())
+        .TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
+    IsNaN);
+
+ONNX_OPERATOR_KERNEL_EX(
+    IsNaN,
+    kOnnxDomain,
+    20,
+    kCudaExecutionProvider,
+    (*KernelDefBuilder::Create())
+        .TypeConstraint("T1", BuildKernelDefConstraints<ISNAN_OPSET20_FLOATS>())
+        .TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
+    IsNaN);
+
+Status IsNaN::ComputeInternal(OpKernelContext* context) const {
+  UnaryElementwisePreparation p;
+  ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p));
+
+  Explicit_Impl_IsNan(Stream(context), p.input_tensor->GetElementType(), p.input_tensor->DataRaw(),
+                      p.output_tensor->MutableData<bool>(),
+                      p.input_tensor->Shape().Size());
+
+  return Status::OK();
+}
+
 #define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \
   UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T)
 
@@ -160,7 +242,7 @@ UNARY_OP_CSILHFD(Neg, 13)
 UNARY_OP_HFD(Floor, 13)
 UNARY_OP_HFD(Ceil, 13)
 UNARY_OP_HFD(Reciprocal, 13)
-UNARY_OP_HFD(Sqrt, 13)
+UNARY_OP_HFDX(Sqrt, 13)
 UNARY_OP_HFD(Log, 13)
 UNARY_OP_HFD(Exp, 13)
 UNARY_OP_HFD(Erf, 13)
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
index 775b78c43a736..95d68b5e1d534 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
@@ -2,6 +2,7 @@
 // Licensed under the MIT License.
 
 #pragma once
+
 #include "core/providers/cuda/cuda_kernel.h"
 
 namespace onnxruntime {
@@ -119,5 +120,22 @@ class Sign final : public UnaryElementwise {
   Status ComputeInternal(OpKernelContext* context) const override;
 };
 
+class IsInf final : public UnaryElementwise {
+ public:
+  explicit IsInf(const OpKernelInfo& info);
+  Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+  bool detect_positive_{true};
+  bool detect_negative_{true};
+  int opset_;
+};
+
+class IsNaN : public UnaryElementwise {
+ public:
+  explicit IsNaN(const OpKernelInfo& info) : UnaryElementwise(info) {}
+  Status ComputeInternal(OpKernelContext* context) const override;
+};
+
 }  // namespace cuda
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
index 5c3db4a499972..2cdfcda5be26a 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
@@ -11,6 +11,7 @@
 #endif
 
 namespace onnxruntime {
+
 namespace cuda {
 
 #define OP(name, expr)                                     \
@@ -83,7 +84,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_CSILHFD(Neg)
 SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Floor)
 SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Ceil)
 SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal)
-SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt)
+SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Sqrt)
 SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log)
 SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp)
 SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf)
@@ -126,9 +127,10 @@ struct OP_Cast {
     UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast<InT, OutT>(), count);                  \
   }
 
-#define IMPL_CAST_IMPL_THROW(InT, OutT)                                                                  \
-  void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \
-    ORT_THROW("Cast from " #InT " to " #OutT " must define saturate.");                                  \
+#define IMPL_CAST_IMPL_THROW(InT, OutT)                                                              \
+  void Explicit_Impl_Cast(cudaStream_t /*stream*/, const InT* /*input_data*/, OutT* /*output_data*/, \
+                          size_t /*count*/) {                                                        \
+    ORT_THROW("Cast from " #InT " to " #OutT " must define saturate.");                              \
   }
 
 #if !defined(DISABLE_FLOAT8_TYPES)
@@ -284,5 +286,62 @@ EXPLICIT_IMPL_CASTSAT(__nv_bfloat16, Float8E5M2)
 
 #endif
 
+namespace isinf_details {
+template <typename T>
+struct IsInf_DispFunc {
+  void operator()(cudaStream_t stream, const void* input_raw, bool* output_data,
+                  bool detect_positive, bool detect_negative, size_t count) const {
+    using CudaType = typename ToCudaType<T>::MappedType;
+    const auto* input_data = reinterpret_cast<const CudaType*>(input_raw);
+    if (detect_positive && detect_negative) {
+      UnaryElementWiseImpl(stream, input_data, output_data, _IsInf<CudaType, true, true>{}, count);
+    } else if (detect_positive) {
+      UnaryElementWiseImpl(stream, input_data, output_data, _IsInf<CudaType, true, false>{}, count);
+    } else if (detect_negative) {
+      UnaryElementWiseImpl(stream, input_data, output_data, _IsInf<CudaType, false, true>{}, count);
+    } else {
+      UnaryElementWiseImpl(stream, input_data, output_data, _IsInf<CudaType, false, false>{}, count);
+    }
+  }
+};
+
+}  // namespace isinf_details
+
+void Explicit_Impl_IsInf(cudaStream_t stream, int op_set,
+                         bool detect_positive, bool detect_negative,
+                         int32_t input_data_type,
+                         const void* input_raw, bool* output_data,
+                         size_t count) {
+  if (op_set < 20) {
+    utils::MLTypeCallDispatcher<float, double> dispatcher{input_data_type};
+    dispatcher.Invoke<isinf_details::IsInf_DispFunc>(stream, input_raw, output_data,
+                                                     detect_positive, detect_negative, count);
+  } else {
+    utils::MLTypeCallDispatcher<ISINF_OPSET20_ALL_FLOATS> dispatcher{input_data_type};
+    dispatcher.Invoke<isinf_details::IsInf_DispFunc>(stream, input_raw, output_data,
+                                                     detect_positive, detect_negative, count);
+  }
+}
+
+// IsNan
+
+namespace isnan_details {
+template <typename T>
+struct IsNan_Disp {
+  void operator()(cudaStream_t stream, const void* input_raw, bool* output_data, size_t count) const {
+    using CudaType = typename ToCudaType<T>::MappedType;
+    const auto* input_data = reinterpret_cast<const CudaType*>(input_raw);
+    UnaryElementWiseImpl(stream, input_data, output_data, _IsNan<CudaType>{}, count);
+  }
+};
+}  // namespace isnan_details
+
+void Explicit_Impl_IsNan(cudaStream_t stream, int32_t input_data_type,
+                         const void* input_raw, bool* output_data, size_t count) {
+  // KernelDef constraints would ensure only subset of datatypes is used.
+  utils::MLTypeCallDispatcher<ISNAN_OPSET20_FLOATS> dispatcher{input_data_type};
+  dispatcher.Invoke<isnan_details::IsNan_Disp>(stream, input_raw, output_data, count);
+}
+
 }  // namespace cuda
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
index 608a81a24cf4f..2588f56e32c12 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
@@ -137,5 +137,34 @@ void Impl_CastSat(
 
 #endif
 
+// IsInf
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+#define ISINF_OPSET20_ALL_FLOATS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \
+                                 Float8E5M2FNUZ
+#else
+#define ISINF_OPSET20_ALL_FLOATS float, double, MLFloat16, BFloat16
+#endif
+
+void Explicit_Impl_IsInf(cudaStream_t stream, int op_set,
+                         bool detect_positive, bool detect_negative,
+                         int32_t input_data_type,
+                         const void* input_raw, bool* output_data,
+                         size_t count);
+
+// IsNan
+#define ISNAN_OPSET9_FLOATS float, double, MLFloat16
+#define ISNAN_OPSET13_FLOATS float, double, MLFloat16, BFloat16
+#if !defined(DISABLE_FLOAT8_TYPES)
+#define ISNAN_OPSET20_FLOATS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \
+                             Float8E5M2FNUZ
+#else
+#define ISNAN_OPSET20_FLOATS ISNAN_OPSET13_FLOATS
+#endif
+
+void Explicit_Impl_IsNan(cudaStream_t stream, int32_t input_data_type,
+                         const void* input_raw, bool* output_data, size_t count);
+
 }  // namespace cuda
+
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/nn/batch_norm.cc b/onnxruntime/core/providers/cuda/nn/batch_norm.cc
index c468971e1e426..02da1a2c99dfd 100644
--- a/onnxruntime/core/providers/cuda/nn/batch_norm.cc
+++ b/onnxruntime/core/providers/cuda/nn/batch_norm.cc
@@ -87,7 +87,7 @@ Status BatchNorm<T, NHWC>::ComputeInternal(OpKernelContext* p_op_kernel_context)
 
   CudnnTensor data_desc;
   vector<int64_t> new_dims;
-  BatchNormHelper::NormalizeDims(x_shape, new_dims);
+  BatchNormHelper::NormalizeDims(x_shape, new_dims, NHWC);
   ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, CudnnTensor::GetDataType<CudaT>(), NHWC));
 
   // For half data type, the alpha, beta, scale, B, mean, var need to be float type
@@ -137,6 +137,12 @@ Status BatchNorm<T, NHWC>::ComputeInternal(OpKernelContext* p_op_kernel_context)
     auto saved_mean_data = reinterpret_cast<CudaT*>(saved_mean->MutableData<T>());
     auto saved_inv_var_data = reinterpret_cast<CudaT*>(saved_var->MutableData<T>());
 
+    auto stream = static_cast<cudaStream_t>(p_op_kernel_context->GetComputeStream()->GetHandle());
+    CUDA_RETURN_IF_ERROR(
+        cudaMemcpyAsync(running_mean_data, mean_data, mean->SizeInBytes(), cudaMemcpyDeviceToDevice, stream));
+    CUDA_RETURN_IF_ERROR(
+        cudaMemcpyAsync(running_var_data, var_data, var->SizeInBytes(), cudaMemcpyDeviceToDevice, stream));
+
     CUDNN_RETURN_IF_ERROR(BatchNormalizationForwardTrainingHelper(
         GetCudnnHandle(p_op_kernel_context),
         cudnn_batch_norm_mode_,
@@ -149,7 +155,7 @@ Status BatchNorm<T, NHWC>::ComputeInternal(OpKernelContext* p_op_kernel_context)
         bn_tensor_desc,
         scale_data,
         b_data,
-        momentum_,
+        1.0 - momentum_,
         running_mean_data,
         running_var_data,
         epsilon_,
@@ -186,6 +192,7 @@ SPECIALIZED_COMPUTE(MLFloat16, kOnnxDomain, false)
 
 #ifdef ENABLE_CUDA_NHWC_OPS
 SPECIALIZED_COMPUTE(float, kMSInternalNHWCDomain, true)
+SPECIALIZED_COMPUTE(double, kMSInternalNHWCDomain, true)
 SPECIALIZED_COMPUTE(MLFloat16, kMSInternalNHWCDomain, true)
 #endif
 }  // namespace cuda
diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc
index 82f3503919237..e05786248cbcf 100644
--- a/onnxruntime/core/providers/cuda/nn/conv.cc
+++ b/onnxruntime/core/providers/cuda/nn/conv.cc
@@ -97,11 +97,11 @@ Status SliceOutUnwantedOutputSection(cudaStream_t stream,
 
 template <typename T, bool NHWC>
 Status Conv<T, NHWC>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
-                              bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) {
+                              bool& is_packed, PrePackedWeights* /*prepacked_weights*/) {
   is_packed = false;
   // only layout of weight input is adjusted via PrePack
-  if (NHWC && is_nhwc_domain_) {  // InputTensors::IN_W
-    if (input_idx == 1) {
+  if constexpr (NHWC) {
+    if (is_nhwc_domain_ && input_idx == 1) {  // InputTensors::IN_W
       // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group}
       auto orig_shape = tensor.Shape();
 
@@ -123,6 +123,10 @@ Status Conv<T, NHWC>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr
       CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream()));
       is_packed = true;
     }
+  } else {
+    ORT_UNUSED_PARAMETER(tensor);
+    ORT_UNUSED_PARAMETER(input_idx);
+    ORT_UNUSED_PARAMETER(alloc);
   }
 
   return Status::OK();
@@ -149,8 +153,11 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
 
   // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC.
   constexpr bool channels_last = NHWC;
-  if (channels_last && (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4)) {
-    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of dimensions of X and W should be 4 for channels_last format (NHWC)");
+  if constexpr (channels_last) {
+    if (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4) {
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+                             "Number of dimensions of X and W should be 4 for channels_last format (NHWC)");
+    }
   }
 
   // set B
@@ -326,7 +333,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
 
     ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
                                          gsl::narrow_cast<int>(conv_attrs_.group),
-                                         CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType<CudaT>()));
+                                         CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType<CudaT>(),
+                                         UseTF32()));
 
     if (context->InputCount() >= 3) {
       const Tensor* B = context->Input<Tensor>(2);
@@ -351,8 +359,13 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
 
     if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) {
       // set math type to tensor core before algorithm search
-      if constexpr (std::is_same<T, MLFloat16>::value)
+      if constexpr (std::is_same<T, MLFloat16>::value) {
         CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH));
+      } else if constexpr (std::is_same<T, float>::value) {
+        if (!UseTF32()) {
+          CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH));
+        }
+      }
 
       cudnnConvolutionFwdAlgoPerf_t perf;
       int algo_count = 1;
@@ -397,8 +410,11 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
         default:
           perf.algo = kDefaultConvAlgo;
           CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory));
-          if (std::is_same<T, MLFloat16>::value) {
+
+          if constexpr (std::is_same<T, MLFloat16>::value) {
             perf.mathType = CUDNN_TENSOR_OP_MATH;
+          } else if (std::is_same<T, float>::value && !UseTF32()) {
+            perf.mathType = CUDNN_FMA_MATH;
           } else {
             perf.mathType = CUDNN_DEFAULT_MATH;
           }
@@ -480,7 +496,8 @@ Status CudnnConvolutionDescriptor::Set(
     const gsl::span<const int64_t>& dilations,
     int groups,
     cudnnConvolutionMode_t mode,
-    cudnnDataType_t data_type) {
+    cudnnDataType_t data_type,
+    bool use_tf32) {
   if (!desc_)
     CUDNN_RETURN_IF_ERROR(cudnnCreateConvolutionDescriptor(&desc_));
 
@@ -513,6 +530,8 @@ Status CudnnConvolutionDescriptor::Set(
   CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH));
   if (data_type == CUDNN_DATA_HALF) {
     CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH));
+  } else if (data_type == CUDNN_DATA_FLOAT && !use_tf32) {
+    CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_FMA_MATH));
   }
 
   return Status::OK();
diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h
index bcaa4d855b81e..3aec654224e39 100644
--- a/onnxruntime/core/providers/cuda/nn/conv.h
+++ b/onnxruntime/core/providers/cuda/nn/conv.h
@@ -29,7 +29,8 @@ class CudnnConvolutionDescriptor final {
              const gsl::span<const int64_t>& dilations,
              int groups,
              cudnnConvolutionMode_t mode,
-             cudnnDataType_t data_type);
+             cudnnDataType_t data_type,
+             bool use_tf32);
 
   operator cudnnConvolutionDescriptor_t() const { return desc_; }
 
@@ -194,7 +195,7 @@ class Conv : public CudaKernel {
   }
 
   Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
-                 bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override;
+                 bool& is_packed, PrePackedWeights* prepacked_weights) override;
 
   Status ComputeInternal(OpKernelContext* context) const override;
 
diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc
index 55dceaa2698e8..939b9959af818 100644
--- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc
+++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc
@@ -167,7 +167,8 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
       cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
       ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations,
                                            gsl::narrow_cast<int>(conv_transpose_attrs_.group), mode,
-                                           CudnnTensor::GetDataType<CudaT>()));
+                                           CudnnTensor::GetDataType<CudaT>(),
+                                           UseTF32()));
 
       if (has_bias) {
         const auto& b_shape = p.B->Shape();
@@ -187,8 +188,13 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
             GetScratchBuffer<void>(AlgoSearchWorkspaceSize, context->GetComputeStream());
 
         // set math type to tensor core before algorithm search
-        if constexpr (std::is_same<T, MLFloat16>::value)
+        if constexpr (std::is_same<T, MLFloat16>::value) {
           CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH));
+        } else if constexpr (std::is_same<T, float>::value) {
+          if (!UseTF32()) {
+            CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH));
+          }
+        }
 
         cudnnConvolutionBwdDataAlgoPerf_t perf;
         int algo_count = 1;
diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm.h b/onnxruntime/core/providers/cuda/nn/layer_norm.h
index ff231f4f1ad5c..c021d3ffe63a2 100644
--- a/onnxruntime/core/providers/cuda/nn/layer_norm.h
+++ b/onnxruntime/core/providers/cuda/nn/layer_norm.h
@@ -7,8 +7,6 @@
 namespace onnxruntime {
 namespace cuda {
 
-using namespace onnxruntime::cuda;
-
 // NOTE: This was originally a contrib op with 3 type constraints. The ONNX spec merges 'T' and 'V'.
 // the kernel is templatized on all three for backwards compatibility, but in ONNX usage T == V.
 template <typename T, typename U, typename V, bool simplified>
diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu
index 679b8b6b78886..b9e8b45307079 100644
--- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu
+++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu
@@ -29,8 +29,6 @@
 namespace onnxruntime {
 namespace cuda {
 
-using namespace onnxruntime::cuda;
-
 template <typename U, bool simplified>
 __device__ void cuWelfordOnlineSum(
     const U curr,
diff --git a/onnxruntime/core/providers/cuda/nn/lrn.cc b/onnxruntime/core/providers/cuda/nn/lrn.cc
index 6fcdec74d84b5..788299b5eb8d6 100644
--- a/onnxruntime/core/providers/cuda/nn/lrn.cc
+++ b/onnxruntime/core/providers/cuda/nn/lrn.cc
@@ -6,37 +6,47 @@
 namespace onnxruntime {
 namespace cuda {
 
-#define REGISTER_KERNEL_VERSIONED_TYPED(START_VER, END_VER, T)                             \
+#define REGISTER_KERNEL_VERSIONED_TYPED(START_VER, END_VER, T, DOMAIN, LAYOUT)             \
   ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(                                                 \
       LRN,                                                                                 \
-      kOnnxDomain,                                                                         \
+      DOMAIN,                                                                              \
       START_VER,                                                                           \
       END_VER,                                                                             \
       T,                                                                                   \
       kCudaExecutionProvider,                                                              \
       (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
-      LRN<T>);
+      LRN<T, LAYOUT>);
 
-#define REGISTER_KERNEL_TYPED(VER, T)                                                      \
+#define REGISTER_KERNEL_TYPED(VER, T, DOMAIN, LAYOUT)                                      \
   ONNX_OPERATOR_TYPED_KERNEL_EX(                                                           \
       LRN,                                                                                 \
-      kOnnxDomain,                                                                         \
+      DOMAIN,                                                                              \
       VER,                                                                                 \
       T,                                                                                   \
       kCudaExecutionProvider,                                                              \
       (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
-      LRN<T>);
+      LRN<T, LAYOUT>);
 
-REGISTER_KERNEL_VERSIONED_TYPED(1, 12, float)
-REGISTER_KERNEL_VERSIONED_TYPED(1, 12, double)
-REGISTER_KERNEL_VERSIONED_TYPED(1, 12, MLFloat16)
+REGISTER_KERNEL_VERSIONED_TYPED(1, 12, float, kOnnxDomain, false)
+REGISTER_KERNEL_VERSIONED_TYPED(1, 12, double, kOnnxDomain, false)
+REGISTER_KERNEL_VERSIONED_TYPED(1, 12, MLFloat16, kOnnxDomain, false)
 
-REGISTER_KERNEL_TYPED(13, float)
-REGISTER_KERNEL_TYPED(13, double)
-REGISTER_KERNEL_TYPED(13, MLFloat16)
+REGISTER_KERNEL_TYPED(13, float, kOnnxDomain, false)
+REGISTER_KERNEL_TYPED(13, double, kOnnxDomain, false)
+REGISTER_KERNEL_TYPED(13, MLFloat16, kOnnxDomain, false)
 
-template <typename T>
-LRN<T>::LRN(const OpKernelInfo& info) : CudaKernel(info) {
+#ifdef ENABLE_CUDA_NHWC_OPS
+REGISTER_KERNEL_VERSIONED_TYPED(1, 12, float, kMSInternalNHWCDomain, true)
+REGISTER_KERNEL_VERSIONED_TYPED(1, 12, double, kMSInternalNHWCDomain, true)
+REGISTER_KERNEL_VERSIONED_TYPED(1, 12, MLFloat16, kMSInternalNHWCDomain, true)
+
+REGISTER_KERNEL_TYPED(13, float, kMSInternalNHWCDomain, true)
+REGISTER_KERNEL_TYPED(13, double, kMSInternalNHWCDomain, true)
+REGISTER_KERNEL_TYPED(13, MLFloat16, kMSInternalNHWCDomain, true)
+#endif
+
+template <typename T, bool Layout>
+LRN<T, Layout>::LRN(const OpKernelInfo& info) : CudaKernel(info) {
   int64_t size;
   ORT_ENFORCE(info.GetAttr<int64_t>("size", &size).IsOK());
   ORT_ENFORCE(size > 0);
@@ -58,8 +68,8 @@ LRN<T>::LRN(const OpKernelInfo& info) : CudaKernel(info) {
                   .IsOK());
 }
 
-template <typename T>
-Status LRN<T>::ComputeInternal(OpKernelContext* context) const {
+template <typename T, bool Layout>
+Status LRN<T, Layout>::ComputeInternal(OpKernelContext* context) const {
   typedef typename ToCudaType<T>::MappedType CudaT;
 
   const Tensor* X = context->Input<Tensor>(0);
@@ -71,7 +81,7 @@ Status LRN<T>::ComputeInternal(OpKernelContext* context) const {
   Tensor* Y = context->Output(0, X->Shape());
 
   CudnnTensor x_tensor;
-  ORT_RETURN_IF_ERROR(x_tensor.Set(X->Shape().GetDims(), CudnnTensor::GetDataType<CudaT>()));
+  ORT_RETURN_IF_ERROR(x_tensor.Set(X->Shape().GetDims(), CudnnTensor::GetDataType<CudaT>(), Layout == NHWC));
 
   const auto one = Consts<CudaT>::One;
   const auto zero = Consts<CudaT>::Zero;
diff --git a/onnxruntime/core/providers/cuda/nn/lrn.h b/onnxruntime/core/providers/cuda/nn/lrn.h
index 319e323c72a92..31b2819ccc52a 100644
--- a/onnxruntime/core/providers/cuda/nn/lrn.h
+++ b/onnxruntime/core/providers/cuda/nn/lrn.h
@@ -20,7 +20,7 @@ class CudnnLRNDescriptor final {
   cudnnLRNDescriptor_t desc_;
 };
 
-template <typename T>
+template <typename T, bool Layout>
 class LRN : public CudaKernel {
  public:
   LRN(const OpKernelInfo& info);
diff --git a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu
index ef1155af127d1..9311f044f4ec5 100644
--- a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu
+++ b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.cu
@@ -7,10 +7,11 @@
 
 #include "core/providers/cuda/cu_inc/common.cuh"
 #include "core/providers/cuda/shared_inc/fast_divmod.h"
+#include "core/providers/cuda/shared_inc/cuda_utils.h"
 
 namespace onnxruntime {
 namespace cuda {
-template <typename T>
+template <typename T, bool Layout>
 __global__ void MaxPoolWithIndexKernel(
     int64_t batch,
     int64_t channels,
@@ -44,11 +45,27 @@ __global__ void MaxPoolWithIndexKernel(
   int id = blockIdx.x * blockDim.x + threadIdx.x;
   if (id >= output_size) return;
 
+  auto compute_offset =
+    [height, width, depth, channels](int n_index, int c_index, int h_index, int w_index, int d_index) -> int64_t {
+    if constexpr (Layout == LAYOUT_NCHW) {
+      return (((n_index * channels + c_index) * height + h_index) * width + w_index) * depth + d_index;
+    } else if constexpr (Layout == LAYOUT_NHWC) {
+      return (((n_index * height + h_index) * width + w_index) * depth + d_index) * channels + c_index;
+    }
+  };
+
   int d_index, w_index, h_index, c_index, n_index, id_tmp;
-  fdm_d.divmod(id, id_tmp, d_index);
-  fdm_w.divmod(id_tmp, id_tmp, w_index);
-  fdm_h.divmod(id_tmp, id_tmp, h_index);
-  fdm_c.divmod(id_tmp, n_index, c_index);
+  if constexpr (Layout == LAYOUT_NCHW) {
+    fdm_d.divmod(id, id_tmp, d_index);
+    fdm_w.divmod(id_tmp, id_tmp, w_index);
+    fdm_h.divmod(id_tmp, id_tmp, h_index);
+    fdm_c.divmod(id_tmp, n_index, c_index);
+  } else if constexpr (Layout == LAYOUT_NHWC) {
+    fdm_c.divmod(id, id_tmp, c_index);
+    fdm_d.divmod(id_tmp, id_tmp, d_index);
+    fdm_w.divmod(id_tmp, id_tmp, w_index);
+    fdm_h.divmod(id_tmp, n_index, h_index);
+  }
 
   int64_t d_start = d_index * stride_d - pad_d;
   int64_t w_start = w_index * stride_w - pad_w;
@@ -64,29 +81,45 @@ __global__ void MaxPoolWithIndexKernel(
   int64_t d_index_max = -1;
   int64_t w_index_max = -1;
   int64_t h_index_max = -1;
-  int64_t offset = (n_index * channels + c_index) * height * width * depth;
+  int64_t offset = compute_offset(n_index, c_index, 0, 0, 0);
   const T* p_slice = p_input + offset;
-  T maxval = p_slice[h_start * width * depth + w_start * depth + d_start] - (T)1;
+  T maxval = p_slice[compute_offset(0, 0, h_start, w_start, d_start)] - (T)1;
   for (int64_t d = d_start; d < d_end; d += dilation_d) {
     for (int64_t w = w_start; w < w_end; w += dilation_w) {
       for (int64_t h = h_start; h < h_end; h += dilation_h) {
-        if (p_slice[h * width * depth + w * depth + d] > maxval) {
+        auto pool_offset = compute_offset(0, 0, h, w, d);
+        if (p_slice[pool_offset] > maxval) {
           h_index_max = h;
           w_index_max = w;
           d_index_max = d;
-          maxval = static_cast<float>(p_slice[h * width * depth + w * depth + d]);
+          maxval = static_cast<float>(p_slice[pool_offset]);
         }
       }
     }
   }
-  p_output[id] = p_input[offset + h_index_max * width * depth + w_index_max * depth + d_index_max];
+  p_output[id] = p_input[offset + compute_offset(0, 0, h_index_max, w_index_max, d_index_max)];
+
   if (p_indices) {
-    p_indices[id] = storage_order == 0 ? offset + h_index_max * width * depth + w_index_max * depth + d_index_max
-                                       : offset + h_index_max + w_index_max * height + d_index_max * width * height;
+    if constexpr (Layout == LAYOUT_NCHW) {
+      p_indices[id] = storage_order == 0 ? offset + h_index_max * width * depth + w_index_max * depth + d_index_max
+                                         : offset + h_index_max + w_index_max * height + d_index_max * width * height;
+    } else if constexpr (Layout == LAYOUT_NHWC) {
+      // The tests currently have to be provided in NHWC layout so that tests do not fail. When converting between
+      // layouts, does it make sense to do an index conversion as well?
+      // Storing indices in NHWC layout isn't critical as they are supposed to be used by Unpooling operations
+      // which currently assume that indices reference to Tensors in NHWC layout.
+      int64_t id_nchw = 
+        (((n_index * channels + c_index) * pooled_height + h_index) * pooled_width + w_index) * pooled_depth + d_index;
+      int64_t offset_nchw = (n_index * channels + c_index) * width * height * depth;
+
+      p_indices[id_nchw] = (storage_order == 0)
+                               ? offset_nchw + h_index_max * width * depth + w_index_max * depth + d_index_max
+                               : offset_nchw + h_index_max + w_index_max * height + d_index_max * width * height;
+    }
   }
 }
 
-template <typename T>
+template <typename T, bool Layout>
 void MaxPoolWithIndex(
     cudaStream_t stream,
     const TensorShape& input_shape,
@@ -99,14 +132,29 @@ void MaxPoolWithIndex(
     const T* p_input,
     T* p_output,
     int64_t* p_indices) {
-  int64_t batchs = input_shape[0];
-  int64_t channels = input_shape[1];
-  int64_t height = input_shape[2];
-  int64_t width = kernel_shape.size() > 1 ? input_shape[3] : 1;
-  int64_t depth = kernel_shape.size() > 2 ? input_shape[4] : 1;
-  int64_t pooled_height = output_shape[2];
-  int64_t pooled_width = kernel_shape.size() > 1 ? output_shape[3] : 1;
-  int64_t pooled_depth = kernel_shape.size() > 2 ? output_shape[4] : 1;
+  int64_t batchs, channels, height, width, depth;
+  int64_t pooled_height, pooled_width, pooled_depth;
+  if constexpr (Layout == LAYOUT_NCHW) {
+    batchs = input_shape[0];
+    channels = input_shape[1];
+    height = input_shape[2];
+    width = kernel_shape.size() > 1 ? input_shape[3] : 1;
+    depth = kernel_shape.size() > 2 ? input_shape[4] : 1;
+
+    pooled_height = output_shape[2];
+    pooled_width = kernel_shape.size() > 1 ? output_shape[3] : 1;
+    pooled_depth = kernel_shape.size() > 2 ? output_shape[4] : 1;
+  } else if constexpr (Layout == LAYOUT_NHWC) {
+    batchs = input_shape[0];
+    height = input_shape[1];
+    width = kernel_shape.size() > 1 ? input_shape[2] : 1;
+    depth = kernel_shape.size() > 2 ? input_shape[3] : 1;
+    channels = input_shape[input_shape.NumDimensions() - 1];
+
+    pooled_height = output_shape[1];
+    pooled_width = kernel_shape.size() > 1 ? output_shape[2] : 1;
+    pooled_depth = kernel_shape.size() > 2 ? output_shape[3] : 1;
+  }
   int64_t kernel_h = kernel_shape[0];
   int64_t kernel_w = kernel_shape.size() > 1 ? kernel_shape[1] : 1;
   int64_t kernel_d = kernel_shape.size() > 2 ? kernel_shape[2] : 1;
@@ -130,7 +178,7 @@ void MaxPoolWithIndex(
   fast_divmod fdm_d(static_cast<int>(pooled_depth));
 
   int blocksPerGrid = (int)((output_size + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock);
-  MaxPoolWithIndexKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
+  MaxPoolWithIndexKernel<T, Layout><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
       batchs,
       channels,
       height,
@@ -162,8 +210,8 @@ void MaxPoolWithIndex(
       p_indices);
 }
 
-#define INSTANTIATEMAXPOOLWITHINDEX(T)              \
-  template void MaxPoolWithIndex<T>(                \
+#define INSTANTIATEMAXPOOLWITHINDEX(T, Layout)      \
+  template void MaxPoolWithIndex<T, Layout>(        \
       cudaStream_t stream,                          \
       const TensorShape& input_shape,               \
       const TensorShape& output_shape,              \
@@ -176,11 +224,19 @@ void MaxPoolWithIndex(
       T* p_output,                                  \
       int64_t* p_indices);
 
-INSTANTIATEMAXPOOLWITHINDEX(float)
-INSTANTIATEMAXPOOLWITHINDEX(double)
-INSTANTIATEMAXPOOLWITHINDEX(half)
-INSTANTIATEMAXPOOLWITHINDEX(int8_t)
-INSTANTIATEMAXPOOLWITHINDEX(uint8_t)
+INSTANTIATEMAXPOOLWITHINDEX(float, LAYOUT_NCHW)
+INSTANTIATEMAXPOOLWITHINDEX(double, LAYOUT_NCHW)
+INSTANTIATEMAXPOOLWITHINDEX(half, LAYOUT_NCHW)
+INSTANTIATEMAXPOOLWITHINDEX(int8_t, LAYOUT_NCHW)
+INSTANTIATEMAXPOOLWITHINDEX(uint8_t, LAYOUT_NCHW)
+
+#ifdef ENABLE_CUDA_NHWC_OPS
+INSTANTIATEMAXPOOLWITHINDEX(float, LAYOUT_NHWC)
+INSTANTIATEMAXPOOLWITHINDEX(double, LAYOUT_NHWC)
+INSTANTIATEMAXPOOLWITHINDEX(half, LAYOUT_NHWC)
+INSTANTIATEMAXPOOLWITHINDEX(int8_t, LAYOUT_NHWC)
+INSTANTIATEMAXPOOLWITHINDEX(uint8_t, LAYOUT_NHWC)
+#endif
 
 }  // namespace cuda
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.h b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.h
index 27f5b241cc785..98f14c3f6a626 100644
--- a/onnxruntime/core/providers/cuda/nn/max_pool_with_index.h
+++ b/onnxruntime/core/providers/cuda/nn/max_pool_with_index.h
@@ -7,7 +7,7 @@
 
 namespace onnxruntime {
 namespace cuda {
-template <typename T>
+template <typename T, bool Layout>
 void MaxPoolWithIndex(
     cudaStream_t stream,
     const TensorShape& input_shape,
diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc
index 8bc96958693bc..4acdcfcf35491 100644
--- a/onnxruntime/core/providers/cuda/nn/pool.cc
+++ b/onnxruntime/core/providers/cuda/nn/pool.cc
@@ -87,6 +87,8 @@ POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, float, MaxPool<8>, 11, 11, kMSInt
 POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 11, 11, kMSInternalNHWCDomain, true)
 POOLING_KERNEL_WITH_INDICES(MaxPool, float, MaxPool<8>, 12, kMSInternalNHWCDomain, true)
 POOLING_KERNEL_WITH_INDICES(MaxPool, MLFloat16, MaxPool<8>, 12, kMSInternalNHWCDomain, true)
+POOLING_KERNEL_WITH_INDICES(MaxPool, int8_t, MaxPool<8>, 12, kMSInternalNHWCDomain, true)
+POOLING_KERNEL_WITH_INDICES(MaxPool, uint8_t, MaxPool<8>, 12, kMSInternalNHWCDomain, true)
 
 POOLING_KERNEL(GlobalMaxPool, float, MaxPool<1>, 1, kMSInternalNHWCDomain, true)
 POOLING_KERNEL(GlobalMaxPool, MLFloat16, MaxPool<1>, 1, kMSInternalNHWCDomain, true)
@@ -145,8 +147,8 @@ class CudnnPoolingDescriptor final {
   cudnnPoolingDescriptor_t desc_;
 };
 
-template <typename T, typename PoolType, bool NHWC>
-Status Pool<T, PoolType, NHWC>::ComputeInternal(OpKernelContext* context) const {
+template <typename T, typename PoolType, bool Layout>
+Status Pool<T, PoolType, Layout>::ComputeInternal(OpKernelContext* context) const {
   typedef typename ToCudaType<T>::MappedType CudaT;
   const Tensor* X = context->Input<Tensor>(0);
   const TensorShape& x_shape = X->Shape();
@@ -157,16 +159,21 @@ Status Pool<T, PoolType, NHWC>::ComputeInternal(OpKernelContext* context) const
   }
 
   auto kernel_shape = pool_attrs_.kernel_shape;
-  auto pads = pool_attrs_.pads;
   auto strides = pool_attrs_.strides;
+  TensorShapeVector pads = pool_attrs_.pads;
 
   if (pool_attrs_.global_pooling) {
-    kernel_shape.assign(x_dims.begin() + 2, x_dims.end());
-    pads.assign(kernel_shape.size(), 0);
+    if constexpr (Layout == LAYOUT_NCHW) {
+      kernel_shape.assign(x_dims.begin() + 2, x_dims.end());
+    } else if constexpr (Layout == LAYOUT_NHWC) {
+      kernel_shape.assign(x_dims.begin() + 1, x_dims.end() - 1);
+    }
+    pads.assign(2 * kernel_shape.size(), 0);
     strides.assign(kernel_shape.size(), 1);
   }
-  auto out_channel = NHWC ? x_shape[3] : x_shape[1];
-  auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, NHWC);
+  auto out_channel = (Layout == LAYOUT_NHWC) ? x_shape[x_dims.size() - 1] : x_shape[1];
+
+  auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, Layout == LAYOUT_NHWC);
   TensorShape y_shape(y_dims);
   Tensor* Y = context->Output(0, y_shape);
   // special case when there is a dim value of 0 in the shape.
@@ -178,20 +185,22 @@ Status Pool<T, PoolType, NHWC>::ComputeInternal(OpKernelContext* context) const
   TensorShapeVector x_dims_cudnn(x_dims.begin(), x_dims.end());
   TensorShapeVector y_dims_cudnn(y_dims);
   if (kernel_shape.size() < 2) {
-    // cudnn only takes 4D or 5D input, so pad dimensions if needed
-    if (NHWC) {
-      x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, 1);
-      y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, 1);
-      kernel_shape.insert(kernel_shape.begin() + 1, 1);
-      strides.insert(strides.begin() + 1, 1);
-    } else {
-      x_dims_cudnn.push_back(1);
-      y_dims_cudnn.push_back(1);
-      kernel_shape.push_back(1);
-      strides.push_back(1);
+    // cuDNN only takes 4D or 5D input, so pad dimensions if needed
+    if constexpr (Layout == LAYOUT_NHWC) {
+      x_dims_cudnn.insert(x_dims_cudnn.end() - 1, 1);
+      y_dims_cudnn.insert(y_dims_cudnn.end() - 1, 1);
+      pads.insert(pads.begin() + pads.size() / 2, 0);
+      pads.insert(pads.end(), 0);
+      kernel_shape.insert(kernel_shape.end(), 1);
+      strides.insert(strides.end(), 1);
+    } else {  // Layout == LAYOUT_NCHW
+      x_dims_cudnn.insert(x_dims_cudnn.end(), 1);
+      y_dims_cudnn.insert(y_dims_cudnn.end(), 1);
+      pads.insert(pads.begin() + pads.size() / 2, 0);
+      pads.insert(pads.end(), 0);
+      kernel_shape.insert(kernel_shape.end(), 1);
+      strides.insert(strides.end(), 1);
     }
-    pads.insert(pads.begin() + kernel_shape.size(), 0);
-    pads.insert(pads.end(), 0);
   }
 
   cudnnPoolingMode_t mode = CUDNN_POOLING_MAX;
@@ -208,8 +217,8 @@ Status Pool<T, PoolType, NHWC>::ComputeInternal(OpKernelContext* context) const
     const auto beta = Consts<float>::Zero;
     CudnnTensor x_tensor;
     CudnnTensor y_tensor;
-    ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType<float>(), NHWC));
-    ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType<float>(), NHWC));
+    ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType<float>(), Layout == LAYOUT_NHWC));
+    ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType<float>(), Layout == LAYOUT_NHWC));
 
     const auto input_count = x_shape.Size();
     const auto output_count = y_shape.Size();
@@ -225,8 +234,8 @@ Status Pool<T, PoolType, NHWC>::ComputeInternal(OpKernelContext* context) const
     const auto beta = Consts<CudaT>::Zero;
     CudnnTensor x_tensor;
     CudnnTensor y_tensor;
-    ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType<CudaT>(), NHWC));
-    ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType<CudaT>(), NHWC));
+    ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType<CudaT>(), Layout == LAYOUT_NHWC));
+    ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType<CudaT>(), Layout == LAYOUT_NHWC));
 
     CUDNN_RETURN_IF_ERROR(
         PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data));
@@ -235,8 +244,8 @@ Status Pool<T, PoolType, NHWC>::ComputeInternal(OpKernelContext* context) const
   return Status::OK();
 }
 
-template <typename T, bool NHWC>
-Status Pool<T, MaxPool<8>, NHWC>::ComputeInternal(OpKernelContext* context) const {
+template <typename T, bool Layout>
+Status Pool<T, MaxPool<8>, Layout>::ComputeInternal(OpKernelContext* context) const {
   typedef typename ToCudaType<T>::MappedType CudaT;
   const Tensor* X = context->Input<Tensor>(0);
   const TensorShape& x_shape = X->Shape();
@@ -251,12 +260,16 @@ Status Pool<T, MaxPool<8>, NHWC>::ComputeInternal(OpKernelContext* context) cons
   auto strides = this->pool_attrs_.strides;
 
   if (this->pool_attrs_.global_pooling) {
-    kernel_shape.assign(x_dims.begin() + 2, x_dims.end());
-    pads.assign(kernel_shape.size(), 0);
+    if constexpr (Layout == LAYOUT_NCHW) {
+      kernel_shape.assign(x_dims.begin() + 2, x_dims.end());
+    } else if constexpr (Layout == LAYOUT_NHWC) {
+      kernel_shape.assign(x_dims.begin() + 1, x_dims.end() - 1);
+    }
+    pads.assign(2 * kernel_shape.size(), 0);  // x{i}_begin + x{i}_end
     strides.assign(kernel_shape.size(), 1);
   }
-  auto out_channel = NHWC ? x_shape[3] : x_shape[1];
-  auto y_dims = this->pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, NHWC);
+  auto out_channel = Layout == LAYOUT_NHWC ? x_shape[x_shape.NumDimensions() - 1] : x_shape[1];
+  auto y_dims = this->pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, Layout == LAYOUT_NHWC);
   Tensor* Y = context->Output(0, TensorShape(y_dims));
 
   // special case when there is a dim value of 0 in the shape.
@@ -265,13 +278,22 @@ Status Pool<T, MaxPool<8>, NHWC>::ComputeInternal(OpKernelContext* context) cons
   auto x_data = reinterpret_cast<const CudaT*>(X->Data<T>());
   auto y_data = reinterpret_cast<CudaT*>(Y->MutableData<T>());
 
-  Tensor* I = context->Output(1, TensorShape(y_dims));
+  // I is in NCHW format and the contained indices use NCHW math to compute the index
+  auto i_dims = y_dims;
+  if constexpr (Layout == LAYOUT_NHWC) {
+    // y_dims in NHWDC format, i_dims has to be in NCHWD format.
+    i_dims.insert(i_dims.begin() + 1, i_dims.back());  // N*C*HWDC
+    i_dims.pop_back();                                 // NCHW
+  }
+
+  Tensor* I = context->Output(1, TensorShape(i_dims));
   if (nullptr != I || !this->pool_attrs_.default_dilations) {
     auto i_data = nullptr == I ? nullptr : I->MutableData<int64_t>();
-    MaxPoolWithIndex<CudaT>(this->Stream(context), x_shape, TensorShape(y_dims), kernel_shape, strides, pads,
-                            this->pool_attrs_.dilations, this->pool_attrs_.storage_order, x_data, y_data, i_data);
+    MaxPoolWithIndex<CudaT, Layout == LAYOUT_NHWC>(this->Stream(context), x_shape, TensorShape(y_dims), kernel_shape,
+                                                   strides, pads, this->pool_attrs_.dilations,
+                                                   this->pool_attrs_.storage_order, x_data, y_data, i_data);
   } else {
-    ORT_RETURN_IF_ERROR((Pool<T, MaxPool<1>, NHWC>::ComputeInternal(context)));
+    ORT_RETURN_IF_ERROR((Pool<T, MaxPool<1>, Layout == LAYOUT_NHWC>::ComputeInternal(context)));
   }
   return Status::OK();
 }
diff --git a/onnxruntime/core/providers/cuda/nn/pool.h b/onnxruntime/core/providers/cuda/nn/pool.h
index 8b5152a1565a9..97f7c8b8762d5 100644
--- a/onnxruntime/core/providers/cuda/nn/pool.h
+++ b/onnxruntime/core/providers/cuda/nn/pool.h
@@ -19,10 +19,10 @@ class Pool : public CudaKernel, public PoolBase {
   Status ComputeInternal(OpKernelContext* context) const override;
 };
 
-template <typename T, bool NHWC>
-class Pool<T, MaxPool<8>, NHWC> final : public Pool<T, MaxPool<1>, NHWC> {
+template <typename T, bool Layout>
+class Pool<T, MaxPool<8>, Layout> final : public Pool<T, MaxPool<1>, Layout> {
  public:
-  explicit Pool(const OpKernelInfo& info) : Pool<T, MaxPool<1>, NHWC>(info) {}
+  explicit Pool(const OpKernelInfo& info) : Pool<T, MaxPool<1>, Layout>(info) {}
 
   Status ComputeInternal(OpKernelContext* context) const override;
 };
diff --git a/onnxruntime/core/providers/cuda/nvtx_profile.cc b/onnxruntime/core/providers/cuda/nvtx_profile.cc
index 6c7c594066b86..867e7c1f24584 100644
--- a/onnxruntime/core/providers/cuda/nvtx_profile.cc
+++ b/onnxruntime/core/providers/cuda/nvtx_profile.cc
@@ -4,13 +4,8 @@
 #ifdef ENABLE_NVTX_PROFILE
 #include "nvtx_profile.h"
 #include "core/common/common.h"
-#if defined(_WIN32) || defined(WIN32) || defined(__CYGWIN__) || defined(__MINGW32__) || defined(__BORLANDC__)
 #include <nvtx3/nvToolsExt.h>
 #include <nvtx3/nvToolsExtCuda.h>
-#else
-#include <nvToolsExt.h>
-#include <nvToolsExtCuda.h>
-#endif
 
 namespace onnxruntime {
 namespace profile {
diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc
index 99c1f48e21c74..6476364a211fd 100644
--- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc
+++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc
@@ -9,40 +9,49 @@ namespace onnxruntime {
 namespace cuda {
 
 template <typename T>
-void CudnnRnnBase<T>::SetWeightBias(const cudnnHandle_t handle,
-                                    const cudnnRNNDescriptor_t rnn_desc,
-                                    const int pseudo_layer,
-                                    const cudnnTensorDescriptor_t x_desc,
-                                    const cudnnFilterDescriptor_t w_desc,
-                                    const cudnnFilterDescriptor_t filter_desc,
-                                    const void* reorganized_w_data,
-                                    const int lin_layer_id,
-                                    const T* pos,
-                                    int& offset,
-                                    bool is_matrix,
-                                    cudaStream_t cuda_stream) const {
+Status CudnnRnnBase<T>::SetWeightBias(const cudnnHandle_t handle,
+                                      const cudnnRNNDescriptor_t rnn_desc,
+                                      const int pseudo_layer,
+                                      size_t reorganized_w_data_size,
+                                      const void* reorganized_w_data,
+                                      const int lin_layer_id,
+                                      const T* pos,
+                                      int& offset,
+                                      bool is_matrix,
+                                      cudaStream_t cuda_stream) const {
   int numDims;
-  std::vector<int> matDims(3);
+  std::array<int, 3> matDims;
+  std::array<int, 3> strideA;
   cudnnDataType_t dt;
-  cudnnTensorFormat_t tf;
   T* mem_offset;
 
-  if (is_matrix) {
-    cudnnGetRNNLinLayerMatrixParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset);
-  } else {
-    cudnnGetRNNLinLayerBiasParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset);
-  }
+  CudnnTensor tensor_desc_matrix, tensor_desc_bias;
+  ORT_RETURN_IF_ERROR(tensor_desc_bias.CreateTensorIfNeeded());
+  ORT_RETURN_IF_ERROR(tensor_desc_matrix.CreateTensorIfNeeded());
 
-  cudnnGetFilterNdDescriptor(filter_desc, 3, &dt, &tf, &numDims, matDims.data());
+  T *mem_offset_matrix, *mem_offset_bias;
+  CUDNN_RETURN_IF_ERROR(cudnnGetRNNWeightParams(
+      handle, rnn_desc, pseudo_layer, reorganized_w_data_size, reorganized_w_data,
+      lin_layer_id, tensor_desc_matrix, (void**)&mem_offset_matrix, tensor_desc_bias, (void**)&mem_offset_bias));
+  CUDNN_RETURN_IF_ERROR(cudnnGetTensorNdDescriptor(
+      is_matrix ? tensor_desc_matrix : tensor_desc_bias, 3, &dt, &numDims, matDims.data(), strideA.data()));
+
+  mem_offset = is_matrix ? mem_offset_matrix : mem_offset_bias;
   int count = matDims[0] * matDims[1] * matDims[2];
+
+  if (strideA[0] != count) {
+    return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, "Stride is not packed");
+  }
   CUDA_CALL_THROW(cudaMemcpyAsync(mem_offset, pos + offset, count * sizeof(T), cudaMemcpyDeviceToDevice, cuda_stream));
+
   offset += count;
+
+  return Status::OK();
 }
 template <typename T>
 Status CudnnRnnBase<T>::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle,
                                               const cudnnRNNDescriptor_t rnn_desc,
-                                              const cudnnTensorDescriptor_t x_desc,
-                                              const cudnnFilterDescriptor_t w_desc,
+                                              size_t reorganized_w_data_size,
                                               void* reorganized_w_data,
                                               const T* W_data,
                                               const T* R_data,
@@ -51,18 +60,22 @@ Status CudnnRnnBase<T>::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle,
   int w_offset = 0;
   int r_offset = 0;
   int bias_offset = 0;
-  CudnnFilterDescriptor filter_desc;
   for (int layer = 0; layer < RNN_NUM_LAYERS * num_directions_; ++layer) {
     for (size_t idx = 0; idx < W_lin_layer_id_.size(); ++idx) {
-      SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream);
+      ORT_RETURN_IF_ERROR(SetWeightBias(
+          cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data,
+          W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream));
       if (B_data != nullptr) {
-        SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream);
+        ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data,
+                                          W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream));
       }
     }
     for (size_t idx = 0; idx < R_lin_layer_id_.size(); ++idx) {
-      SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream);
+      ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data,
+                                        R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream));
       if (B_data != nullptr) {
-        SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream);
+        ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data,
+                                          R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream));
       }
     }
   }
@@ -72,6 +85,7 @@ Status CudnnRnnBase<T>::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle,
 
 template <typename T>
 Status CudnnRnnBase<T>::ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B,
+                                          size_t& reorganized_w_data_size_in_bytes,
                                           IAllocatorUniquePtr<void>& reorganized_w_data,
                                           CudnnFilterDescriptor& target_w_desc,
                                           CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const {
@@ -91,19 +105,16 @@ Status CudnnRnnBase<T>::ReorganizeWeights(const Tensor* W, const Tensor* R, cons
   TensorShapeVector dims_w({w_size, 1, 1});
   ORT_RETURN_IF_ERROR(target_w_desc.Set(dims_w, CudnnTensor::GetDataType<CudaT>()));
 
-  TensorShapeVector fake_dims_x({1, input_size, 1});
-  CudnnTensor fake_x_desc;
-  ORT_RETURN_IF_ERROR(fake_x_desc.Set(fake_dims_x, CudnnTensor::GetDataType<CudaT>()));
-
   // Prepare the weight data
-  reorganized_w_data = GetScratchBuffer<void>(w_size * sizeof(T), ort_stream);
+  reorganized_w_data_size_in_bytes = w_size * sizeof(T);
+  reorganized_w_data = GetScratchBuffer<void>(reorganized_w_data_size_in_bytes, ort_stream);
 
   // In many cases, this allocation is bigger than needed, leaving part of
-  // the buffer unintialized. non-zero garbage data leads to wrong result
+  // the buffer uninitialized. non-zero garbage data leads to wrong result
   // in call to cudnnRNNForwardInference()
   // TODO! refine allocation size for each case.
   cudaStream_t cuda_stream = ort_stream ? static_cast<cudaStream_t>(ort_stream->GetHandle()) : nullptr;
-  cudaMemsetAsync(reorganized_w_data.get(), 0, w_size * sizeof(T), cuda_stream);
+  CUDA_RETURN_IF_ERROR(cudaMemsetAsync(reorganized_w_data.get(), 0, reorganized_w_data_size_in_bytes, cuda_stream));
 
   const T* W_data = W->Data<T>();
   const T* R_data = R->Data<T>();
@@ -111,8 +122,9 @@ Status CudnnRnnBase<T>::ReorganizeWeights(const Tensor* W, const Tensor* R, cons
 
   auto* ort_cuda_stream = dynamic_cast<CudaStream*>(ort_stream);
   cudnnHandle_t cudnn_handle = ort_cuda_stream ? ort_cuda_stream->cudnn_handle_ : DefaultCudnnHandle();
-  ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, fake_x_desc, target_w_desc,
-                                            reorganized_w_data.get(), W_data, R_data, B_data, cuda_stream));
+  ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc,
+                                            reorganized_w_data_size_in_bytes, reorganized_w_data.get(),
+                                            W_data, R_data, B_data, cuda_stream));
 
   return Status::OK();
 }
@@ -128,22 +140,31 @@ Status CudnnRnnBase<T>::CacheCudnnRnnWeights(const OpKernelInfo& info) {
   bool get_R = info.TryGetConstantInput(RNN_Input_Index::R, &R);
   bool get_B = info.TryGetConstantInput(RNN_Input_Index::B, &B);
 
+  bool has_bias = B != nullptr;
+
   if (get_W && get_R) {
     CudnnRNN tmp_rnn_desc;
-    ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(DefaultCudnnHandle(),
+    auto proj_size = hidden_size_;
+    ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(W->Shape()[2],  // input_size
                                          hidden_size_,
+                                         proj_size,
                                          RNN_NUM_LAYERS,
                                          cudnn_dropout_desc_,
                                          cudnn_direction_mode_,
                                          rnn_mode_,
-                                         CudnnTensor::GetDataType<CudaT>(),
-                                         GetDeviceProp()));
+                                         has_bias,
+                                         CudnnTensor::GetDataType<CudaT>()));
     if (get_B) {
-      ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr));
+      ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B,
+                                            w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_,
+                                            tmp_rnn_desc, nullptr));
     } else {
-      ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr));
+      ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr,
+                                            w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_,
+                                            tmp_rnn_desc, nullptr));
     }
     cudaStreamSynchronize(nullptr);
+
     weight_cached_ = true;
   }
 
@@ -158,17 +179,72 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
   ORT_ENFORCE(nullptr != X);
 
   // optional inputs
-  const Tensor* sequence_lens = ctx->Input<Tensor>(RNN_Input_Index::sequence_lens);  // [batch_size]
-  const Tensor* initial_h = ctx->Input<Tensor>(RNN_Input_Index::initial_h);          // initial hidden. [num_directions_, batch_size, hidden_size_]
+  // [batch_size]
+  const Tensor* sequence_lens = ctx->Input<Tensor>(RNN_Input_Index::sequence_lens);
+  // initial hidden. [num_directions_, batch_size, hidden_size_]
+  const Tensor* initial_h = ctx->Input<Tensor>(RNN_Input_Index::initial_h);
   const Tensor* initial_c(nullptr);
   if (rnn_mode_ == CUDNN_LSTM) {
-    initial_c = ctx->Input<Tensor>(RNN_Input_Index::initial_c);  // initial cell. [num_directions_, batch_size, hidden_size_]
+    // initial cell. [num_directions_, batch_size, hidden_size_]
+    initial_c = ctx->Input<Tensor>(RNN_Input_Index::initial_c);
   }
 
+  size_t proj_size = hidden_size_;
   int64_t seq_length = X->Shape()[0];
   int64_t batch_size = X->Shape()[1];
   int64_t input_size = X->Shape()[2];
 
+  // we thread a single input as sequence_lens of length 1, require to expand to [batch_size]?
+  std::vector<int32_t> sequence_lengths_temp;
+  if (!sequence_lens) {
+    sequence_lengths_temp.resize(batch_size, gsl::narrow_cast<int32_t>(seq_length));
+  }
+
+  const int32_t* sequence_lens_data = (sequence_lens == nullptr)
+                                          ? sequence_lengths_temp.data()
+                                          : sequence_lens->Data<int32_t>();
+
+  // cuDNN doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1
+  // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence
+  int64_t zero_seq_count = 0;
+  std::vector<int32_t> zero_seq_index_cache(batch_size, 0);
+
+  CudaAsyncBuffer<int32_t> sequence_lens_buffer(this, batch_size);
+  int32_t* seq_len_array = sequence_lens_buffer.CpuPtr();
+
+  // 0-len sequences are not supported by cuDNN.
+  // Replace them by sequences of len 1 and mask them out with SetZeroSequences
+  for (int i = 0; i < batch_size; ++i) {
+    if (0 == sequence_lens_data[i]) {
+      seq_len_array[i] = 1;
+      zero_seq_index_cache[zero_seq_count] = i;
+      ++zero_seq_count;
+    } else {
+      seq_len_array[i] = sequence_lens_data[i];
+    }
+  }
+
+  // Calculate the zero position cache for reverse direction if it's bidirectional
+  // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since
+  // we hacked the 0 sequence to 1
+  if (zero_seq_count && num_directions_ > 1) {
+    zero_seq_index_cache.resize(zero_seq_count * num_directions_);
+    for (int64_t i = 0; i < zero_seq_count; ++i) {
+      zero_seq_index_cache[static_cast<size_t>(zero_seq_count) + i] =
+          static_cast<int32_t>(batch_size + zero_seq_index_cache[i]);
+    }
+    zero_seq_count *= num_directions_;
+  }
+
+  // Prior to cuDNN 8.9.1 the sequence lens buffer must be passed to cudnnRNNForward and thus is must
+  // be copied to the GPU always.
+  ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream()));
+  // Starting with cuDNN 8.9.1 the sequence lens buffer is ignored by cudnnRNNForward and thus it must
+  // be copied to the GPU only for the ReverseBySequence kernels.
+  // if (reverse_) {
+  //  ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream()));
+  // }
+
   // optional outputs
   TensorShapeVector dims_Y({seq_length, num_directions_, batch_size, hidden_size_});
   TensorShapeVector dims_hxy({RNN_NUM_LAYERS * num_directions_, batch_size, hidden_size_});
@@ -177,25 +253,6 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
   Tensor* Y_h = ctx->Output(Output_Index::Y_h, dims_hxy);
   Tensor* Y_c = ctx->Output(Output_Index::Y_c, dims_yc);
 
-  std::vector<int64_t> dims_x({batch_size, input_size, 1});
-  std::vector<int64_t> dims_y({batch_size, hidden_size_ * num_directions_, 1});
-
-  CudnnTensor x_desc_temp;
-  ORT_RETURN_IF_ERROR(x_desc_temp.Set(dims_x, CudnnTensor::GetDataType<CudaT>()));
-  CudnnTensor y_desc_temp;
-  ORT_RETURN_IF_ERROR(y_desc_temp.Set(dims_y, CudnnTensor::GetDataType<CudaT>()));
-  std::vector<cudnnTensorDescriptor_t> x_desc(seq_length, x_desc_temp);
-  std::vector<cudnnTensorDescriptor_t> y_desc(seq_length, y_desc_temp);
-
-  CudnnTensor hx_desc;
-  CudnnTensor cx_desc;
-  CudnnTensor y_h_desc;
-  CudnnTensor y_c_desc;
-  ORT_RETURN_IF_ERROR(hx_desc.Set(dims_hxy, CudnnTensor::GetDataType<CudaT>()));
-  ORT_RETURN_IF_ERROR(cx_desc.Set(dims_hxy, CudnnTensor::GetDataType<CudaT>()));
-  ORT_RETURN_IF_ERROR(y_h_desc.Set(dims_hxy, CudnnTensor::GetDataType<CudaT>()));
-  ORT_RETURN_IF_ERROR(y_c_desc.Set(dims_hxy, CudnnTensor::GetDataType<CudaT>()));
-
   IAllocatorUniquePtr<T> x_reversed_data;
   const T* x_data = X->Data<T>();
   if (reverse_) {
@@ -203,6 +260,7 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
     x_reversed_data = GetScratchBuffer<T>(seq_length * batch_size * input_size, ctx->GetComputeStream());
     ReverseBySequence(Stream(ctx),
                       gsl::narrow_cast<int32_t>(seq_length),
+                      sequence_lens_buffer.GpuPtr(),
                       gsl::narrow_cast<int32_t>(batch_size),
                       gsl::narrow_cast<int32_t>(input_size),
                       reinterpret_cast<const CudaT*>(x_data),
@@ -226,115 +284,81 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
     y_data = y_alloc_data.get();
   }
 
-  const int32_t* sequence_lens_data = (sequence_lens == nullptr) ? nullptr : sequence_lens->Data<int32_t>();
+  const Tensor* B = ctx->Input<Tensor>(RNN_Input_Index::B);
+  bool has_bias = B != nullptr;
 
   CudnnRNN rnn_desc;
-  ORT_RETURN_IF_ERROR(rnn_desc.Set(GetCudnnHandle(ctx),
+  ORT_RETURN_IF_ERROR(rnn_desc.Set(input_size,
                                    hidden_size_,
+                                   proj_size,
                                    RNN_NUM_LAYERS,
                                    cudnn_dropout_desc_,
                                    cudnn_direction_mode_,
                                    rnn_mode_,
-                                   CudnnTensor::GetDataType<CudaT>(),
-                                   GetDeviceProp()));
+                                   has_bias,
+                                   CudnnTensor::GetDataType<CudaT>()));
 
   // Prepare the weight data
+  size_t w_data_size_in_bytes = 0;
   IAllocatorUniquePtr<void> w_data;
   CudnnFilterDescriptor w_desc;
   if (!weight_cached_) {
     const Tensor& W = *ctx->Input<Tensor>(RNN_Input_Index::W);
     const Tensor& R = *ctx->Input<Tensor>(RNN_Input_Index::R);
-    const Tensor* B = ctx->Input<Tensor>(RNN_Input_Index::B);
-    ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data, w_desc, rnn_desc, ctx->GetComputeStream()));
+    ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data_size_in_bytes, w_data, w_desc,
+                                          rnn_desc, ctx->GetComputeStream()));
   }
 
-  // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences
-  CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_desc, CUDNN_RNN_PADDED_IO_ENABLED));
+  CudnnDataTensor x_desc1;
+  ORT_RETURN_IF_ERROR(x_desc1.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size,
+                                  input_size, seq_len_array));
+  CudnnDataTensor y_desc1;
+  ORT_RETURN_IF_ERROR(y_desc1.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size,
+                                  ((rnn_mode_ == CUDNN_LSTM) ? proj_size : hidden_size_) * num_directions_,
+                                  seq_len_array));
 
-  size_t workspace_bytes;
-  CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(GetCudnnHandle(ctx), rnn_desc, gsl::narrow_cast<int>(seq_length), x_desc.data(), &workspace_bytes));
-  auto workspace_cuda = GetScratchBuffer<void>(workspace_bytes, ctx->GetComputeStream());
-  int64_t zero_seq_count = 0;
-  std::vector<int32_t> zero_seq_index_cache(batch_size, 0);
-  int64_t zero_seq_index_cache_size = 0;
-
-  if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) {
-    CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(GetCudnnHandle(ctx),
-                                                   rnn_desc,
-                                                   gsl::narrow_cast<int>(seq_length),
-                                                   x_desc.data(),
-                                                   x_data_input,
-                                                   hx_desc,
-                                                   hx_data,
-                                                   cx_desc,
-                                                   cx_data,
-                                                   weight_cached_ ? w_desc_cache_ : w_desc,
-                                                   weight_cached_ ? w_data_cache_.get() : w_data.get(),
-                                                   y_desc.data(),
-                                                   y_data,
-                                                   y_h_desc,
-                                                   y_h_data,
-                                                   y_c_desc,
-                                                   y_c_data,
-                                                   workspace_cuda.get(),
-                                                   workspace_bytes));
-  } else {
-    // cudnn doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1
-    // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence
-    std::vector<int32_t> seq_len_array(sequence_lens_data, sequence_lens_data + batch_size);
-    for (int i = 0; i < batch_size; ++i) {
-      if (0 == seq_len_array[i]) {
-        seq_len_array[i] = 1;
-        zero_seq_index_cache[zero_seq_count] = i;
-        ++zero_seq_count;
-      }
-    }
+  CudnnTensor cx_desc;
+  ORT_RETURN_IF_ERROR(cx_desc.Set(dims_hxy, CudnnTensor::GetDataType<CudaT>()));
 
-    // Calculate the zero position cache for reverse direction if it's bidirectional
-    // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since
-    // we hacked the 0 sequence to 1
-    if (zero_seq_count && num_directions_ > 1) {
-      zero_seq_index_cache_size = zero_seq_count * num_directions_;
-      zero_seq_index_cache.resize(zero_seq_index_cache_size);
-      for (int64_t i = 0; i < zero_seq_count; ++i) {
-        zero_seq_index_cache[static_cast<size_t>(zero_seq_count) + i] = static_cast<int32_t>(batch_size + zero_seq_index_cache[i]);
-      }
-    }
+  CudnnTensor hx_desc;
+  ORT_RETURN_IF_ERROR(hx_desc.Set(dims_hxy, CudnnTensor::GetDataType<CudaT>()));
+
+  // reserveSpaceSize is not required cudnnRNNForward, but returned by cudnnGetRNNTempSpaceSizes
+  size_t workspace_bytes, reservespace_bytes;
 
-    CudnnDataTensor x_desc1;
-    ORT_RETURN_IF_ERROR(x_desc1.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, input_size, seq_len_array.data()));
-    CudnnDataTensor y_desc1;
-    ORT_RETURN_IF_ERROR(y_desc1.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, hidden_size_ * num_directions_, seq_len_array.data()));
-
-    CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(GetCudnnHandle(ctx),
-                                                     rnn_desc,
-                                                     x_desc1,
-                                                     x_data_input,
-                                                     hx_desc,
-                                                     hx_data,
-                                                     cx_desc,
-                                                     cx_data,
-                                                     weight_cached_ ? w_desc_cache_ : w_desc,
-                                                     weight_cached_ ? w_data_cache_.get() : w_data.get(),
-                                                     y_desc1,
-                                                     y_data,
-                                                     y_h_desc,
-                                                     y_h_data,
-                                                     y_c_desc,
-                                                     y_c_data,
-                                                     nullptr, nullptr, nullptr, nullptr,
-                                                     nullptr, nullptr, nullptr, nullptr,
-                                                     workspace_cuda.get(),
-                                                     workspace_bytes));
-
-    // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data.
-    if (nullptr == Y) {
+  CUDNN_RETURN_IF_ERROR(cudnnGetRNNTempSpaceSizes(GetCudnnHandle(ctx), rnn_desc, CUDNN_FWD_MODE_INFERENCE,
+                                                  x_desc1, &workspace_bytes, &reservespace_bytes));
+  auto workspace_cuda = GetScratchBuffer<void>(workspace_bytes, ctx->GetComputeStream());
+  auto reservespace_cuda = GetScratchBuffer<void>(reservespace_bytes, ctx->GetComputeStream());
+
+  CUDNN_RETURN_IF_ERROR(cudnnRNNForward(GetCudnnHandle(ctx),
+                                        rnn_desc,
+                                        CUDNN_FWD_MODE_INFERENCE,
+                                        sequence_lens_buffer.GpuPtr(),  // should be zero starting with cudnn 8.9.1
+                                        x_desc1,
+                                        x_data_input,
+                                        y_desc1,
+                                        y_data,  // output
+                                        hx_desc,
+                                        hx_data,   // input
+                                        y_h_data,  // output
+                                        cx_desc, cx_data, y_c_data,
+                                        weight_cached_ ? w_data_cache_size_in_bytes_ : w_data_size_in_bytes,
+                                        weight_cached_ ? w_data_cache_.get() : w_data.get(),
+                                        workspace_bytes,
+                                        workspace_cuda.get(),
+                                        reservespace_bytes,
+                                        reservespace_cuda.get()));
+
+  // Early terminate for this case since Y data is not required, and Y_h is obtained correctly,
+  // no need the following code to retrieve Y_h from Y data.
+  if (nullptr == Y) {
+    // Mask on output for 0 sequence batches
+    if (zero_seq_count > 0) {
       // Mask on output for 0 sequence batches
-      if (zero_seq_count > 0) {
-        SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream());
-      }
-      return Status::OK();
+      SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream());
     }
+    return Status::OK();
   }
 
   IAllocatorUniquePtr<T> y_reorganized_data;
@@ -345,6 +369,7 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
       // reverse output data
       ReverseBySequence(Stream(ctx),
                         gsl::narrow_cast<int32_t>(seq_length),
+                        sequence_lens_buffer.GpuPtr(),
                         gsl::narrow_cast<int32_t>(batch_size),
                         gsl::narrow_cast<int32_t>(hidden_size_),
                         reinterpret_cast<CudaT*>(y_data),
@@ -361,8 +386,9 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
     }
 
     if (Y != nullptr) {
-      // User specified this optional output, so need to copy the reversed data to orignial place
-      CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(y_data, y_reorganized_data.get(), output_size * sizeof(T), cudaMemcpyDeviceToDevice, Stream(ctx)));
+      // User specified this optional output, so need to copy the reversed data to original place
+      CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(y_data, y_reorganized_data.get(), output_size * sizeof(T),
+                                           cudaMemcpyDeviceToDevice, Stream(ctx)));
     } else {
       y_data = y_reorganized_data.get();
     }
@@ -370,23 +396,9 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
 
   // Mask on output for 0 sequence batches
   if (zero_seq_count > 0) {
-    SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream());
+    SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream());
   }
 
-  if ((CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_) && sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) {
-    CudaAsyncBuffer<int32_t> sequence_lens_buffer(this, batch_size);
-    memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, batch_size * sizeof(int32_t));
-    ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream()));
-    RnnMaskImpl(Stream(ctx),
-                gsl::narrow_cast<int32_t>(num_directions_),
-                gsl::narrow_cast<int32_t>(seq_length),
-                gsl::narrow_cast<int32_t>(batch_size),
-                gsl::narrow_cast<int32_t>(hidden_size_),
-                sequence_lens_buffer.GpuPtr(),
-                reinterpret_cast<CudaT*>(y_data),
-                reinterpret_cast<CudaT*>(y_h_data),
-                output_size);
-  }
   return Status::OK();
 }
 
@@ -399,7 +411,8 @@ void CudnnRnnBase<T>::SetZeroSequences(const int64_t zero_seq_index_cache_size,
                                        onnxruntime::Stream* ort_stream) const {
   typedef typename ToCudaType<T>::MappedType CudaT;
   CudaAsyncBuffer<int32_t> zero_seq_index_cache_async_buffer(this, zero_seq_index_cache_size);
-  memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), zero_seq_index_cache_size * sizeof(int32_t));
+  memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(),
+         zero_seq_index_cache_size * sizeof(int32_t));
   ORT_THROW_IF_ERROR(zero_seq_index_cache_async_buffer.CopyToGpu(ort_stream));
   cudaStream_t cuda_stream = ort_stream ? static_cast<cudaStream_t>(ort_stream->GetHandle()) : nullptr;
   MaskZeroSequences(cuda_stream,
diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h
index 1c9483b2afd38..0fa01d3486e99 100644
--- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h
+++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h
@@ -38,26 +38,28 @@ class CudnnRNN {
     }
   }
 
-  Status Set(const cudnnHandle_t& cudnnHandle, int64_t hidden_size, int num_layers,
+  Status Set(int64_t input_size, int64_t hidden_size, int64_t proj_size, int num_layers,
              cudnnDropoutDescriptor_t cudnn_dropout_desc, cudnnDirectionMode_t cudnn_direction_model,
-             cudnnRNNMode_t rnn_mode, cudnnDataType_t dataType, const cudaDeviceProp& prop) {
+             cudnnRNNMode_t rnn_mode, bool has_bias, cudnnDataType_t dataType) {
     if (!cudnn_rnn_desc_)
       CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&cudnn_rnn_desc_));
 
-    CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v6(cudnnHandle,
-                                                   cudnn_rnn_desc_,
+    CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v8(cudnn_rnn_desc_,
+                                                   CUDNN_RNN_ALGO_STANDARD,  // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC
+                                                   rnn_mode,
+                                                   has_bias ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS,
+                                                   cudnn_direction_model,
+                                                   CUDNN_LINEAR_INPUT,
+                                                   dataType,
+                                                   dataType,
+                                                   dataType == CUDNN_DATA_HALF ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH,
+                                                   gsl::narrow_cast<int>(input_size),
                                                    gsl::narrow_cast<int>(hidden_size),
+                                                   gsl::narrow_cast<int>(proj_size),  // projected size
                                                    num_layers,
                                                    cudnn_dropout_desc,
-                                                   CUDNN_LINEAR_INPUT,  // We can also skip the input matrix transformation
-                                                   cudnn_direction_model,
-                                                   rnn_mode,
-                                                   CUDNN_RNN_ALGO_STANDARD,  // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC
-                                                   dataType));
-
-    if (prop.major >= 7 && dataType == CUDNN_DATA_HALF) {
-      cudnnSetRNNMatrixMathType(cudnn_rnn_desc_, CUDNN_TENSOR_OP_MATH);
-    }
+                                                   // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences
+                                                   CUDNN_RNN_PADDED_IO_ENABLED));
 
     return Status::OK();
   }
@@ -119,8 +121,7 @@ class CudnnRnnBase : public CudaKernel {
  private:
   Status SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle,
                                const cudnnRNNDescriptor_t rnn_desc,
-                               const cudnnTensorDescriptor_t x_desc,
-                               const cudnnFilterDescriptor_t w_desc,
+                               size_t w_data_size,
                                void* w_data,
                                const T* W_data,
                                const T* R_data,
@@ -128,23 +129,22 @@ class CudnnRnnBase : public CudaKernel {
                                cudaStream_t cuda_stream) const;
 
   Status ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B,
+                           size_t& target_w_data_size_in_bytes,
                            IAllocatorUniquePtr<void>& target_w_data,
                            CudnnFilterDescriptor& target_w_desc,
                            CudnnRNN& rnn_desc,
                            onnxruntime::Stream* ort_stream) const;
 
-  void SetWeightBias(const cudnnHandle_t handle,
-                     const cudnnRNNDescriptor_t rnn_desc,
-                     const int pseudo_layer,
-                     const cudnnTensorDescriptor_t x_desc,
-                     const cudnnFilterDescriptor_t w_desc,
-                     const cudnnFilterDescriptor_t filter_desc,
-                     const void* w_data,
-                     const int lin_layer_id,
-                     const T* pos,
-                     int& offset,
-                     bool is_matrix,
-                     cudaStream_t cuda_stream) const;
+  Status SetWeightBias(const cudnnHandle_t handle,
+                       const cudnnRNNDescriptor_t rnn_desc,
+                       const int pseudo_layer,
+                       size_t w_data_size,
+                       const void* w_data,
+                       const int lin_layer_id,
+                       const T* pos,
+                       int& offset,
+                       bool is_matrix,
+                       cudaStream_t cuda_stream) const;
 
   void SetZeroSequences(const int64_t zero_seq_index_cache_size,
                         const std::vector<int32_t> zero_seq_index_cache,
@@ -167,6 +167,7 @@ class CudnnRnnBase : public CudaKernel {
   cudnnRNNMode_t rnn_mode_;
   // w_desc_cache_ & w_data_cache_ are changed in Constructor if we can get the weights as constant input
   CudnnFilterDescriptor w_desc_cache_;
+  size_t w_data_cache_size_in_bytes_;
   IAllocatorUniquePtr<void> w_data_cache_;
   bool weight_cached_;
   int64_t layout_;
diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.cc b/onnxruntime/core/providers/cuda/rnn/rnn.cc
index 4bd22340ef2bb..ed8be63679707 100644
--- a/onnxruntime/core/providers/cuda/rnn/rnn.cc
+++ b/onnxruntime/core/providers/cuda/rnn/rnn.cc
@@ -1,8 +1,9 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
-#include "core/providers/shared_library/provider_api.h"
 #include "rnn.h"
+
+#include "core/providers/shared_library/provider_api.h"
 #include "rnn_impl.h"
 #include "core/providers/cuda/cudnn_common.h"
 
diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.h b/onnxruntime/core/providers/cuda/rnn/rnn.h
index e4e50046b3725..6221afb003b22 100644
--- a/onnxruntime/core/providers/cuda/rnn/rnn.h
+++ b/onnxruntime/core/providers/cuda/rnn/rnn.h
@@ -4,6 +4,7 @@
 #pragma once
 
 #include "cudnn_rnn_base.h"
+
 #include "core/providers/cuda/cuda_common.h"
 #include <cudnn.h>
 
diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu
index d485855ddb417..94c8036be6cdf 100644
--- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu
+++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu
@@ -8,22 +8,32 @@ namespace onnxruntime {
 namespace cuda {
 
 template <typename T>
-__global__ void _ReverseBySequenceKernel(const int32_t seq_length,
+__global__ void _ReverseBySequenceKernel(const int32_t max_seq_length,
+                                         const int32_t* seq_lengths,
                                          const int32_t block_size,
                                          const fast_divmod div_batch_block,
+                                         const fast_divmod div_input_or_hidden_size,
                                          const T* data,
                                          T* reversed_data,
                                          const CUDA_LONG N) {
   CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
   int seq_id, offset;
   div_batch_block.divmod(id, seq_id, offset);
-  int org_id = (seq_length - seq_id - 1) * block_size + offset;
-  reversed_data[id] = data[org_id];
+  int batch, batch_offset;
+  div_input_or_hidden_size.divmod(offset, batch, batch_offset);
+  int seq_id_org = seq_lengths[batch] - seq_id - 1;
+  if (seq_id_org >= 0) {
+    int org_id = seq_id_org * block_size + offset;
+    reversed_data[id] = data[org_id];
+  } else {
+    reversed_data[id] = T{};
+  }
 }
 
 template <typename T>
 void ReverseBySequence(cudaStream_t stream,
-                       const int32_t seq_length,
+                       const int32_t max_seq_length,
+                       const int32_t *seq_lengths,
                        const int32_t batch_size,
                        const int32_t input_or_hidden_size,
                        const T* data,
@@ -32,9 +42,10 @@ void ReverseBySequence(cudaStream_t stream,
   // kerneral
   int32_t block_size = batch_size * input_or_hidden_size;
   fast_divmod div_batch_block(block_size);
+  fast_divmod div_input_or_hidden_size(input_or_hidden_size);
   int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
   _ReverseBySequenceKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
-      seq_length, block_size, div_batch_block, data, reversed_data, (CUDA_LONG)N);
+      max_seq_length, seq_lengths, block_size, div_batch_block, div_input_or_hidden_size, data, reversed_data, (CUDA_LONG)N);
 }
 
 template <typename T>
@@ -82,60 +93,6 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream,
       data, reordered_data, (CUDA_LONG)N);
 }
 
-template <typename T>
-__global__ void _RnnMaskKernel(const int32_t seq_length,
-                               const int32_t batch_size,
-                               const int32_t hidden_size,
-                               const int32_t* sequence_lens,
-                               const fast_divmod div_seq_block,
-                               const fast_divmod div_dir_block,
-                               const fast_divmod div_batch_block,
-                               T* y_output_data,
-                               T* y_h_output_data,
-                               const CUDA_LONG N) {
-  CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
-
-  int seq_id, direction_id, batch_id, offset;
-  div_seq_block.divmod(id, seq_id, offset);
-  div_dir_block.divmod(offset, direction_id, offset);
-  div_batch_block.divmod(offset, batch_id, offset);
-  int32_t batch_seq_length = sequence_lens[batch_id];
-
-  if (batch_id >= batch_size || batch_seq_length == seq_length) {
-    return;
-  }
-
-  if (seq_id >= batch_seq_length) {
-    y_output_data[id] = 0;
-    return;
-  }
-
-  if ((y_h_output_data != nullptr) && 
-      ((direction_id == 0 && (seq_id + 1) == batch_seq_length) || (direction_id == 1 && seq_id == 0))) {
-    int hy_idx = direction_id * batch_size * hidden_size + batch_id * hidden_size + offset;
-    y_h_output_data[hy_idx] = y_output_data[id];
-  }
-}
-
-template <typename T>
-void RnnMaskImpl(cudaStream_t stream,
-                 const int32_t num_directions,
-                 const int32_t seq_length,
-                 const int32_t batch_size,
-                 const int32_t hidden_size,
-                 const int32_t* sequence_lens,
-                 T* y_output_data,
-                 T* y_h_output_data,
-                 const size_t N) {
-  fast_divmod div_seq_block(batch_size * hidden_size * num_directions);
-  fast_divmod div_dir_block(batch_size * hidden_size);
-  fast_divmod div_batch_block(hidden_size);
-  int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
-  _RnnMaskKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
-      seq_length, batch_size, hidden_size, sequence_lens, div_seq_block,
-      div_dir_block, div_batch_block, y_output_data, y_h_output_data, (CUDA_LONG)N);
-}
-
 template <typename T>
 __global__ void _MaskZeroSequences(const int32_t hidden_size,
                                    T* y_output_data,
@@ -180,17 +137,9 @@ void MaskZeroSequences(cudaStream_t stream,
 }
 
 #define SPECIALIZED_RNN_IMPL(T)                                                 \
-  template void RnnMaskImpl<T>(cudaStream_t stream,                       \
-                               const int32_t num_directions,                    \
-                               const int32_t seq_length,                        \
-                               const int32_t batch_size,                        \
-                               const int32_t hidden_size,                       \
-                               const int32_t* sequence_lens,                    \
-                               T* y_output_data,                                \
-                               T* y_h_output_data,                              \
-                               const size_t N);                                 \
-  template void ReverseBySequence<T>(cudaStream_t stream,                 \
-                                     const int32_t seq_length,                  \
+  template void ReverseBySequence<T>(cudaStream_t stream,                       \
+                                     const int32_t max_seq_length,              \
+                                     const int32_t* seq_lengths,                \
                                      const int32_t batch_size,                  \
                                      const int32_t hidden_size,                 \
                                      const T* data,                             \
@@ -203,7 +152,7 @@ void MaskZeroSequences(cudaStream_t stream,
                                                       const T* data,            \
                                                       T* reordered_data,        \
                                                      const size_t N);           \
-template void MaskZeroSequences<T>(cudaStream_t stream,                   \
+template void MaskZeroSequences<T>(cudaStream_t stream,                         \
                                    const int32_t hidden_size,                   \
                                    T* y_output_data,                            \
                                    T* y_h_output_data,                          \
diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h
index 9844e04ff6ec5..ba876011f6b67 100644
--- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h
+++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h
@@ -10,7 +10,8 @@ namespace cuda {
 
 template <typename T>
 void ReverseBySequence(cudaStream_t stream,
-                       const int32_t seq_length,
+                       const int32_t max_seq_length,
+                       const int32_t* seq_lengths,
                        const int32_t batch_size,
                        const int32_t input_or_hidden_size,
                        const T* data,
@@ -26,17 +27,6 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream,
                                         T* reordered_data,
                                         const size_t N);
 
-template <typename T>
-void RnnMaskImpl(cudaStream_t stream,
-                 const int32_t num_directions,
-                 const int32_t seq_length,
-                 const int32_t batch_size,
-                 const int32_t hidden_size,
-                 const int32_t* sequence_lens,
-                 T* y_output_data,
-                 T* y_h_output_data,
-                 const size_t N);
-
 template <typename T>
 void MaskZeroSequences(cudaStream_t stream,
                        const int32_t hidden_size,
diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h
index fa987866c002f..54c024793ff0b 100644
--- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h
+++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h
@@ -168,5 +168,31 @@ struct NumericLimits<double> {
   }
 };
 
+// TODO Where to put this? good places might be
+// core/framework/tensor_shape.h
+// core/util/matrix_layout.h
+
+constexpr bool LAYOUT_NCHW = false;
+constexpr bool LAYOUT_NHWC = true;
+
+template <bool IsNHWC>
+struct Channels;
+
+template <>
+struct Channels<LAYOUT_NHWC> {
+  static constexpr size_t N = 0;
+  static constexpr size_t H = 1;
+  static constexpr size_t W = 2;
+  static constexpr size_t C = 3;
+};
+
+template <>
+struct Channels<LAYOUT_NCHW> {
+  static constexpr size_t N = 0;
+  static constexpr size_t C = 1;
+  static constexpr size_t H = 2;
+  static constexpr size_t W = 3;
+};
+
 }  // namespace cuda
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/tensor/gelu.cc b/onnxruntime/core/providers/cuda/tensor/gelu.cc
new file mode 100644
index 0000000000000..67b2fad373a7f
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/tensor/gelu.cc
@@ -0,0 +1,89 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cuda/cudnn_common.h"
+#include "core/providers/cuda/tensor/gelu.h"
+#include "core/providers/cuda/tensor/gelu_impl.h"
+
+namespace onnxruntime {
+namespace cuda {
+
+#define REGISTER_KERNEL_TYPED(T)                                 \
+  ONNX_OPERATOR_TYPED_KERNEL_EX(                                 \
+      Gelu,                                                      \
+      kOnnxDomain,                                               \
+      20,                                                        \
+      T,                                                         \
+      kCudaExecutionProvider,                                    \
+      (*KernelDefBuilder::Create())                              \
+          .TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
+          .MayInplace(0, 0),                                     \
+      Gelu<T>);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
+REGISTER_KERNEL_TYPED(double)
+
+template <typename T>
+Status Gelu<T>::ComputeInternal(OpKernelContext* context) const {
+  const Tensor* input = context->Input<Tensor>(0);
+  const auto& input_dims = input->Shape().GetDims();
+  if (input_dims.size() < 1) {
+    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+                           "Input 0 is expected to have 1 or more dimensions, got ", input_dims.size());
+  }
+
+  Tensor* output = context->Output(0, input->Shape());
+
+  int64_t input_length = input->Shape().Size();
+  if (input_length == 0) {
+    return Status::OK();
+  }
+
+  typedef typename ToCudaType<T>::MappedType CudaT;
+
+  if (approximation_algorithm_ == "tanh") {
+    return LaunchFastGeluKernel<CudaT>(GetDeviceProp(),
+                                       Stream(context),
+                                       static_cast<int>(input_length),
+                                       0 /* no bias */,
+                                       reinterpret_cast<const CudaT*>(input->Data<T>()),
+                                       nullptr /* no bias */,
+                                       reinterpret_cast<CudaT*>(output->MutableData<T>()),
+                                       use_half2_);
+  } else if (approximation_algorithm_ == "none") {
+    return LaunchGeluKernel<CudaT>(Stream(context),
+                                   reinterpret_cast<const CudaT*>(input->Data<T>()),
+                                   reinterpret_cast<CudaT*>(output->MutableData<T>()),
+                                   static_cast<size_t>(input_length));
+  }
+
+  return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_);
+}
+
+}  // namespace cuda
+
+#ifndef DISABLE_CONTRIB_OPS
+namespace contrib::cuda {
+#define REGISTER_CONTRIB_KERNEL_TYPED(T)                         \
+  ONNX_OPERATOR_TYPED_KERNEL_EX(                                 \
+      Gelu,                                                      \
+      kMSDomain,                                                 \
+      1,                                                         \
+      T,                                                         \
+      kCudaExecutionProvider,                                    \
+      (*KernelDefBuilder::Create())                              \
+          .TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
+          .MayInplace(0, 0),                                     \
+      onnxruntime::cuda::Gelu<T>);
+
+REGISTER_CONTRIB_KERNEL_TYPED(float)
+REGISTER_CONTRIB_KERNEL_TYPED(MLFloat16)
+REGISTER_CONTRIB_KERNEL_TYPED(double)
+
+#undef REGISTER_CONTRIB_KERNEL_TYPED
+}  // namespace contrib::cuda
+#endif
+
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/tensor/gelu.h b/onnxruntime/core/providers/cuda/tensor/gelu.h
new file mode 100644
index 0000000000000..1c8189ab24121
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/tensor/gelu.h
@@ -0,0 +1,28 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/cuda_kernel.h"
+#include "core/providers/cuda/math/unary_elementwise_ops.h"
+
+namespace onnxruntime {
+namespace cuda {
+
+template <typename T>
+class Gelu final : public UnaryElementwise {
+ public:
+  Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {
+    approximation_algorithm_ = info.GetAttrOrDefault<std::string>("approximate", "none");
+  }
+
+  Status ComputeInternal(OpKernelContext* ctx) const override;
+
+ private:
+  const bool use_half2_{true};
+
+  std::string approximation_algorithm_;
+};
+
+}  // namespace cuda
+}  // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu b/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu
similarity index 84%
rename from onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu
rename to onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu
index c9498eb1bcd7b..7a27b7af33137 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu
+++ b/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu
@@ -24,12 +24,9 @@ limitations under the License.
 #include "core/providers/cuda/cuda_common.h"
 #include "core/providers/cuda/cu_inc/common.cuh"
 #include "core/providers/cuda/shared_inc/cuda_call.h"
-#include "contrib_ops/cuda/bert/fast_gelu_impl.h"
-
-using namespace onnxruntime::cuda;
+#include "core/providers/cuda/tensor/gelu_impl.h"
 
 namespace onnxruntime {
-namespace contrib {
 namespace cuda {
 
 // constants for approximating the normal cdf
@@ -65,7 +62,7 @@ __global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int
 }
 
 template <>
-Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
+Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length,
                             const float* input, const float* bias, float* output, bool /*use_half2*/) {
   constexpr int blockSize = 256;
   const int gridSize = (input_length + blockSize - 1) / blockSize;
@@ -75,6 +72,17 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int
   return CUDA_CALL(cudaGetLastError());
 }
 
+template <>
+Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length,
+                            const double* input, const double* bias, double* output, bool /*use_half2*/) {
+  constexpr int blockSize = 256;
+  const int gridSize = (input_length + blockSize - 1) / blockSize;
+  FastGeluKernel<double, blockSize><<<gridSize, blockSize, 0, stream>>>(A, B, C, input_length, bias_length,
+                                                                        input, bias, output);
+
+  return CUDA_CALL(cudaGetLastError());
+}
+
 template <>
 Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
                             const half* input, const half* bias, half* output, bool use_half2) {
@@ -100,7 +108,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int
 }
 
 template <>
-Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
+Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length,
                             const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) {
   constexpr int blockSize = 256;
 
@@ -114,5 +122,4 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int
 }
 
 }  // namespace cuda
-}  // namespace contrib
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu b/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu
new file mode 100644
index 0000000000000..3f96da38b37bb
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu
@@ -0,0 +1,48 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include <cuda_runtime.h>
+#include "core/providers/cuda/tensor/gelu_impl.h"
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "core/providers/cuda/cu_inc/unary_elementwise_impl.cuh"
+
+namespace onnxruntime {
+namespace cuda {
+
+template <typename T>
+struct OP_Gelu {
+  __device__ __inline__ T operator()(const T& a) const {
+    return _Gelu(a);
+  }
+};
+
+template <>
+struct OP_Gelu<half> {
+  __device__ __inline__ half operator()(const half& a) const {
+    return static_cast<half>(_Gelu(static_cast<float>(a)));
+  }
+};
+
+template <typename T>
+Status LaunchGeluKernel(
+    cudaStream_t stream,
+    const T* input_data,
+    T* output_data,
+    size_t count) {
+  UnaryElementWiseImpl(stream, input_data, output_data, OP_Gelu<T>(), count);
+
+  return CUDA_CALL(cudaGetLastError());
+}
+
+#define SPECIALIZED_GELU_IMPL(T)                                                                \
+  template Status LaunchGeluKernel<T>(cudaStream_t stream, const T* input_data, T* output_data, \
+                                      size_t count);
+
+SPECIALIZED_GELU_IMPL(float);
+SPECIALIZED_GELU_IMPL(half);
+SPECIALIZED_GELU_IMPL(double);
+
+#undef SPECIALIZED_GELU_IMPL
+
+}  // namespace cuda
+}  // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h b/onnxruntime/core/providers/cuda/tensor/gelu_impl.h
similarity index 80%
rename from onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h
rename to onnxruntime/core/providers/cuda/tensor/gelu_impl.h
index ba78310f5dfc2..2ea0d3441fda3 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h
+++ b/onnxruntime/core/providers/cuda/tensor/gelu_impl.h
@@ -1,17 +1,18 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
-
 #pragma once
+
 #include "core/common/common.h"
 
 namespace onnxruntime {
-namespace contrib {
 namespace cuda {
 
+template <typename T>
+Status LaunchGeluKernel(cudaStream_t stream, const T* input, T* output, size_t count);
+
 template <typename T>
 Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
                             const T* input, const T* bias, T* output, bool use_half2);
 
 }  // namespace cuda
-}  // namespace contrib
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/tensor/resize.cc b/onnxruntime/core/providers/cuda/tensor/resize.cc
index 764172a8d1fac..97d4eb71e970a 100644
--- a/onnxruntime/core/providers/cuda/tensor/resize.cc
+++ b/onnxruntime/core/providers/cuda/tensor/resize.cc
@@ -28,10 +28,22 @@ namespace cuda {
           .InputMemoryType(OrtMemTypeCPUInput, 3)                  \
           .TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()), \
       Resize<T>);                                                  \
+  ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(                         \
+      Resize,                                                      \
+      kOnnxDomain,                                                 \
+      13, 17,                                                      \
+      T,                                                           \
+      kCudaExecutionProvider,                                      \
+      (*KernelDefBuilder::Create())                                \
+          .InputMemoryType(OrtMemTypeCPUInput, 1)                  \
+          .InputMemoryType(OrtMemTypeCPUInput, 2)                  \
+          .InputMemoryType(OrtMemTypeCPUInput, 3)                  \
+          .TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()), \
+      Resize<T>);                                                  \
   ONNX_OPERATOR_TYPED_KERNEL_EX(                                   \
       Resize,                                                      \
       kOnnxDomain,                                                 \
-      13,                                                          \
+      18,                                                          \
       T,                                                           \
       kCudaExecutionProvider,                                      \
       (*KernelDefBuilder::Create())                                \
diff --git a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu
new file mode 100644
index 0000000000000..d56e4bc53874d
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu
@@ -0,0 +1,1179 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "core/providers/cuda/tensor/resize_impl.h"
+
+#define FUNC_DEF __device__
+
+namespace onnxruntime {
+namespace cuda {
+
+using onnxruntime::ResizeCoordinateTransformationMode;
+using onnxruntime::UpsampleMode;
+
+/// <summary>
+/// Compute a buffer for bilinear data for CUDA antialias resizing.
+/// </summary>
+static std::tuple<int64_t, int64_t> ComputeBilinearScaleBufferSize(
+    int64_t output_height, int64_t output_width,
+    float height_rscale, float width_rscale,
+    float support_value,
+    float& scaled_support_height, float& scaled_support_width,
+    int32_t& window_size_height, int32_t& window_size_width) {
+  scaled_support_height = ComputeScaledSupportValue(support_value, height_rscale);
+  scaled_support_width = ComputeScaledSupportValue(support_value, width_rscale);
+  window_size_height = ComputeWindowSize(scaled_support_height);
+  window_size_width = ComputeWindowSize(scaled_support_width);
+
+  auto height_buffer_size = ComputeWeightedCoeffBufferSize(output_height, window_size_height);
+  auto width_buffer_size = ComputeWeightedCoeffBufferSize(output_width, window_size_width);
+
+  return std::make_tuple(height_buffer_size, width_buffer_size);
+}
+
+/// <summary>
+/// Compute a buffer for btrilinear data for CUDA antialias resizing.
+/// </summary>
+static std::tuple<int64_t, int64_t, int64_t> ComputeTrilinearScaleBufferSize(
+    int64_t output_depth, int64_t output_height, int64_t output_width,
+    float depth_rscale, float height_rscale, float width_rscale,
+    float support_value,
+    float& scaled_support_depth, float& scaled_support_height,
+    float& scaled_support_width, int32_t& window_size_depth,
+    int32_t& window_size_height, int32_t& window_size_width) {
+  scaled_support_depth = ComputeScaledSupportValue(support_value, depth_rscale);
+  window_size_depth = ComputeWindowSize(scaled_support_depth);
+  auto depth_buffer_size = ComputeWeightedCoeffBufferSize(output_depth, window_size_depth);
+
+  const auto [y_buffer_size, w_buffer_size] = ComputeBilinearScaleBufferSize(output_height,
+                                                                             output_width, height_rscale,
+                                                                             width_rscale, support_value,
+                                                                             scaled_support_height,
+                                                                             scaled_support_width,
+                                                                             window_size_height, window_size_width);
+  return std::make_tuple(depth_buffer_size, y_buffer_size, w_buffer_size);
+}
+
+// Antialiasing filters
+struct BilinearFilter {
+  __device__ __host__ float operator()(float x, float /* cubic_coeff_a */) const {
+    if (x < 0.0f) {
+      x = -x;
+    }
+    if (x < 1.0f) {
+      return 1.0f - x;
+    }
+    return 0.0f;
+  }
+};
+
+struct BiCubicFilter {
+  __device__ __host__ float operator()(float x, float cubic_coeff_a) const {
+    /* https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
+     */
+    if (x < 0.0f) {
+      x = -x;
+    }
+    if (x < 1.0f) {
+      return ((cubic_coeff_a + 2.0f) * x - (cubic_coeff_a + 3.0f)) * x * x + 1;
+    }
+    if (x < 2.0f) {
+      return (((x - 5.0f) * x + 8.f) * x - 4.f) * cubic_coeff_a;
+    }
+    return 0.0f;
+  }
+};
+
+struct TriLinearFilter {
+  __device__ __host__ float operator()(float x, float /* cubic_coeff_a */) const {
+    if (x < 0.0f) {
+      x = -x;
+    }
+    if (x < 1.0f) {
+      return 1.0f - x;
+    }
+    return 0.0f;
+  }
+};
+
+template <typename AccumType>
+struct AccumTypeCaster {
+  static __device__ __host__ AccumType* cast(AccumType* p) {
+    return p;
+  }
+};
+
+template <>
+struct AccumTypeCaster<int32_t> {
+  static __device__ __host__ float* cast(int32_t* p) {
+    return reinterpret_cast<float*>(p);
+  }
+};
+
+template <typename T, typename AccumType>
+__global__ void _ComputeInterpolationAtLevel1(
+    int64_t num_channels,
+    int64_t input_height, int64_t input_width,
+    int64_t output_height, int64_t output_width,
+    const fast_divmod div_output_width,
+    const fast_divmod div_output_image,
+    int32_t window_size,
+    const uint8_t* clip8_table,
+    const int64_t* bound_data,
+    std::tuple<int64_t*, int64_t*> outof_bounds_buffers,
+    const AccumType* weight_coefficients,
+    const T* Xdata, T* Ydata,
+    const int N) {
+  CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
+
+  // No need to do scale
+  if (output_width == input_width) {
+    Ydata[id] = Xdata[id];
+    return;
+  }
+
+  int bxc, output_image_index;
+  div_output_image.divmod(id, bxc, output_image_index);
+
+  int output_y, output_x;
+  div_output_width.divmod(output_image_index, output_y, output_x);
+
+  CUDA_LONG input_index = static_cast<CUDA_LONG>(bxc * num_channels * input_height * input_width);
+  CUDA_LONG output_index = static_cast<CUDA_LONG>(bxc * num_channels * output_height * output_width);
+
+  auto* Ydata_offset = Ydata + output_index + output_width * output_y + output_x;
+  const auto* bound = bound_data;
+
+  AccumType output = onnxruntime::is_8bit_v<T> ? ConstValue::mag_factor : 0;
+
+  const auto* weight_coeff = weight_coefficients + window_size * output_x;
+  int64_t xmin = bound[static_cast<ptrdiff_t>(output_x) * 2];
+  int64_t xmax = bound[static_cast<ptrdiff_t>(output_x) * 2 + 1];
+
+  // Input window
+  const auto* Xdata_offset = Xdata + input_index + input_width * output_y + xmin;
+
+  for (; xmin < xmax; ++xmin) {
+    if constexpr (std::is_same<T, half>::value) {
+      // This cast is needed when we deal with half
+      output += static_cast<AccumType>((*Xdata_offset++)) * (*weight_coeff++);
+    } else {
+      output += (*Xdata_offset++) * (*weight_coeff++);
+    }
+  }
+
+  if constexpr (onnxruntime::is_8bit_v<T>) {
+    const uint8_t* clip8_lookups = &clip8_table[640];
+    *Ydata_offset = static_cast<T>(clip8_lookups[output >> 22]);
+  } else if constexpr (std::is_same<T, int32_t>::value) {
+    *Ydata_offset = static_cast<int32_t>(std::round(output));
+  } else {
+    *Ydata_offset = static_cast<T>(output);
+  }
+}
+
+template <typename T, typename AccumType>
+__global__ void _ComputeInterpolationAtLevel2(
+    int64_t num_channels,
+    int64_t input_height, int64_t input_width,
+    int64_t output_height, int64_t output_width,
+    const fast_divmod div_output_height,
+    const fast_divmod div_output_width,
+    const fast_divmod div_output_image,
+    int32_t window_size,
+    bool use_extrapolation, float extrapolation_value,
+    const uint8_t* clip8_table,
+    const int64_t* bound_data,
+    std::tuple<int64_t*, int64_t*> outof_bounds_buffers,
+    const AccumType* weight_coefficients,
+    const T* Xdata, T* Ydata, int N) {
+  CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
+
+  // No need to do scale
+  if (output_height == input_height) {
+    Ydata[id] = Xdata[id];
+    return;
+  }
+
+  int bxc, output_image_index;
+  div_output_image.divmod(id, bxc, output_image_index);
+
+  int output_z, output_y, output_x, temp;
+  div_output_height.divmod(output_image_index, output_z, temp);
+  div_output_width.divmod(temp, output_y, output_x);
+
+  CUDA_LONG input_index = static_cast<CUDA_LONG>(bxc * num_channels * input_height * input_width +
+                                                 output_z * input_height * input_width);
+  CUDA_LONG output_index = static_cast<CUDA_LONG>(bxc * num_channels * output_height * output_width +
+                                                  output_z * output_height * output_width);
+
+  auto* Ydata_offset = Ydata + output_index + output_width * output_y + output_x;
+
+  if (use_extrapolation) {
+    const auto* w_outof_bounds = std::get<1>(outof_bounds_buffers);
+    // Extrapolate along the w dimension
+    if (w_outof_bounds[static_cast<ptrdiff_t>(output_x)] != -1) {
+      *Ydata_offset = static_cast<T>(extrapolation_value);
+      return;
+    }
+
+    // Extrapolate along the y dimension
+    const auto* y_outof_bounds = std::get<0>(outof_bounds_buffers);
+    if (y_outof_bounds[static_cast<ptrdiff_t>(output_y)] != -1) {
+      *Ydata_offset = static_cast<T>(extrapolation_value);
+      return;
+    }
+  }
+
+  const auto* bound = bound_data;
+
+  AccumType output = onnxruntime::is_8bit_v<T> ? ConstValue::mag_factor : 0;
+
+  const auto* weight_coeff = weight_coefficients + window_size * output_y;
+  int64_t ymin = bound[static_cast<ptrdiff_t>(output_y) * 2];
+  int64_t ymax = bound[static_cast<ptrdiff_t>(output_y) * 2 + 1];
+
+  const auto* Xdata_offset = Xdata + input_index + ymin * output_width + output_x;
+
+  for (; ymin < ymax; ++ymin) {
+    if constexpr (std::is_same<T, half>::value) {
+      // We cast to AccumType to resolve ambiguous call to operator* for half in CUDA
+      output += static_cast<AccumType>((*Xdata_offset)) * (*weight_coeff++);
+    } else {
+      output += (*Xdata_offset) * (*weight_coeff++);
+    }
+    Xdata_offset += input_width;
+  }
+
+  if constexpr (onnxruntime::is_8bit_v<T>) {
+    const uint8_t* clip8_lookups = &clip8_table[640];
+    *Ydata_offset = static_cast<T>(clip8_lookups[output >> 22]);
+  } else if constexpr (std::is_same<T, int32_t>::value) {
+    *Ydata_offset = static_cast<int32_t>(std::round(output));
+  } else {
+    *Ydata_offset = output;
+  }
+}
+
+template <typename T, typename AccumType>
+__global__ void _ComputeInterpolationAtLevel3(
+    int64_t input_depth,
+    int64_t input_height, int64_t input_width,
+    int64_t output_depth,
+    int64_t output_height, int64_t output_width,
+    const fast_divmod div_output_height,
+    const fast_divmod div_output_width,
+    const fast_divmod div_output_image,
+    int32_t window_size,
+    bool use_extrapolation, float extrapolation_value,
+    const uint8_t* clip8_table,
+    const int64_t* bound_data,
+    std::tuple<int64_t*, int64_t*, int64_t*> outof_bounds_buffers,
+    const AccumType* weight_coefficients,
+    const T* Xdata, T* Ydata, int N) {
+  CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
+
+  // No need to do scale
+  if (input_depth == output_depth) {
+    Ydata[id] = Xdata[id];
+    return;
+  }
+
+  int bxc, output_image_index;
+  div_output_image.divmod(id, bxc, output_image_index);
+
+  int output_z, output_y, output_x, temp;
+  div_output_height.divmod(output_image_index, output_z, temp);
+  div_output_width.divmod(temp, output_y, output_x);
+
+  CUDA_LONG input_index = static_cast<CUDA_LONG>(bxc * input_depth * input_height * input_width);
+
+  auto* Ydata_offset = Ydata + id;
+
+  if (use_extrapolation) {
+    const auto* w_outof_bounds = std::get<2>(outof_bounds_buffers);
+    // Extrapolate along the w dimension
+    if (w_outof_bounds[static_cast<ptrdiff_t>(output_x)] != -1) {
+      *Ydata_offset = static_cast<T>(extrapolation_value);
+      return;
+    }
+
+    // Extrapolate along the y dimension
+    const auto* y_outof_bounds = std::get<1>(outof_bounds_buffers);
+    if (y_outof_bounds[static_cast<ptrdiff_t>(output_y)] != -1) {
+      *Ydata_offset = static_cast<T>(extrapolation_value);
+      return;
+    }
+
+    // Extrapolate along the z dimension
+    const int64_t* z_outof_bounds = std::get<0>(outof_bounds_buffers);
+    if (z_outof_bounds != nullptr && z_outof_bounds[static_cast<ptrdiff_t>(output_z)] != -1) {
+      *Ydata_offset = static_cast<T>(extrapolation_value);
+      return;
+    }
+  }
+
+  const auto* bound = bound_data;
+
+  AccumType output = onnxruntime::is_8bit_v<T> ? ConstValue::mag_factor : 0;
+
+  const auto* weight_coeff = weight_coefficients + window_size * output_z;
+  int64_t zmin = bound[static_cast<ptrdiff_t>(output_z) * 2];
+  int64_t zmax = bound[static_cast<ptrdiff_t>(output_z) * 2 + 1];
+
+  const auto z_step = input_height * input_width;
+  const auto* Xdata_offset = Xdata + input_index + zmin * z_step + output_y * output_width + output_x;
+
+  for (; zmin < zmax; ++zmin) {
+    if constexpr (std::is_same<T, half>::value) {
+      // We cast to AccumType to resolve ambiguous call to operator* for half in CUDA
+      output += static_cast<AccumType>((*Xdata_offset)) * (*weight_coeff++);
+    } else {
+      output += (*Xdata_offset) * (*weight_coeff++);
+    }
+    Xdata_offset += z_step;
+  }
+
+  if constexpr (onnxruntime::is_8bit_v<T>) {
+    const uint8_t* clip8_lookups = &clip8_table[640];
+    *Ydata_offset = static_cast<T>(clip8_lookups[output >> 22]);
+  } else if constexpr (std::is_same<T, int32_t>::value) {
+    *Ydata_offset = static_cast<int32_t>(std::round(output));
+  } else {
+    *Ydata_offset = output;
+  }
+}
+
+/// <summary>
+/// This function expects the following buffers to be pre-allocated on device
+/// 1. bounds: int64_t[output_size * 2]
+/// 2. out_of_bounds: int64_t[output_size]
+/// 3. scale_data: T[output_size * window_size]
+///
+/// Template parameter AccumType
+/// </summary>
+template <typename AccumType, typename Filter, typename CudaFunctionOriginalCoordinate>
+FUNC_DEF void SetupUpsampleFilterAnitAliasImpl(
+    int64_t i,
+    int64_t input_size, int64_t output_size,
+    float rscale,
+    float roi_start, float roi_end,
+    float scaled_support, int32_t window_size, bool exclude_outside,
+    float cubic_coeff_a,
+    int64_t* bounds,
+    int64_t* out_of_bounds,
+    AccumType* scale_data) {
+  Filter filter{};
+  CudaFunctionOriginalCoordinate get_original_coordinate{};
+
+  const auto scale = 1.f / rscale;
+  const float inv_scale = (scale >= 1.0f) ? 1.0f / scale : 1.0f;
+
+  const float id = static_cast<float>(i);
+  float center = 0.5f;
+  if (scale == 1.0f) {
+    center += id;
+  } else {
+    center += get_original_coordinate(id, rscale,
+                                      static_cast<float>(output_size),
+                                      static_cast<float>(input_size),
+                                      roi_start, roi_end);
+  }
+
+  if (center - 0.5f < 0 || center - 0.5f > static_cast<float>(input_size - 1)) {
+    out_of_bounds[i] = i;
+  } else {
+    out_of_bounds[i] = -1;
+  }
+
+  float total_weight{0};
+
+  auto fmin = _Floor(center - scaled_support + 0.5f);
+  auto fmax = _Floor(center + scaled_support + 0.5f);
+
+  int64_t min_real = static_cast<int64_t>(fmin);
+  int64_t max_real = static_cast<int64_t>(fmax);
+  int64_t min_cut = std::max<int64_t>(min_real, 0);
+  int64_t max_cut = std::min(max_real, input_size);
+
+  int64_t min_val = exclude_outside ? min_cut : min_real;
+  int64_t max_val = exclude_outside ? max_cut : max_real;
+  bounds[i * 2] = min_cut;
+  bounds[i * 2 + 1] = max_cut;
+
+  // This is done for int32_t case, when the final result is in int32_t, but
+  // we perform calculations in float. All other types as is.
+  auto* scale_buffer = AccumTypeCaster<AccumType>::cast(&scale_data[i * window_size]);
+
+  max_val -= min_val;
+  for (int64_t x = 0; x < max_val; x++) {
+    const float arg = (x + min_val - center + 0.5f) * inv_scale;
+    const auto w = filter(arg, cubic_coeff_a);
+    scale_buffer[x] = w;
+    total_weight += w;
+  }
+
+  if (!exclude_outside) {
+    int64_t neg_xsize = min_val < 0 ? -min_val : 0;
+    for (int64_t x = 0; x < neg_xsize; x++) {
+      scale_buffer[neg_xsize] += scale_buffer[x];
+    }
+
+    int64_t bound_size =
+        max_val + min_val > input_size ? max_val + min_val - input_size : 0;
+    for (int64_t x = max_val - bound_size; x < max_val; x++) {
+      scale_buffer[max_val - bound_size - 1] +=
+          scale_buffer[x];
+    }
+
+    for (int64_t x = 0; (neg_xsize | bound_size) > 0 && x < max_cut - min_cut; x++) {
+      scale_buffer[x] = scale_buffer[x + neg_xsize];
+    }
+  }
+
+  const float total_weight_inv = (total_weight == 0) ? 1.f : (1.f / total_weight);
+  if constexpr (std::is_same<AccumType, int32_t>::value) {
+    auto* scale_buffer_int = reinterpret_cast<int32_t*>(scale_buffer);
+    for (int64_t x = 0; x < max_cut - min_cut; x++) {
+      scale_buffer[x] *= total_weight_inv;
+      // normalize the scale to 1 << 22 for int8/uint8
+      scale_buffer_int[x] = static_cast<int32_t>(_Round(scale_buffer[x] * ConstValue::mag_factor_x_2));
+    }
+  } else {
+    for (int64_t x = 0; x < max_cut - min_cut; x++) {
+      scale_buffer[x] *= total_weight_inv;
+    }
+  }
+}
+
+/// This kernel computes antialias filter for bilinear or bicubic upsampling.
+/// The function expects the following buffers to be pre-allocated on device
+/// 1. bounds: int64_t[output_size * 2] for each of the two dimensions
+/// 2. out_of_bounds: int64_t[output_size] for each of the two dimensions
+/// 3. scale_data: AccumType[output_size * window_size] for each of the two dimensions
+/// Buffers layout [h_data, w_data]
+template <typename AccumType, typename Filter, typename CudaFunctionOriginalCoordinate>
+__global__ void _SetupBilinearUpsampleFilterAntiAlias(
+    std::tuple<int64_t, int64_t> input_dims,       // h, w
+    std::tuple<int64_t, int64_t> output_dims,      // h, w
+    std::tuple<float, float> inv_scale_vals,       // h, w
+    std::tuple<float, float> roi_start_vals,       // h, w
+    std::tuple<float, float> roi_end_vals,         // h, w
+    std::tuple<float, float> dim_scaled_support,   // Pre-computed scaled support values h, w
+    std::tuple<int32_t, int32_t> dim_window_size,  // Pre-computed windows sizes h, w
+    float cubic_coeff_a,
+    bool exclude_outside,
+    int64_t* bounds,
+    int64_t* out_of_bounds,
+    std::tuple<AccumType*, AccumType*> weighted_coefficients  // y, h buffers
+) {
+  const auto N = std::get<0>(output_dims) + std::get<1>(output_dims);
+
+  CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
+
+  if (id < std::get<0>(output_dims)) {
+    // Setup for y
+    int64_t input_size = std::get<0>(input_dims);
+    int64_t output_size = std::get<0>(output_dims);
+    float inv_scale = std::get<0>(inv_scale_vals);
+    float roi_start = std::get<0>(roi_start_vals);
+    float roi_end = std::get<0>(roi_end_vals);
+    float scaled_support = std::get<0>(dim_scaled_support);
+    int32_t window_size = std::get<0>(dim_window_size);
+
+    SetupUpsampleFilterAnitAliasImpl<AccumType, Filter, CudaFunctionOriginalCoordinate>(
+        id,
+        input_size, output_size,
+        inv_scale,
+        roi_start, roi_end,
+        scaled_support, window_size,
+        exclude_outside,
+        cubic_coeff_a,
+        bounds,
+        out_of_bounds,
+        std::get<0>(weighted_coefficients));
+
+  } else {
+    // Setup for w
+    // w = id - output_height
+
+    int64_t input_size = std::get<1>(input_dims);
+    int64_t output_size = std::get<1>(output_dims);
+    float inv_scale = std::get<1>(inv_scale_vals);
+    float roi_start = std::get<1>(roi_start_vals);
+    float roi_end = std::get<1>(roi_end_vals);
+
+    float scaled_support = std::get<1>(dim_scaled_support);
+    int32_t window_size = std::get<1>(dim_window_size);
+
+    // Adjust buffer positions
+    const auto y_output_size = std::get<0>(output_dims);
+
+    auto i = id - y_output_size;
+    bounds += (y_output_size * 2);
+    out_of_bounds += y_output_size;
+
+    SetupUpsampleFilterAnitAliasImpl<AccumType, Filter, CudaFunctionOriginalCoordinate>(
+        i,
+        input_size, output_size,
+        inv_scale,
+        roi_start, roi_end,
+        scaled_support, window_size,
+        exclude_outside,
+        cubic_coeff_a,
+        bounds,
+        out_of_bounds,
+        std::get<1>(weighted_coefficients));
+  }
+}
+
+/// <summary>
+/// Compute AntiAlias filter for trilinear upsampling, all in one go
+/// The function expects the following buffers to be pre-allocated on device
+/// 1. bounds: int64_t[output_size * 2] for each of the three dimensions
+/// 2. out_of_bounds: int64_t[output_size] for each of the three dimensions
+/// 3. scale_data: AccumType[output_size * window_size] for each of the three dimensions
+/// Each kind of buffer contains data for all 3 dims.
+/// Buffers layout [d_data, h_data, w_data]
+/// </summary>
+template <typename AccumType, typename Filter, typename CudaFunctionOriginalCoordinate>
+__global__ void _SetupTrilinerarUpsampleFilterAntiAlias(
+    std::tuple<int64_t, int64_t, int64_t> input_dims,       // d, h, w
+    std::tuple<int64_t, int64_t, int64_t> output_dims,      // d, h, w
+    std::tuple<float, float, float> inv_scale_vals,         // d, h, w
+    std::tuple<float, float, float> roi_start_vals,         // d, h, w
+    std::tuple<float, float, float> roi_end_vals,           // d, h, w
+    std::tuple<float, float, float> dim_scaled_support,     // Pre-computed scaled support values d, h, w
+    std::tuple<int32_t, int32_t, int32_t> dim_window_size,  // Pre-computed windows sizes d, h, w
+    bool exclude_outisde,
+    int64_t* bounds,
+    int64_t* out_of_bounds,
+    std::tuple<AccumType*, AccumType*, AccumType*> weighted_coefficients) {
+  const auto N = std::get<0>(output_dims) + std::get<1>(output_dims) + std::get<2>(output_dims);
+
+  CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
+
+  if (id < std::get<0>(output_dims)) {
+    // Setup for d by default (id < output_depth)
+    int64_t input_size = std::get<0>(input_dims);
+    int64_t output_size = std::get<0>(output_dims);
+    float inv_scale = std::get<0>(inv_scale_vals);
+    float roi_start = std::get<0>(roi_start_vals);
+    float roi_end = std::get<0>(roi_end_vals);
+    float scaled_support = std::get<0>(dim_scaled_support);
+    int32_t window_size = std::get<0>(dim_window_size);
+
+    SetupUpsampleFilterAnitAliasImpl<AccumType, Filter, CudaFunctionOriginalCoordinate>(
+        id,
+        input_size, output_size,
+        inv_scale,
+        roi_start, roi_end,
+        scaled_support, window_size,
+        exclude_outisde,
+        onnxruntime::antialias_constants::kCubicCoeffA,  // Default value for trilinear
+        bounds,
+        out_of_bounds,
+        std::get<0>(weighted_coefficients));
+
+  } else if (id >= std::get<0>(output_dims) && id < (std::get<0>(output_dims) + std::get<1>(output_dims))) {
+    int64_t input_size = std::get<1>(input_dims);
+    int64_t output_size = std::get<1>(output_dims);
+    float inv_scale = std::get<1>(inv_scale_vals);
+    float roi_start = std::get<1>(roi_start_vals);
+    float roi_end = std::get<1>(roi_end_vals);
+
+    float scaled_support = std::get<1>(dim_scaled_support);
+    int32_t window_size = std::get<1>(dim_window_size);
+
+    // Adjust buffer positions
+    const auto d_output_size = std::get<0>(output_dims);
+
+    auto i = id - d_output_size;
+    bounds += d_output_size * 2;
+    out_of_bounds += d_output_size;
+
+    SetupUpsampleFilterAnitAliasImpl<AccumType, Filter, CudaFunctionOriginalCoordinate>(
+        i,
+        input_size, output_size,
+        inv_scale,
+        roi_start, roi_end,
+        scaled_support, window_size,
+        exclude_outisde,
+        onnxruntime::antialias_constants::kCubicCoeffA,  // Default value for trilinear
+        bounds,
+        out_of_bounds,
+        std::get<1>(weighted_coefficients));
+  } else {
+    int64_t input_size = std::get<2>(input_dims);
+    int64_t output_size = std::get<2>(output_dims);
+    float inv_scale = std::get<2>(inv_scale_vals);
+    float roi_start = std::get<2>(roi_start_vals);
+    float roi_end = std::get<2>(roi_end_vals);
+    float scaled_support = std::get<2>(dim_scaled_support);
+    int32_t window_size = std::get<2>(dim_window_size);
+
+    // Adjust buffer positions
+    const auto d_y_output_size = std::get<0>(output_dims) + std::get<1>(output_dims);
+
+    auto i = id - d_y_output_size;
+    bounds += (d_y_output_size * 2);
+    out_of_bounds += d_y_output_size;
+
+    SetupUpsampleFilterAnitAliasImpl<AccumType, Filter, CudaFunctionOriginalCoordinate>(
+        i,
+        input_size, output_size,
+        inv_scale,
+        roi_start, roi_end,
+        scaled_support, window_size,
+        exclude_outisde,
+        onnxruntime::antialias_constants::kCubicCoeffA,  // Default value for trilinear
+        bounds,
+        out_of_bounds,
+        std::get<2>(weighted_coefficients));
+  }
+}
+
+#define CASEA_COORD_ANTIALIAS(coordinate_mode, TransformCoordType, ...) \
+  case coordinate_mode: {                                               \
+    using coord_t = TransformCoordType;                                 \
+    return __VA_ARGS__();                                               \
+    break;                                                              \
+  }
+
+#define DISPATCH_ANTIALIAS_FILTER_SETUP(coord_enum, ...)                              \
+  [&] {                                                                               \
+    const auto the_type = coord_enum;                                                 \
+    switch (the_type) {                                                               \
+      CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::HALF_PIXEL,           \
+                            TransformCoordinate_HALF_PIXEL, __VA_ARGS__)              \
+      CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::ASYMMETRIC,           \
+                            TransformCoordinate_ASYMMETRIC, __VA_ARGS__)              \
+      CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL,   \
+                            TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__)      \
+      CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::ALIGN_CORNERS,        \
+                            TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__)           \
+      CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN, \
+                            TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__)    \
+      CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE,   \
+                            TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__)      \
+      default:                                                                        \
+        ORT_THROW("unknown ResizeCoordinateTransformationMode");                      \
+    }                                                                                 \
+  }()
+
+namespace {
+template <typename T>
+IAllocatorUniquePtr<uint8_t> AllocateTyped(
+    const TempSpaceAllocateFunc& alloc,
+    size_t elements) {
+  return alloc(elements * sizeof(T));
+}
+
+template <typename T>
+T* GetTyped(IAllocatorUniquePtr<uint8_t>& bytes) {
+  return reinterpret_cast<T*>(bytes.get());
+}
+}  // namespace
+
+template <typename T>
+void ResizeTrilinearUpsample(
+    cudaStream_t stream,
+    int rank,
+    const UpsampleMode /*upsample_mode*/,
+    ResizeCoordinateTransformationMode coordinate_transform_mode,
+    gsl::span<const int64_t> /*input_shape*/,
+    gsl::span<const int64_t> /*output_shape*/,
+    int64_t batch_size, int64_t num_channels,
+    std::tuple<int64_t, int64_t, int64_t> inferred_input_dims,
+    std::tuple<int64_t, int64_t, int64_t> inferred_output_dims,
+    std::tuple<float, float, float> inferred_dim_rscales,
+    const TArray<fast_divmod>& output_div_pitches,
+    gsl::span<const float> roi_vals,
+    const std::optional<float>& extrapolation,
+    bool exclude_outside,
+    const TempSpaceAllocateFunc& allocate_temp_space,
+    const uint8_t* clip8_lookups,
+    const T* input_data,
+    T* output_data,
+    const size_t N) {
+  using AccumType = typename onnxruntime::AccumulateType<T>::type;
+
+  const bool use_extrapolation = extrapolation.has_value();
+  const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f;
+
+  int64_t input_depth, input_height, input_width;
+  std::tie(input_depth, input_height, input_width) = inferred_input_dims;
+
+  int64_t output_depth, output_height, output_width;
+  std::tie(output_depth, output_height, output_width) = inferred_output_dims;
+
+  int blocksPerDimsMappingGrid =
+      static_cast<int>(ceil((output_depth + output_height + output_width) / 32.0));
+
+  int blocksPerGrid = static_cast<int>(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
+
+  constexpr float support_value = antialias_constants::kSupportSize;
+  float z_scale, h_scale, w_scale;
+  std::tie(z_scale, h_scale, w_scale) = inferred_dim_rscales;
+
+  const auto& div_output_width = output_div_pitches[rank - 2];
+
+  SafeInt<int64_t> bounds_buffer_size = (SafeInt<int64_t>(output_depth) + output_height + output_width) * 2;
+  SafeInt<int64_t> out_of_bounds_buffer_size = (SafeInt<int64_t>(output_depth) + output_height + output_width);
+
+  auto bounds_buffer_ptr = AllocateTyped<int64_t>(allocate_temp_space, bounds_buffer_size);
+  auto out_of_bounds_buffer_ptr = AllocateTyped<int64_t>(allocate_temp_space, out_of_bounds_buffer_size);
+
+  int64_t* z_bounds_buffer = GetTyped<int64_t>(bounds_buffer_ptr);
+  int64_t* y_bounds_buffer = z_bounds_buffer + output_depth * 2;
+  int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2;
+
+  int64_t* z_outof_bounds_buffer = GetTyped<int64_t>(out_of_bounds_buffer_ptr);
+  int64_t* y_outof_bounds_buffer = z_outof_bounds_buffer + output_depth;
+  int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height;
+
+  float z_scaled_support, h_scaled_support, w_scaled_support;
+  int32_t z_window_size, h_window_size, w_window_size;
+  const auto [z_buffer_size, y_buffer_size, w_buffer_size] = ComputeTrilinearScaleBufferSize(
+      output_depth, output_height, output_width,
+      z_scale, h_scale, w_scale, support_value,
+      z_scaled_support, h_scaled_support, w_scaled_support,
+      z_window_size, h_window_size, w_window_size);
+
+  const int64_t weighted_buffer_size = SafeInt<int64_t>(z_buffer_size) + y_buffer_size + w_buffer_size;
+
+  auto weighted_buffer_ptr = AllocateTyped<AccumType>(allocate_temp_space, weighted_buffer_size);
+  AccumType* z_weighted_buffer = GetTyped<AccumType>(weighted_buffer_ptr);
+  AccumType* y_weighted_buffer = z_weighted_buffer + z_buffer_size;
+  AccumType* w_weighted_buffer = y_weighted_buffer + y_buffer_size;
+
+  const auto h_w_interpolate_temp_buf_size = SafeInt<int64_t>(batch_size) * num_channels *
+                                             input_depth * input_height * output_width;
+  auto h_w_interpolate_temp_buffer_ptr = AllocateTyped<T>(allocate_temp_space,
+                                                          narrow<size_t>(h_w_interpolate_temp_buf_size));
+
+  const auto h_w_interpolate_result_buffer_size = SafeInt<int64_t>(batch_size) * num_channels *
+                                                  input_depth * output_height * output_width;
+  auto h_w_interpolate_result_buffer_ptr = AllocateTyped<T>(allocate_temp_space, h_w_interpolate_result_buffer_size);
+
+  // clang-format off
+  DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() {
+    _SetupTrilinerarUpsampleFilterAntiAlias<AccumType,
+                                            TriLinearFilter,
+                                            coord_t><<<blocksPerDimsMappingGrid, 32, 0, stream>>>(
+        inferred_input_dims,
+        inferred_output_dims,
+        inferred_dim_rscales,
+        std::make_tuple(roi_vals[rank - 3], roi_vals[rank - 2], roi_vals[rank - 1]),  // roi starts d, h, w
+        std::make_tuple(roi_vals[rank - 3 + rank], roi_vals[rank - 2 + rank],         // roi ends d, h, w
+                        roi_vals[rank - 1 + rank]),
+        std::make_tuple(z_scaled_support, h_scaled_support, w_scaled_support),
+        std::make_tuple(z_window_size, h_window_size, w_window_size),
+        exclude_outside,
+        GetTyped<int64_t>(bounds_buffer_ptr),
+        GetTyped<int64_t>(out_of_bounds_buffer_ptr),
+        std::make_tuple(z_weighted_buffer, y_weighted_buffer, w_weighted_buffer));
+  });
+
+  // clang-format on
+  const fast_divmod div_w_image(narrow<int>(num_channels * input_depth * input_height * output_width));
+  // clang-format off
+  _ComputeInterpolationAtLevel1<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
+      num_channels * input_depth, input_height, input_width, input_height, output_width,
+      div_output_width,
+      div_w_image,
+      w_window_size,
+      clip8_lookups,
+      w_bounds_buffer,
+      std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer),
+      w_weighted_buffer, input_data,
+      GetTyped<T>(h_w_interpolate_temp_buffer_ptr),
+      narrow<int>(h_w_interpolate_temp_buf_size));
+
+  // clang-format on
+  const fast_divmod div_output_height{narrow<int>(output_height * output_width)};
+  const fast_divmod div_h_w_image(narrow<int>(num_channels * input_depth * output_height * output_width));
+  // clang-format off
+  _ComputeInterpolationAtLevel2<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
+      num_channels * input_depth, input_height, output_width, output_height, output_width,
+      div_output_height,
+      div_output_width,
+      div_h_w_image,
+      h_window_size,
+      false, 0.f,  // No extrapolation
+      clip8_lookups,
+      y_bounds_buffer,
+      std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer),
+      y_weighted_buffer, GetTyped<T>(h_w_interpolate_temp_buffer_ptr),
+      GetTyped<T>(h_w_interpolate_result_buffer_ptr),
+      narrow<int>(h_w_interpolate_result_buffer_size));
+
+  // clang-format on
+  const fast_divmod div_z_h_w_image(narrow<int>(input_depth * output_height * output_width));
+  // clang-format off
+  _ComputeInterpolationAtLevel3<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
+      input_depth, output_height, output_width,
+      output_depth, output_height, output_width,
+      div_output_height,
+      div_output_width,
+      div_z_h_w_image,
+      z_window_size,
+      use_extrapolation, extrapolation_value,
+      clip8_lookups,
+      z_bounds_buffer,
+      std::make_tuple(z_outof_bounds_buffer, y_outof_bounds_buffer, w_outof_bounds_buffer),
+      z_weighted_buffer, GetTyped<T>(h_w_interpolate_result_buffer_ptr),
+      output_data,
+      narrow<int>(N));
+  // clang-format on
+}
+
+template <class T>
+void ResizeBiLinearUpsample(cudaStream_t stream,
+                            int rank,
+                            const UpsampleMode /*upsample_mode*/,
+                            ResizeCoordinateTransformationMode coordinate_transform_mode,
+                            gsl::span<const int64_t> /*input_shape*/,
+                            gsl::span<const int64_t> /*output_shape*/,
+                            int64_t /*batch_size*/, int64_t num_channels,
+                            std::tuple<int64_t, int64_t, int64_t> inferred_input_dims,
+                            std::tuple<int64_t, int64_t, int64_t> inferred_output_dims,
+                            std::tuple<float, float, float> inferred_dim_rscales,
+                            const TArray<fast_divmod>& output_div_pitches,
+                            gsl::span<const float> roi_vals,
+                            const std::optional<float>& extrapolation,
+                            bool exclude_outside,
+                            const TempSpaceAllocateFunc& allocate_temp_space,
+                            const uint8_t* clip8_lookups,
+                            const T* input_data,
+                            T* output_data,
+                            const size_t N) {
+  using AccumType = typename onnxruntime::AccumulateType<T>::type;
+
+  const bool use_extrapolation = extrapolation.has_value();
+  const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f;
+
+  int64_t input_depth, input_height, input_width;
+  std::tie(input_depth, input_height, input_width) = inferred_input_dims;
+
+  int64_t output_depth, output_height, output_width;
+  std::tie(output_depth, output_height, output_width) = inferred_output_dims;
+
+  int blocksPerDimsMappingGrid =
+      narrow<int>(CeilDiv((output_depth + output_height + output_width), 32));
+
+  // rank 2 or 4
+  const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4]
+                                                  : fast_divmod(gsl::narrow_cast<int>(N));
+  const fast_divmod& div_output_width = output_div_pitches[rank - 2];
+
+  constexpr float support_value = antialias_constants::kSupportSize;
+
+  float h_scale, w_scale;
+  std::tie(std::ignore, h_scale, w_scale) = inferred_dim_rscales;
+
+  int blocksPerGrid = narrow<int>(CeilDiv(N, GridDim::maxThreadsPerBlock));
+
+  SafeInt<int64_t> bounds_buffer_size = (SafeInt<int64_t>(output_height) + output_width) * 2;
+  SafeInt<int64_t> out_of_bounds_buffer_size = (SafeInt<int64_t>(output_height) + output_width);
+
+  float h_scaled_support, w_scaled_support;
+  int32_t h_window_size, w_window_size;
+  const auto [weighted_y_size, weighted_w_size] =
+      ComputeBilinearScaleBufferSize(output_height, output_width,
+                                     h_scale, w_scale, support_value,
+                                     h_scaled_support, w_scaled_support, h_window_size, w_window_size);
+
+  auto bounds_buffer_ptr = AllocateTyped<int64_t>(allocate_temp_space, bounds_buffer_size);
+  auto out_of_bounds_buffer_ptr = AllocateTyped<int64_t>(allocate_temp_space, out_of_bounds_buffer_size);
+
+  int64_t* y_bounds_buffer = GetTyped<int64_t>(bounds_buffer_ptr);
+  int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2;
+
+  int64_t* y_outof_bounds_buffer = GetTyped<int64_t>(out_of_bounds_buffer_ptr);
+  int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height;
+
+  const int64_t weighted_buffer_size = SafeInt<int64_t>(weighted_y_size) + weighted_w_size;
+  auto weighted_buffer_ptr = AllocateTyped<AccumType>(allocate_temp_space, narrow<size_t>(weighted_buffer_size));
+
+  AccumType* y_weighted_buffer = GetTyped<AccumType>(weighted_buffer_ptr);
+  AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size;
+
+  const auto temp_buf_size = num_channels * input_height * output_width;
+  auto image_temp_buffer = AllocateTyped<T>(allocate_temp_space, narrow<size_t>(temp_buf_size));
+
+  // clang-format off
+  DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() {
+    //  Data is d, h, w in tuples
+
+    _SetupBilinearUpsampleFilterAntiAlias<AccumType,
+                                          BilinearFilter,
+                                          coord_t><<<blocksPerDimsMappingGrid, 32, 0, stream>>>(
+        std::make_tuple(input_height, input_width),
+        std::make_tuple(output_height, output_width),
+        std::make_tuple(h_scale, w_scale),
+        std::make_tuple(roi_vals[rank - 2], roi_vals[rank - 1]),                // roi starts h, w
+        std::make_tuple(roi_vals[rank - 2 + rank], roi_vals[rank - 1 + rank]),  // roi ends h, w
+        std::make_tuple(h_scaled_support, w_scaled_support),
+        std::make_tuple(h_window_size, w_window_size),
+        onnxruntime::antialias_constants::kCubicCoeffA, exclude_outside,
+        GetTyped<int64_t>(bounds_buffer_ptr),
+        GetTyped<int64_t>(out_of_bounds_buffer_ptr),
+        std::make_tuple(y_weighted_buffer, w_weighted_buffer));
+  });
+
+  // clang-format on
+  const fast_divmod div_step_image{narrow<int>(num_channels * input_height * output_width)};
+  // clang-format off
+  _ComputeInterpolationAtLevel1<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
+      num_channels, input_height, input_width, input_height, output_width,
+      div_output_width,
+      div_step_image,
+      w_window_size,
+      clip8_lookups,
+      w_bounds_buffer,
+      std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer),
+      w_weighted_buffer, input_data, GetTyped<T>(image_temp_buffer),
+      narrow<int>(temp_buf_size));
+
+  // clang-format on
+  const fast_divmod div_output_height{narrow<int>(output_height * output_width)};
+  // clang-format off
+  _ComputeInterpolationAtLevel2<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
+      num_channels, input_height, output_width, output_height, output_width,
+      div_output_height,
+      div_output_width,
+      div_output_image,
+      h_window_size,
+      use_extrapolation, extrapolation_value,
+      clip8_lookups,
+      y_bounds_buffer,
+      std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer),
+      y_weighted_buffer, GetTyped<T>(image_temp_buffer), output_data,
+      narrow<int>(N));
+
+  // clang-format on
+}
+
+template <typename T>
+void ResizeBicubicUpsample(cudaStream_t stream,
+                           int rank,
+                           const UpsampleMode /*upsample_mode*/,
+                           ResizeCoordinateTransformationMode coordinate_transform_mode,
+                           gsl::span<const int64_t> /*input_shape*/,
+                           gsl::span<const int64_t> /*output_shape*/,
+                           int64_t batch_size, int64_t num_channels,
+                           std::tuple<int64_t, int64_t, int64_t> inferred_input_dims,
+                           std::tuple<int64_t, int64_t, int64_t> inferred_output_dims,
+                           std::tuple<float, float, float> inferred_dim_rscales,
+                           // const TArray<int64_t>& input_strides,
+                           const TArray<fast_divmod>& output_div_pitches,
+                           gsl::span<const float> roi_vals,
+                           const std::optional<float>& extrapolation,
+                           bool exclude_outside,
+                           const TempSpaceAllocateFunc& allocate_temp_space,
+                           const uint8_t* clip8_lookups,
+                           const T* input_data,
+                           T* output_data,
+                           const size_t N) {
+  using AccumType = typename onnxruntime::AccumulateType<T>::type;
+
+  const bool use_extrapolation = extrapolation.has_value();
+  const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f;
+
+  int blocksPerGrid = narrow<int>(CeilDiv(N, GridDim::maxThreadsPerBlock));
+  const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4]
+                                                  : fast_divmod(gsl::narrow_cast<int>(N));
+  const fast_divmod& div_output_width = output_div_pitches[rank - 2];
+
+  constexpr float support_value = antialias_constants::kBiCubicSupportSize;
+
+  int64_t input_depth, input_height, input_width;
+  std::tie(input_depth, input_height, input_width) = inferred_input_dims;
+
+  int64_t output_depth, output_height, output_width;
+  std::tie(output_depth, output_height, output_width) = inferred_output_dims;
+
+  int blocksPerDimsMappingGrid =
+      narrow<int>(CeilDiv((output_depth + output_height + output_width), 32));
+
+  float h_scale, w_scale;
+  std::tie(std::ignore, h_scale, w_scale) = inferred_dim_rscales;
+
+  SafeInt<int64_t> bounds_buffer_size = (SafeInt<int64_t>(output_height) + output_width) * 2;
+  SafeInt<int64_t> out_of_bounds_buffer_size = (SafeInt<int64_t>(output_height) + output_width);
+
+  float h_scaled_support, w_scaled_support;
+  int32_t h_window_size, w_window_size;
+  const auto [weighted_y_size, weighted_w_size] =
+      ComputeBilinearScaleBufferSize(output_height, output_width,
+                                     h_scale, w_scale, support_value,
+                                     h_scaled_support, w_scaled_support, h_window_size, w_window_size);
+
+  auto bounds_buffer_ptr = AllocateTyped<int64_t>(allocate_temp_space, bounds_buffer_size);
+  auto out_of_bounds_buffer_ptr = AllocateTyped<int64_t>(allocate_temp_space, out_of_bounds_buffer_size);
+
+  int64_t* y_bounds_buffer = GetTyped<int64_t>(bounds_buffer_ptr);
+  int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2;
+
+  int64_t* y_outof_bounds_buffer = GetTyped<int64_t>(out_of_bounds_buffer_ptr);
+  int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height;
+
+  const int64_t weighted_buffer_size = SafeInt<int64_t>(weighted_y_size) +
+                                       weighted_w_size;
+  auto weighted_buffer_ptr = AllocateTyped<AccumType>(allocate_temp_space, weighted_buffer_size);
+
+  AccumType* y_weighted_buffer = GetTyped<AccumType>(weighted_buffer_ptr);
+  AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size;
+
+  const auto temp_buf_size = SafeInt<int64_t>(batch_size) * num_channels * input_height * output_width;
+  auto image_temp_buffer = AllocateTyped<T>(allocate_temp_space, narrow<size_t>(temp_buf_size));
+
+  // clang-format off
+  DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() {
+    _SetupBilinearUpsampleFilterAntiAlias<AccumType,
+                                          BiCubicFilter,
+                                          coord_t><<<blocksPerDimsMappingGrid, 32, 0, stream>>>(
+        std::make_tuple(input_height, input_width),
+        std::make_tuple(output_height, output_width),
+        std::make_tuple(h_scale, w_scale),
+        std::make_tuple(roi_vals[rank - 2], roi_vals[rank - 1]),                // roi starts h, w
+        std::make_tuple(roi_vals[rank - 2 + rank], roi_vals[rank - 1 + rank]),  // roi ends h, w
+        std::make_tuple(h_scaled_support, w_scaled_support),
+        std::make_tuple(h_window_size, w_window_size),
+        onnxruntime::antialias_constants::kCubicCoeffA, exclude_outside,
+        GetTyped<int64_t>(bounds_buffer_ptr),
+        GetTyped<int64_t>(out_of_bounds_buffer_ptr),
+        std::make_tuple(y_weighted_buffer, w_weighted_buffer));
+  });
+  // clang-format on
+  const fast_divmod div_step_image(narrow<int>(num_channels * input_height * output_width));
+  // clang-format off
+  _ComputeInterpolationAtLevel1<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
+      num_channels, input_height, input_width, input_height, output_width,
+      div_output_width,
+      div_step_image,
+      w_window_size,
+      clip8_lookups,
+      w_bounds_buffer,
+      std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer),
+      w_weighted_buffer, input_data, GetTyped<T>(image_temp_buffer),
+      narrow<int>(temp_buf_size));
+  // clang-format on
+
+  const fast_divmod div_output_height{narrow<int>(output_height * output_width)};
+  // clang-format off
+  _ComputeInterpolationAtLevel2<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
+      num_channels, input_height, output_width, output_height, output_width,
+      div_output_height,
+      div_output_width,
+      div_output_image,
+      h_window_size,
+      use_extrapolation, extrapolation_value,
+      clip8_lookups,
+      y_bounds_buffer,
+      std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer),
+      y_weighted_buffer, GetTyped<T>(image_temp_buffer), output_data,
+      narrow<int>(N));
+  // clang-format on
+}
+
+template <class T>
+void ResizeAntiAliasImpl(
+    cudaStream_t stream,
+    int rank,
+    const UpsampleMode upsample_mode,
+    ResizeCoordinateTransformationMode coordinate_transform_mode,
+    gsl::span<const int64_t> input_shape,
+    gsl::span<const int64_t> output_shape,
+    int64_t batch_size, int64_t num_channels,
+    std::tuple<int64_t, int64_t, int64_t> inferred_input_dims,
+    std::tuple<int64_t, int64_t, int64_t> inferred_output_dims,
+    std::tuple<float, float, float> inferred_dim_rscales,
+    const TArray<fast_divmod>& output_div_pitches,
+    gsl::span<const float> roi_vals,
+    const std::optional<float>& extrapolation,
+    bool exclude_outside,
+    TempSpaceAllocateFunc allocate_temp_space,
+    const uint8_t* clip8_lookups,
+    const T* input_data,
+    T* output_data,
+    const size_t N) {
+  // We support a special case of bilinear or bicubic if the input data is 4D with the outer 2 scales being 1.0
+  // We would have validated the outer scale values by the time execution reaches this
+  const bool is_2D = (rank == 2 || rank == 4);
+
+  // We support a special case of trilinear or tricubic if the input data is 5D with the outer 2 scales being 1.0
+  // We would have validated the outer scale values by the time execution reaches this
+  const bool is_3D = (rank == 3 || rank == 5);
+
+  // Should not hit this as we have already validated input rank/scales and we provide verbose error messages
+  // to the user.
+  ORT_ENFORCE(is_2D || is_3D, "Only bilinear/trilinear and bicubic modes are supported in Resize anti-alias mode");
+
+  switch (upsample_mode) {
+    case UpsampleMode::LINEAR: {
+      if (is_2D) {
+        ResizeBiLinearUpsample<T>(stream, rank, upsample_mode, coordinate_transform_mode,
+                                  input_shape, output_shape, batch_size, num_channels,
+                                  inferred_input_dims, inferred_output_dims, inferred_dim_rscales,
+                                  output_div_pitches, roi_vals, extrapolation, exclude_outside,
+                                  allocate_temp_space, clip8_lookups, input_data, output_data, N);
+      } else if (is_3D) {
+        ResizeTrilinearUpsample<T>(stream, rank, upsample_mode, coordinate_transform_mode,
+                                   input_shape, output_shape, batch_size, num_channels,
+                                   inferred_input_dims, inferred_output_dims, inferred_dim_rscales,
+                                   output_div_pitches, roi_vals, extrapolation, exclude_outside,
+                                   allocate_temp_space, clip8_lookups, input_data, output_data, N);
+      } else {
+        ORT_NOT_IMPLEMENTED("Resize supports only 2-D or 3-D in LINEAR mode.");
+      }
+    } break;
+    case CUBIC: {
+      if (is_2D) {
+        ResizeBicubicUpsample<T>(stream, rank, upsample_mode, coordinate_transform_mode,
+                                 input_shape, output_shape, batch_size, num_channels,
+                                 inferred_input_dims, inferred_output_dims, inferred_dim_rscales,
+                                 output_div_pitches, roi_vals, extrapolation, exclude_outside,
+                                 allocate_temp_space, clip8_lookups, input_data, output_data, N);
+      } else {
+        ORT_NOT_IMPLEMENTED("Resize supports only 2-D in CUBIC mode.");
+      }
+    } break;
+    default:
+      ORT_NOT_IMPLEMENTED("Only bilinear/trilinear and bicubic modes are supported in Resize anti-alias mode");
+      break;
+  }
+}
+
+#define SPECIALIZED_ANTIALIAS_IMPL(T)                               \
+  template void ResizeAntiAliasImpl<T>(                             \
+      cudaStream_t stream,                                          \
+      int rank,                                                     \
+      const UpsampleMode upsample_mode,                             \
+      ResizeCoordinateTransformationMode coordinate_transform_mode, \
+      gsl::span<const int64_t> input_shape,                         \
+      gsl::span<const int64_t> output_shape,                        \
+      int64_t batch_size, int64_t num_channels,                     \
+      std::tuple<int64_t, int64_t, int64_t> inferred_input_dims,    \
+      std::tuple<int64_t, int64_t, int64_t> inferred_output_dims,   \
+      std::tuple<float, float, float> inferred_dim_rscales,         \
+      const TArray<fast_divmod>& output_div_pitches,                \
+      gsl::span<const float> roi_vals,                              \
+      const std::optional<float>& extrapolation_value,              \
+      bool exclude_outside,                                         \
+      TempSpaceAllocateFunc allocate_temp_space,                    \
+      const uint8_t* clip8_lookups,                                 \
+      const T* input_data,                                          \
+      T* output_data,                                               \
+      const size_t N);
+
+SPECIALIZED_ANTIALIAS_IMPL(float)
+SPECIALIZED_ANTIALIAS_IMPL(double)
+SPECIALIZED_ANTIALIAS_IMPL(half)
+SPECIALIZED_ANTIALIAS_IMPL(int32_t)
+SPECIALIZED_ANTIALIAS_IMPL(uint8_t)
+
+}  // namespace cuda
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu
index 1a94c7705e913..e788f24052985 100644
--- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu
+++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu
@@ -12,7 +12,7 @@ using onnxruntime::ResizeNearestMode;
 using onnxruntime::UpsampleMode;
 
 struct NearestPixel_SIMPLE {
-  __device__ __forceinline__ int operator() (float x_original, bool is_down_sampling) const {
+  __device__ __forceinline__ int operator()(float x_original, bool is_down_sampling) const {
     if (is_down_sampling) {
       return static_cast<int>(_Ceil(x_original));
     }
@@ -21,7 +21,7 @@ struct NearestPixel_SIMPLE {
 };
 
 struct NearestPixel_ROUND_PREFER_FLOOR {
-  __device__ __forceinline__ int operator() (float x_original, bool) const {
+  __device__ __forceinline__ int operator()(float x_original, bool) const {
     if (x_original == static_cast<int>(x_original) + 0.5f) {
       return static_cast<int>(_Floor(x_original));
     }
@@ -30,62 +30,23 @@ struct NearestPixel_ROUND_PREFER_FLOOR {
 };
 
 struct NearestPixel_ROUND_PREFER_CEIL {
-  __device__ __forceinline__ int operator() (float x_original, bool) const {
+  __device__ __forceinline__ int operator()(float x_original, bool) const {
     return static_cast<int>(roundf(x_original));
   }
 };
 
 struct NearestPixel_FLOOR {
-  __device__ __forceinline__ int operator() (float x_original, bool) const {
+  __device__ __forceinline__ int operator()(float x_original, bool) const {
     return static_cast<int>(_Floor(x_original));
   }
 };
 
 struct NearestPixel_CEIL {
-  __device__ __forceinline__ int operator() (float x_original, bool) const {
+  __device__ __forceinline__ int operator()(float x_original, bool) const {
     return static_cast<int>(_Ceil(x_original));
   }
 };
 
-struct TransformCoordinate_ASYMMETRIC {
-  __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const {
-    return x_resized / x_scale;
-  }
-};
-
-struct TransformCoordinate_HALF_PIXEL {
-  __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const {
-    return ((x_resized + 0.5f) / x_scale) - 0.5f;
-  }
-};
-
-struct TransformCoordinate_PYTORCH_HALF_PIXEL {
-  __device__ __forceinline__ float operator() (float x_resized, float x_scale, float length_resized, float, float, float) const {
-    return length_resized > 1 ? (x_resized + 0.5f) / x_scale - 0.5f : 0.0f;
-  }
-};
-
-struct TransformCoordinate_TF_HALF_PIXEL_FOR_NN {
-  __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const {
-    return (x_resized + 0.5f) / x_scale;
-  }
-};
-
-struct TransformCoordinate_ALIGN_CORNERS {
-  __device__ __forceinline__ float operator() (float x_resized, float, float length_resized, float length_original, float, float) const {
-    return length_resized == 1 ? 0 : x_resized * (length_original - 1) / (length_resized - 1);
-  }
-};
-
-struct TransformCoordinate_TF_CROP_AND_RESIZE {
-  __device__ __forceinline__ float operator() (float x_resized, float, float length_resized, float length_original, float roi_start, float roi_end) const {
-    auto orig = length_resized > 1
-      ? roi_start * (length_original - 1) + (x_resized * (roi_end - roi_start) * (length_original - 1)) / (length_resized - 1)
-      : 0.5 * (roi_start + roi_end) * (length_original - 1);
-    return static_cast<float>(orig);
-  }
-};
-
 #define CASE_TYPE_USING_HINT(enum_type, type, HINT, ...) \
   case enum_type: {                                      \
     using HINT = type;                                   \
@@ -95,20 +56,24 @@ struct TransformCoordinate_TF_CROP_AND_RESIZE {
 #define CASE_TYPE_COORD(enum_type, type, ...) \
   CASE_TYPE_USING_HINT(enum_type, type, coord_t, __VA_ARGS__)
 
-#define DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(TYPE, ...)                                                                      \
-  [&] {                                                                                                                                \
-    const auto& the_type = TYPE;                                                                                                       \
-    /* don't use TYPE again in case it is an expensive or side-effect op */                                                            \
-    switch (the_type) {                                                                                                                \
-      CASE_TYPE_COORD(ResizeCoordinateTransformationMode::HALF_PIXEL,           TransformCoordinate_HALF_PIXEL, __VA_ARGS__)           \
-      CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ASYMMETRIC,           TransformCoordinate_ASYMMETRIC, __VA_ARGS__)           \
-      CASE_TYPE_COORD(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL,   TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__)   \
-      CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ALIGN_CORNERS,        TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__)        \
-      CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN, TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \
-      CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE,   TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__)   \
-      default:                                                                                                                         \
-        ORT_THROW("unknown ResizeCoordinateTransformationMode");                                                                       \
-    }                                                                                                                                  \
+#define DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(TYPE, ...)                                                  \
+  [&] {                                                                                                            \
+    const auto& the_type = TYPE;                                                                                   \
+    /* don't use TYPE again in case it is an expensive or side-effect op */                                        \
+    switch (the_type) {                                                                                            \
+      CASE_TYPE_COORD(ResizeCoordinateTransformationMode::HALF_PIXEL, TransformCoordinate_HALF_PIXEL, __VA_ARGS__) \
+      CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ASYMMETRIC, TransformCoordinate_ASYMMETRIC, __VA_ARGS__) \
+      CASE_TYPE_COORD(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL,                                      \
+                      TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__)                                         \
+      CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ALIGN_CORNERS,                                           \
+                      TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__)                                              \
+      CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN,                                    \
+                      TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__)                                       \
+      CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE,                                      \
+                      TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__)                                         \
+      default:                                                                                                     \
+        ORT_THROW("unknown ResizeCoordinateTransformationMode");                                                   \
+    }                                                                                                              \
   }()
 
 #define CASE_TYPE_NEAREST(enum_type, type, ...) \
@@ -119,11 +84,11 @@ struct TransformCoordinate_TF_CROP_AND_RESIZE {
     const auto& the_type = TYPE;                                                                             \
     /* don't use TYPE again in case it is an expensive or side-effect op */                                  \
     switch (the_type) {                                                                                      \
-      CASE_TYPE_NEAREST(ResizeNearestMode::SIMPLE,             NearestPixel_SIMPLE, __VA_ARGS__)             \
+      CASE_TYPE_NEAREST(ResizeNearestMode::SIMPLE, NearestPixel_SIMPLE, __VA_ARGS__)                         \
       CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_FLOOR, NearestPixel_ROUND_PREFER_FLOOR, __VA_ARGS__) \
-      CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_CEIL,  NearestPixel_ROUND_PREFER_CEIL, __VA_ARGS__)  \
-      CASE_TYPE_NEAREST(ResizeNearestMode::FLOOR,              NearestPixel_FLOOR, __VA_ARGS__)              \
-      CASE_TYPE_NEAREST(ResizeNearestMode::CEIL,               NearestPixel_CEIL, __VA_ARGS__)               \
+      CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_CEIL, NearestPixel_ROUND_PREFER_CEIL, __VA_ARGS__)   \
+      CASE_TYPE_NEAREST(ResizeNearestMode::FLOOR, NearestPixel_FLOOR, __VA_ARGS__)                           \
+      CASE_TYPE_NEAREST(ResizeNearestMode::CEIL, NearestPixel_CEIL, __VA_ARGS__)                             \
       default:                                                                                               \
         ORT_THROW("unknown ResizeNearestMode");                                                              \
     }                                                                                                        \
@@ -151,10 +116,12 @@ __global__ void _ResizeNearestMappingKernel2D(
 
     // only apply co-ordinate transformation if scale != 1.0
     if (scales_height == 1.0f) {
-        dims_mapping[id].extrapolate_ = 0;
+      dims_mapping[id].extrapolate_ = 0;
     } else {
-      float orig_coord = transform_coordinate(static_cast<float>(dim), scales_height, static_cast<float>(output_height),
-                                              static_cast<float>(input_height), roi_start_height, roi_end_height);
+      float orig_coord = transform_coordinate(static_cast<float>(dim), scales_height,
+                                              static_cast<float>(output_height),
+                                              static_cast<float>(input_height),
+                                              roi_start_height, roi_end_height);
       dims_mapping[id].extrapolate_ = static_cast<int>(
           extrapolation_enabled && (orig_coord < 0.f || orig_coord > static_cast<float>(input_height - 1)));
       dim = calc_nearest_pixel(orig_coord, scales_height < 1);
@@ -210,9 +177,12 @@ __global__ void _ResizeNearestMappingKernel(
       if (scales[axis] == 1.0f) {
         dims_mapping[id].extrapolate_ = 0;
       } else {
-        float orig_coord = transform_coordinate(static_cast<float>(dim), scales[axis], static_cast<float>(output_shape[axis]),
+        float orig_coord = transform_coordinate(static_cast<float>(dim), scales[axis],
+                                                static_cast<float>(output_shape[axis]),
                                                 static_cast<float>(input_shape[axis]), roi[axis], roi[axis + rank]);
-        dims_mapping[id].extrapolate_ = static_cast<int>(extrapolation_enabled && (orig_coord < 0.f || orig_coord > static_cast<float>(input_shape[axis] - 1)));
+        dims_mapping[id].extrapolate_ = static_cast<int>(extrapolation_enabled &&
+                                                         (orig_coord < 0.f ||
+                                                          orig_coord > static_cast<float>(input_shape[axis] - 1)));
         dim = calc_nearest_pixel(orig_coord, scales[axis] < 1);
         if (dim >= input_shape[axis]) dim = input_shape[axis] - 1;
         if (dim < 0) dim = 0;
@@ -293,21 +263,27 @@ __global__ void _ResizeBilinearCoordinateMapping(
     LinearMappingInfo* dims_mapping) {
   CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, SumHW);
   if (id < output_height) {  //  y = id
-    float input_y = scale_height == 1 ? static_cast<float>(id) :
-                                        transform_coordinate(static_cast<float>(id), scale_height,
-                                        static_cast<float>(output_height), static_cast<float>(input_height),
-                                        roi_height_start, roi_height_end);
-    dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_y < 0 || input_y > static_cast<float>(input_height - 1)));
+    float input_y = scale_height == 1 ? static_cast<float>(id)
+                                      : transform_coordinate(static_cast<float>(id), scale_height,
+                                                             static_cast<float>(output_height),
+                                                             static_cast<float>(input_height),
+                                                             roi_height_start, roi_height_end);
+    dims_mapping[id].extrapolate_ = static_cast<int>((extrapolation_enabled &&
+                                                      (input_y < 0 ||
+                                                       input_y > static_cast<float>(input_height - 1))));
     input_y = max(0.0f, min(input_y, static_cast<float>(input_height - 1)));
     int y_int = static_cast<int>(input_y);
     dims_mapping[id].origin_ = y_int;
     dims_mapping[id].weight_ = (y_int >= input_height - 1) ? 0.5f : input_y - y_int;
-  } else {  //x = id - output_height
-    float input_x = scale_width == 1 ? static_cast<float>(id - output_height) :
-                                       transform_coordinate(static_cast<float>(id - output_height), scale_width,
-                                       static_cast<float>(output_width), static_cast<float>(input_width),
-                                       roi_width_start, roi_width_end);
-    dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 || input_x > static_cast<float>(input_width - 1)));
+  } else {  // x = id - output_height
+    float input_x = scale_width == 1 ? static_cast<float>(id - output_height)
+                                     : transform_coordinate(static_cast<float>(id - output_height),
+                                                            scale_width, static_cast<float>(output_width),
+                                                            static_cast<float>(input_width), roi_width_start,
+                                                            roi_width_end);
+    dims_mapping[id].extrapolate_ = static_cast<int>((extrapolation_enabled &&
+                                                      (input_x < 0 ||
+                                                       input_x > static_cast<float>(input_width - 1))));
     input_x = max(0.0f, min(input_x, static_cast<float>(input_width - 1)));
     int x_int = static_cast<int>(input_x);
     dims_mapping[id].origin_ = x_int;
@@ -371,32 +347,40 @@ __global__ void _ResizeTrilinearCoordinateMapping(
     LinearMappingInfo* dims_mapping) {
   CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, SumDHW);
   if (id < output_depth) {  //  z = id
-    float input_z = scale_depth == 1 ? static_cast<float>(id)  :
-                                       transform_coordinate(static_cast<float>(id), scale_depth,
-                                       static_cast<float>(output_depth), static_cast<float>(input_depth),
-                                       roi_depth_start, roi_depth_end);
-    dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_z < 0 || input_z > static_cast<float>(input_depth - 1)));
+    float input_z = scale_depth == 1 ? static_cast<float>(id)
+                                     : transform_coordinate(static_cast<float>(id), scale_depth,
+                                                            static_cast<float>(output_depth),
+                                                            static_cast<float>(input_depth),
+                                                            roi_depth_start, roi_depth_end);
+    dims_mapping[id].extrapolate_ = static_cast<int>((extrapolation_enabled &&
+                                                      (input_z < 0 ||
+                                                       input_z > static_cast<float>(input_depth - 1))));
     input_z = max(0.0f, min(input_z, static_cast<float>(input_depth - 1)));
     int z_int = static_cast<int>(input_z);
     dims_mapping[id].origin_ = z_int;
     dims_mapping[id].weight_ = (z_int >= input_depth - 1) ? 0.5f : input_z - z_int;
   } else if (id >= output_depth && id < (output_depth + output_height)) {  //  y = id - output_depth
-    float input_y = scale_height == 1 ? static_cast<float>(id - output_depth) :
-                                        transform_coordinate(static_cast<float>(id - output_depth), scale_height,
-                                        static_cast<float>(output_height), static_cast<float>(input_height),
-                                        roi_height_start, roi_height_end);
-
-    dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_y < 0 || input_y > static_cast<float>(input_height - 1)));
+    float input_y = scale_height == 1 ? static_cast<float>(id - output_depth)
+                                      : transform_coordinate(static_cast<float>(id - output_depth),
+                                                             scale_height, static_cast<float>(output_height),
+                                                             static_cast<float>(input_height),
+                                                             roi_height_start, roi_height_end);
+
+    dims_mapping[id].extrapolate_ = static_cast<int>((extrapolation_enabled &&
+                                                      (input_y < 0 ||
+                                                       input_y > static_cast<float>(input_height - 1))));
     input_y = max(0.0f, min(input_y, static_cast<float>(input_height - 1)));
     int y_int = static_cast<int>(input_y);
     dims_mapping[id].origin_ = y_int;
     dims_mapping[id].weight_ = (y_int >= input_height - 1) ? 0.5f : input_y - y_int;
-  } else {  //x = id - output_depth - output_height
-    float input_x = scale_width == 1 ? static_cast<float>(id - output_depth - output_height) :
-                                       transform_coordinate(static_cast<float>(id - output_depth - output_height), scale_width,
-                                       static_cast<float>(output_width), static_cast<float>(input_width),
-                                       roi_width_start, roi_width_end);
-    dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 || input_x > static_cast<float>(input_width - 1)));
+  } else {  // x = id - output_depth - output_height
+    float input_x = scale_width == 1 ? static_cast<float>(id - output_depth - output_height)
+                                     : transform_coordinate(static_cast<float>(id - output_depth - output_height),
+                                                            scale_width, static_cast<float>(output_width),
+                                                            static_cast<float>(input_width),
+                                                            roi_width_start, roi_width_end);
+    dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 ||
+                                                                    input_x > static_cast<float>(input_width - 1)));
     input_x = max(0.0f, min(input_x, static_cast<float>(input_width - 1)));
     int x_int = static_cast<int>(input_x);
     dims_mapping[id].origin_ = x_int;
@@ -513,21 +497,33 @@ __global__ void _ResizeCubicCoordinateMapping(
   int max_input_coord = static_cast<int>(is_y_axis ? input_height : input_width);
 
   float scale = is_y_axis ? scale_height : scale_width;
-  float input_coordinat = scale == 1 ? (is_y_axis ? id : id - output_height) :
-      transform_coordinate(
-      static_cast<float>(is_y_axis ? id : id - output_height),
-      scale,
-      static_cast<float>(is_y_axis ? output_height : output_width),
-      static_cast<float>(max_input_coord),
-      (is_y_axis ? roi_height_start : roi_width_start),
-      (is_y_axis ? roi_height_end : roi_width_end));
+  float input_coordinat = scale == 1 ? (is_y_axis ? id : id - output_height)
+                                     : transform_coordinate(
+                                           static_cast<float>(is_y_axis ? id : id - output_height),
+                                           scale,
+                                           static_cast<float>(is_y_axis ? output_height : output_width),
+                                           static_cast<float>(max_input_coord),
+                                           (is_y_axis ? roi_height_start : roi_width_start),
+                                           (is_y_axis ? roi_height_end : roi_width_end));
   int coord_int = static_cast<int>(_Floor(input_coordinat));
   float s_coord = abs(input_coordinat - coord_int);
   float coeff_sum = 1.0f;
-  float coeff_0 = static_cast<float>(((cubic_coeff_a * (s_coord + 1) - 5 * cubic_coeff_a) * (s_coord + 1) + 8 * cubic_coeff_a) * (s_coord + 1) - 4 * cubic_coeff_a);
-  float coeff_1 = static_cast<float>(((cubic_coeff_a + 2) * s_coord - (cubic_coeff_a + 3)) * s_coord * s_coord + 1);
-  float coeff_2 = static_cast<float>(((cubic_coeff_a + 2) * (1 - s_coord) - (cubic_coeff_a + 3)) * (1 - s_coord) * (1 - s_coord) + 1);
-  float coeff_3 = static_cast<float>(((cubic_coeff_a * (2 - s_coord) - 5 * cubic_coeff_a) * (2 - s_coord) + 8 * cubic_coeff_a) * (2 - s_coord) - 4 * cubic_coeff_a);
+  float coeff_0 = static_cast<float>(((cubic_coeff_a * (s_coord + 1) - 5 * cubic_coeff_a) *
+                                          (s_coord + 1) +
+                                      8 * cubic_coeff_a) *
+                                         (s_coord + 1) -
+                                     4 * cubic_coeff_a);
+  float coeff_1 = static_cast<float>(((cubic_coeff_a + 2) * s_coord - (cubic_coeff_a + 3)) *
+                                         s_coord * s_coord +
+                                     1);
+  float coeff_2 = static_cast<float>(((cubic_coeff_a + 2) * (1 - s_coord) - (cubic_coeff_a + 3)) *
+                                         (1 - s_coord) * (1 - s_coord) +
+                                     1);
+  float coeff_3 = static_cast<float>(((cubic_coeff_a * (2 - s_coord) - 5 * cubic_coeff_a) *
+                                          (2 - s_coord) +
+                                      8 * cubic_coeff_a) *
+                                         (2 - s_coord) -
+                                     4 * cubic_coeff_a);
   if (exclude_outside) {
     coeff_0 = (coord_int - 1 < 0 || coord_int - 1 >= max_input_coord) ? 0.0 : coeff_0;
     coeff_1 = (coord_int + 0 < 0 || coord_int + 0 >= max_input_coord) ? 0.0 : coeff_1;
@@ -540,7 +536,8 @@ __global__ void _ResizeCubicCoordinateMapping(
   dm.coeff1_ = coeff_1 / coeff_sum;
   dm.coeff2_ = coeff_2 / coeff_sum;
   dm.coeff3_ = coeff_3 / coeff_sum;
-  dm.extrapolate_ = (int)(extrapolation_enabled && (input_coordinat < 0 || input_coordinat > static_cast<float>(max_input_coord - 1)));
+  dm.extrapolate_ = (int)(extrapolation_enabled && (input_coordinat < 0 ||
+                                                    input_coordinat > static_cast<float>(max_input_coord - 1)));
 }
 
 template <typename T>
@@ -569,21 +566,30 @@ __global__ void _ResizeBiCubicKernel(
   int x_int = x_info.origin_;
   int y_int = y_info.origin_;
   const T* image = input_data + input_index;
-  output_data[id] = y_info.coeff0_ * CubicInterpolationRowwise(image, x_int, y_int - 1, input_height, input_width, w0, w1, w2, w3) +
-                    y_info.coeff1_ * CubicInterpolationRowwise(image, x_int, y_int, input_height, input_width, w0, w1, w2, w3) +
-                    y_info.coeff2_ * CubicInterpolationRowwise(image, x_int, y_int + 1, input_height, input_width, w0, w1, w2, w3) +
-                    y_info.coeff3_ * CubicInterpolationRowwise(image, x_int, y_int + 2, input_height, input_width, w0, w1, w2, w3);
+  output_data[id] = y_info.coeff0_ *
+                        CubicInterpolationRowwise(image, x_int, y_int - 1, input_height, input_width, w0, w1, w2, w3) +
+                    y_info.coeff1_ *
+                        CubicInterpolationRowwise(image, x_int, y_int, input_height, input_width, w0, w1, w2, w3) +
+                    y_info.coeff2_ *
+                        CubicInterpolationRowwise(image, x_int, y_int + 1, input_height, input_width, w0, w1, w2, w3) +
+                    y_info.coeff3_ *
+                        CubicInterpolationRowwise(image, x_int, y_int + 2, input_height, input_width, w0, w1, w2, w3);
 }
 
 size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode,
                             const gsl::span<const int64_t>& output_dims) {
   switch (upsample_mode) {
     case UpsampleMode::NN:
-      return sizeof(int64_t) * output_dims.size() + sizeof(NearestMappingInfo) * static_cast<size_t>(std::accumulate(output_dims.begin(), output_dims.end(), (int64_t)0));
+      return sizeof(int64_t) * output_dims.size() +
+             sizeof(NearestMappingInfo) *
+                 static_cast<size_t>(std::accumulate(output_dims.begin(),
+                                                     output_dims.end(), (int64_t)0));
     case UpsampleMode::LINEAR:
-      return sizeof(LinearMappingInfo) * static_cast<size_t>(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0));
+      return sizeof(LinearMappingInfo) *
+             static_cast<size_t>(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0));
     case UpsampleMode::CUBIC:
-      return sizeof(CubicMappingInfo) * static_cast<size_t>(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0));
+      return sizeof(CubicMappingInfo) *
+             static_cast<size_t>(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0));
   }
   return 0;
 }
@@ -603,7 +609,7 @@ void ResizeNearestImpl(
     const size_t N,
     bool extrapolation_enabled,
     const T extrapolation_value,
-    float cubic_coeff_a,
+    float /*cubic_coeff_a*/,
     ResizeCoordinateTransformationMode transform_coordinate,
     ResizeNearestMode calc_nearest_pixel,
     int64_t* /* prefix_dim_sum */,
@@ -616,7 +622,8 @@ void ResizeNearestImpl(
   if (could2d) {
     int64_t output_height = output_shape[rank - 2];
     int64_t output_width = output_shape[rank - 1];
-    fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 3] : fast_divmod(static_cast<int>(output_height * output_width));
+    fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 3]
+                                              : fast_divmod(static_cast<int>(output_height * output_width));
     int blocksPerDimsMappingGrid = static_cast<int>(ceil((output_height + output_width) / 32.0));
 
     DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(transform_coordinate, [&]() {
@@ -694,13 +701,6 @@ void ResizeImpl(
     ResizeCoordinateTransformationMode coordinate_transform_mode,
     ResizeNearestMode nearest_mode,
     void* dims_mapping) {
-  bool isSame = std::all_of(scales_vals.Data(), scales_vals.Data() + rank, [](float v) { return v == 1.0f; }) &&
-                (coordinate_transform_mode != ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE);
-  if (isSame) {
-    CUDA_CALL_THROW(cudaMemcpyAsync(output_data, input_data, N * sizeof(T), cudaMemcpyDeviceToDevice, stream));
-    return;
-  }
-
   if (upsample_mode == UpsampleMode::NN) {
     ResizeNearestImpl(
         stream, rank, input_shape, output_shape, input_strides, output_div_pitches,
@@ -761,7 +761,7 @@ void ResizeImpl(
       } else if (is_3D) {
         DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(coordinate_transform_mode, [&]() {
           _ResizeTrilinearCoordinateMapping<T><<<blocksPerDimsMappingGrid, 32, 0, stream>>>(
-              input_shape[rank - 3] , input_shape[rank - 2], input_shape[rank - 1],
+              input_shape[rank - 3], input_shape[rank - 2], input_shape[rank - 1],
               output_depth, output_height, output_width,
               scales_vals[rank - 3], scales_vals[rank - 2], scales_vals[rank - 1],
               roi_vals[rank - 3], roi_vals[rank - 3 + rank],
@@ -778,7 +778,7 @@ void ResizeImpl(
             reinterpret_cast<LinearMappingInfo*>(dims_mapping));
         return;
       }
-      ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize");
+      ORT_THROW("Resize support 2-D and 3-D dimensions in LINEAR mode.");
       break;
     case UpsampleMode::CUBIC:
       if (is_2D) {
@@ -801,7 +801,7 @@ void ResizeImpl(
             reinterpret_cast<CubicMappingInfo*>(dims_mapping));
         return;
       }
-      ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize");
+      ORT_THROW("Resize supports only 2-D in CUBIC mode.");
     case UpsampleMode::NN:
       ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize");
   }
@@ -809,7 +809,7 @@ void ResizeImpl(
 
 #define SPECIALIZED_IMPL(T)                                         \
   template void ResizeImpl<T>(                                      \
-      cudaStream_t stream,                                    \
+      cudaStream_t stream,                                          \
       const UpsampleMode upsample_mode,                             \
       const int rank,                                               \
       TArray<int64_t>& input_shape,                                 \
diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.h b/onnxruntime/core/providers/cuda/tensor/resize_impl.h
index d459dbff18d3e..ad06eebb9efb1 100644
--- a/onnxruntime/core/providers/cuda/tensor/resize_impl.h
+++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.h
@@ -2,15 +2,69 @@
 // Licensed under the MIT License.
 
 #pragma once
+
 #include <stdint.h>
+
+#include <tuple>
+
 #include "core/providers/cuda/shared_inc/cuda_utils.h"
 #include "core/common/common.h"
 #include "core/providers/cpu/tensor/upsamplebase.h"
 #include "core/providers/cuda/cuda_common.h"
 
 namespace onnxruntime {
+template <>
+struct AccumulateType<half> {
+  using type = float;
+};
 namespace cuda {
 
+struct TransformCoordinate_ASYMMETRIC {
+  __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale,
+                                                       float, float, float, float) const {
+    return x_resized / x_scale;
+  }
+};
+
+struct TransformCoordinate_HALF_PIXEL {
+  __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale,
+                                                       float, float, float, float) const {
+    return ((x_resized + 0.5f) / x_scale) - 0.5f;
+  }
+};
+
+struct TransformCoordinate_PYTORCH_HALF_PIXEL {
+  __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, float length_resized, float,
+                                                       float, float) const {
+    return length_resized > 1 ? (x_resized + 0.5f) / x_scale - 0.5f : 0.0f;
+  }
+};
+
+struct TransformCoordinate_TF_HALF_PIXEL_FOR_NN {
+  __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale,
+                                                       float, float, float, float) const {
+    return (x_resized + 0.5f) / x_scale;
+  }
+};
+
+struct TransformCoordinate_ALIGN_CORNERS {
+  __device__ __host__ __forceinline__ float operator()(float x_resized, float, float length_resized,
+                                                       float length_original, float, float) const {
+    return length_resized == 1 ? 0 : x_resized * (length_original - 1) / (length_resized - 1);
+  }
+};
+
+struct TransformCoordinate_TF_CROP_AND_RESIZE {
+  __device__ __host__ __forceinline__ float operator()(float x_resized, float, float length_resized,
+                                                       float length_original, float roi_start, float roi_end) const {
+    auto orig = length_resized > 1
+                    ? roi_start * (length_original - 1) +
+                          (x_resized * (roi_end - roi_start) * (length_original - 1)) / (length_resized - 1)
+                    : 0.5 * (roi_start + roi_end) * (length_original - 1);
+    return static_cast<float>(orig);
+  }
+};
+
 size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode,
                             const gsl::span<const int64_t>& output_dims);
 
@@ -36,5 +90,62 @@ void ResizeImpl(
     onnxruntime::ResizeNearestMode nearest_mode,
     void* dims_mapping);
 
+using TempSpaceAllocateFunc = std::function<onnxruntime::IAllocatorUniquePtr<uint8_t>(size_t buffer_size)>;
+
+template <class T>
+void ResizeAntiAliasImpl(
+    cudaStream_t stream,
+    int rank,
+    const UpsampleMode upsample_mode,
+    ResizeCoordinateTransformationMode coordinate_transform_mode,
+    gsl::span<const int64_t> input_shape,
+    gsl::span<const int64_t> output_shape,
+    int64_t batch_size, int64_t num_channels,
+    std::tuple<int64_t, int64_t, int64_t> inferred_input_dims,
+    std::tuple<int64_t, int64_t, int64_t> inferred_output_dims,
+    std::tuple<float, float, float> inferred_dim_rscales,
+    const TArray<fast_divmod>& output_div_pitches,
+    gsl::span<const float> roi_vals,  // CPU
+    const std::optional<float>& extrapolation_value,
+    bool exclude_outside,
+    TempSpaceAllocateFunc allocate_temp_space,
+    const uint8_t* clip8_lookups,
+    const T* input_data,
+    T* output_data,
+    const size_t N);
+
+/// <summary>
+/// Compute scaled support value for a given dimension inverse scale
+/// </summary>
+/// <param name="support_value">Support value from parameters</param>
+/// <param name="inv_scale">inverse scale value comes from input/attr for</param>
+/// <returns></returns>
+inline float ComputeScaledSupportValue(float support_value, float rscale) {
+  const float scale = 1.0f / rscale;
+  float scaled_support = (scale >= 1.0f) ? (support_value * 0.5f) * scale : support_value * 0.5f;
+  return scaled_support;
+}
+
+/// <summary>
+/// Compute window size for a given dimension scaled support value.
+/// </summary>
+/// <param name="scaled_support"></param>
+/// <returns></returns>
+inline int32_t ComputeWindowSize(float scaled_support) {
+  SafeInt<int32_t> window_size(ceilf(scaled_support));
+  return window_size * 2 + 1;
+}
+
+/// <summary>
+/// Computes scale buffer size in number of elements for allocation purposes.
+/// </summary>
+/// <param name="output_size"></param>
+/// <param name="window_size"></param>
+/// <returns>Number of elements to fit in the buffer</returns>
+inline SafeInt<int64_t> ComputeWeightedCoeffBufferSize(int64_t output_size, int32_t window_size) {
+  SafeInt<int64_t> buffer_size(output_size);
+  return buffer_size * window_size;
+}
+
 }  // namespace cuda
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc
index 407a2ef3981f1..aaaf3600b676e 100644
--- a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc
+++ b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc
@@ -20,7 +20,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
                         {DataTypeImpl::GetTensorType<float>(),
                          DataTypeImpl::GetTensorType<double>(),
                          DataTypeImpl::GetTensorType<MLFloat16>()}),
-    SpaceToDepth);
+    SpaceToDepth<LAYOUT_NCHW>);
+
+#ifdef ENABLE_CUDA_NHWC_OPS
+ONNX_OPERATOR_VERSIONED_KERNEL_EX(
+    SpaceToDepth,
+    kMSInternalNHWCDomain,
+    1,
+    12,
+    kCudaExecutionProvider,
+    (*KernelDefBuilder::Create())
+        .TypeConstraint("T",
+                        {DataTypeImpl::GetTensorType<float>(),
+                         DataTypeImpl::GetTensorType<double>(),
+                         DataTypeImpl::GetTensorType<MLFloat16>()}),
+    SpaceToDepth<LAYOUT_NHWC>);
+#endif
 
 ONNX_OPERATOR_KERNEL_EX(
     SpaceToDepth,
@@ -32,7 +47,21 @@ ONNX_OPERATOR_KERNEL_EX(
                         {DataTypeImpl::GetTensorType<float>(),
                          DataTypeImpl::GetTensorType<double>(),
                          DataTypeImpl::GetTensorType<MLFloat16>()}),
-    SpaceToDepth);
+    SpaceToDepth<LAYOUT_NCHW>);
+
+#ifdef ENABLE_CUDA_NHWC_OPS
+ONNX_OPERATOR_KERNEL_EX(
+    SpaceToDepth,
+    kMSInternalNHWCDomain,
+    13,
+    kCudaExecutionProvider,
+    (*KernelDefBuilder::Create())
+        .TypeConstraint("T",
+                        {DataTypeImpl::GetTensorType<float>(),
+                         DataTypeImpl::GetTensorType<double>(),
+                         DataTypeImpl::GetTensorType<MLFloat16>()}),
+    SpaceToDepth<LAYOUT_NHWC>);
+#endif
 
 ONNX_OPERATOR_VERSIONED_KERNEL_EX(
     DepthToSpace,
@@ -45,7 +74,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
                         {DataTypeImpl::GetTensorType<float>(),
                          DataTypeImpl::GetTensorType<double>(),
                          DataTypeImpl::GetTensorType<MLFloat16>()}),
-    DepthToSpace);
+    DepthToSpace<LAYOUT_NCHW>);
+
+#ifdef ENABLE_CUDA_NHWC_OPS
+ONNX_OPERATOR_VERSIONED_KERNEL_EX(
+    DepthToSpace,
+    kMSInternalNHWCDomain,
+    1,
+    10,
+    kCudaExecutionProvider,
+    (*KernelDefBuilder::Create())
+        .TypeConstraint("T",
+                        {DataTypeImpl::GetTensorType<float>(),
+                         DataTypeImpl::GetTensorType<double>(),
+                         DataTypeImpl::GetTensorType<MLFloat16>()}),
+    DepthToSpace<LAYOUT_NHWC>);
+#endif
 
 ONNX_OPERATOR_VERSIONED_KERNEL_EX(
     DepthToSpace,
@@ -58,7 +102,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
                         {DataTypeImpl::GetTensorType<float>(),
                          DataTypeImpl::GetTensorType<double>(),
                          DataTypeImpl::GetTensorType<MLFloat16>()}),
-    DepthToSpace);
+    DepthToSpace<LAYOUT_NCHW>);
+
+#ifdef ENABLE_CUDA_NHWC_OPS
+ONNX_OPERATOR_VERSIONED_KERNEL_EX(
+    DepthToSpace,
+    kMSInternalNHWCDomain,
+    11,
+    12,
+    kCudaExecutionProvider,
+    (*KernelDefBuilder::Create())
+        .TypeConstraint("T",
+                        {DataTypeImpl::GetTensorType<float>(),
+                         DataTypeImpl::GetTensorType<double>(),
+                         DataTypeImpl::GetTensorType<MLFloat16>()}),
+    DepthToSpace<LAYOUT_NHWC>);
+#endif
 
 ONNX_OPERATOR_KERNEL_EX(
     DepthToSpace,
@@ -70,23 +129,35 @@ ONNX_OPERATOR_KERNEL_EX(
                         {DataTypeImpl::GetTensorType<float>(),
                          DataTypeImpl::GetTensorType<double>(),
                          DataTypeImpl::GetTensorType<MLFloat16>()}),
-    DepthToSpace);
+    DepthToSpace<LAYOUT_NCHW>);
+
+#ifdef ENABLE_CUDA_NHWC_OPS
+ONNX_OPERATOR_KERNEL_EX(
+    DepthToSpace,
+    kMSInternalNHWCDomain,
+    13,
+    kCudaExecutionProvider,
+    (*KernelDefBuilder::Create())
+        .TypeConstraint("T",
+                        {DataTypeImpl::GetTensorType<float>(),
+                         DataTypeImpl::GetTensorType<double>(),
+                         DataTypeImpl::GetTensorType<MLFloat16>()}),
+    DepthToSpace<LAYOUT_NHWC>);
+#endif
 
 static Status SpaceDepthOpCudaImpl(const cudaDeviceProp& prop,
                                    cudaStream_t stream,
                                    const cublasHandle_t cublas_handle,
                                    const Tensor& input, Tensor& output,
                                    const std::vector<size_t>& permutation,
-                                   const int64_t batch_size,
-                                   const int64_t in_dim1, const int64_t in_dim2, const int64_t in_dim3,
-                                   const int64_t in_dim4, const int64_t in_dim5,
+                                   const TensorShape& virtual_input_shape,
                                    const TensorShape& virtual_output_shape) {
-  TensorShape virtual_input_shape{batch_size, in_dim1, in_dim2, in_dim3, in_dim4, in_dim5};
   return Transpose::DoTranspose(prop, stream, cublas_handle, permutation, input, output,
                                 &virtual_input_shape, &virtual_output_shape);
 }
 
-Status SpaceToDepth::ComputeInternal(OpKernelContext* context) const {
+template <bool Layout>
+Status SpaceToDepth<Layout>::ComputeInternal(OpKernelContext* context) const {
   const auto* tensor_pointer = context->Input<Tensor>(0);
   if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
   const Tensor& input = *tensor_pointer;
@@ -101,29 +172,44 @@ Status SpaceToDepth::ComputeInternal(OpKernelContext* context) const {
   int64_t output_height = -1;
   int64_t output_width = -1;
 
-  ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc(input,
-                                                        batch,
-                                                        input_depth, input_height, input_width,
-                                                        output_depth, output_height, output_width,
-                                                        true));
+  ORT_RETURN_IF_ERROR(
+      InputValidationsAndOutputDimsCalc<Layout == LAYOUT_NHWC>(input,
+                                                               batch,
+                                                               input_depth, input_height, input_width,
+                                                               output_depth, output_height, output_width,
+                                                               true));
 
   // We use the "actual" output shape to construct the output tensor
-  Tensor& output = *context->Output(0, {batch, output_depth, output_height, output_width});
+  Tensor& output = (Layout == LAYOUT_NCHW)
+                       ? *context->Output(0, {batch, output_depth, output_height, output_width})
+                       : *context->Output(0, {batch, output_height, output_width, output_depth});
+
+  TensorShape virtual_input_shape = (Layout == LAYOUT_NCHW)
+                                        ? TensorShape{batch, input_depth, input_height / blocksize_,
+                                                      blocksize_, input_width / blocksize_, blocksize_}
+                                        : TensorShape{batch, input_height / blocksize_, blocksize_,
+                                                      input_width / blocksize_, blocksize_, input_depth};
 
   // We will pass in the "virtual" output shape to be used by DoTranspose() in SpaceDepthOpCudaImpl(...)
-  TensorShape virtual_output_shape{batch, blocksize_, blocksize_, input_depth,
-                                   input_height / blocksize_, input_width / blocksize_};
+  TensorShape virtual_output_shape = (Layout == LAYOUT_NCHW)
+                                         ? TensorShape{batch, blocksize_, blocksize_, input_depth,
+                                                       input_height / blocksize_, input_width / blocksize_}
+                                         : TensorShape{batch, input_height / blocksize_, input_width / blocksize_,
+                                                       blocksize_, blocksize_, input_depth};
 
-  std::vector<size_t> permutation = {0, 3, 5, 1, 2, 4};
+  std::vector<size_t> permutation = (Layout == LAYOUT_NCHW)
+                                        ? std::vector<size_t>{0, 3, 5, 1, 2, 4}
+                                        : std::vector<size_t>{0, 1, 3, 2, 4, 5};
 
-  ORT_RETURN_IF_ERROR(SpaceDepthOpCudaImpl(GetDeviceProp(), Stream(context), GetCublasHandle(context), input, output, permutation, batch,
-                                           input_depth, input_height / blocksize_, blocksize_, input_width / blocksize_, blocksize_,
-                                           virtual_output_shape));
+  ORT_RETURN_IF_ERROR(
+      SpaceDepthOpCudaImpl(GetDeviceProp(), Stream(context), GetCublasHandle(context), input, output, permutation,
+                           virtual_input_shape, virtual_output_shape));
 
   return Status::OK();
 }
 
-Status DepthToSpace::ComputeInternal(OpKernelContext* context) const {
+template <bool Layout>
+Status DepthToSpace<Layout>::ComputeInternal(OpKernelContext* context) const {
   const auto* tensor_pointer = context->Input<Tensor>(0);
   if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
   const Tensor& input = *tensor_pointer;
@@ -138,46 +224,56 @@ Status DepthToSpace::ComputeInternal(OpKernelContext* context) const {
   int64_t output_height = -1;
   int64_t output_width = -1;
 
-  ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc(input,
-                                                        batch,
-                                                        input_depth, input_height, input_width,
-                                                        output_depth, output_height, output_width,
-                                                        false));
+  ORT_RETURN_IF_ERROR(
+      InputValidationsAndOutputDimsCalc<Layout == LAYOUT_NHWC>(input,
+                                                               batch,
+                                                               input_depth, input_height, input_width,
+                                                               output_depth, output_height, output_width,
+                                                               false));
 
   // We use the "actual" output shape to construct the output tensor
-  Tensor& output = *context->Output(0, {batch, output_depth, output_height, output_width});
+  Tensor& output = (Layout == LAYOUT_NCHW)
+                       ? *context->Output(0, {batch, output_depth, output_height, output_width})
+                       : *context->Output(0, {batch, output_height, output_width, output_depth});
+
+  int64_t virtual_input_depth = input_depth / blocksize_ / blocksize_;
+  TensorShape virtual_input_shape;
+
+  // cdr only here!
+  if (is_dcr_) {
+    virtual_input_shape = (Layout == LAYOUT_NCHW)
+                              ? TensorShape{batch, blocksize_, blocksize_,
+                                            virtual_input_depth, input_height, input_width}
+                              : TensorShape{batch, input_height, input_width,
+                                            blocksize_, blocksize_, virtual_input_depth};
+  } else {
+    virtual_input_shape = (Layout == LAYOUT_NCHW)
+                              ? TensorShape{batch, virtual_input_depth, blocksize_,
+                                            blocksize_, input_height, input_width}
+                              : TensorShape{batch, input_height, input_width,
+                                            virtual_input_depth, blocksize_, blocksize_};
+  }
 
   // We will pass in the "virtual" output shape to be used by DoTranspose() in SpaceDepthOpCudaImpl(...)
-  TensorShape virtual_output_shape{batch, input_depth / blocksize_ / blocksize_,
-                                   input_height, blocksize_, input_width, blocksize_};
+  TensorShape virtual_output_shape = (Layout == LAYOUT_NCHW)
+                                         ? TensorShape{batch, virtual_input_depth, input_height,
+                                                       blocksize_, input_width, blocksize_}
+                                         : TensorShape{batch, input_height, blocksize_,
+                                                       input_width, blocksize_, virtual_input_depth};
 
   std::vector<size_t> permutation;
-  permutation.reserve(6);
-  permutation.push_back(0);
 
   if (is_dcr_) {
-    permutation.push_back(3);
-    permutation.push_back(4);
-    permutation.push_back(1);
-    permutation.push_back(5);
-    permutation.push_back(2);
+    permutation = (Layout == LAYOUT_NCHW)
+                      ? std::vector<size_t>({0, 3, 4, 1, 5, 2})
+                      : std::vector<size_t>({0, 1, 3, 2, 4, 5});
 
   } else {
-    permutation.push_back(1);
-    permutation.push_back(4);
-    permutation.push_back(2);
-    permutation.push_back(5);
-    permutation.push_back(3);
+    permutation = std::vector<size_t>({0, 1, 4, 2, 5, 3});
   }
 
-  int64_t dim1 = is_dcr_ ? blocksize_ : input_depth / blocksize_ / blocksize_;
-  int64_t dim3 = is_dcr_ ? input_depth / blocksize_ / blocksize_ : blocksize_;
-
   ORT_RETURN_IF_ERROR(SpaceDepthOpCudaImpl(GetDeviceProp(), Stream(context), GetCublasHandle(context), input, output,
-                                           permutation,
-                                           batch,
-                                           dim1, blocksize_, dim3, input_height, input_width,
-                                           virtual_output_shape));
+                                           permutation, virtual_input_shape, virtual_output_shape));
 
   return Status::OK();
 }
diff --git a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h
index 57b85556f1dbe..8780d9b365005 100644
--- a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h
+++ b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h
@@ -9,6 +9,7 @@
 namespace onnxruntime {
 namespace cuda {
 
+template <bool Layout>
 class SpaceToDepth final : public CudaKernel, SpaceDepthBase {
  public:
   explicit SpaceToDepth(const OpKernelInfo& info) : CudaKernel(info), SpaceDepthBase(info) {
@@ -17,6 +18,7 @@ class SpaceToDepth final : public CudaKernel, SpaceDepthBase {
   Status ComputeInternal(OpKernelContext* context) const override;
 };
 
+template <bool Layout>
 class DepthToSpace final : public CudaKernel, SpaceDepthBase {
  public:
   explicit DepthToSpace(const OpKernelInfo& info) : CudaKernel(info), SpaceDepthBase(info) {
diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu
index 9f9c365d2a53d..6344845359b32 100644
--- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu
+++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu
@@ -80,7 +80,7 @@ bool CanDoTranspose3D(const cudaDeviceProp& prop, size_t rank, const gsl::span<c
   } break
 
 Status Transpose3DImpl(cudaStream_t stream, size_t element_size, const TArray<int64_t>& input_shape,
-                       const TArray<int64_t>& input_strides, const void* input_data, void* output_data, int64_t N,
+                       const TArray<int64_t>& input_strides, const void* input_data, void* output_data, int64_t /*N*/,
                        const dim3& grid_size, const dim3& block_size) {
   switch (element_size) {
     HANDLE_TRANSPOSE_3D_TILE_DIM(int8_t);
@@ -248,10 +248,10 @@ __global__ void Transpose4DKernelParallelizeOneElementPerThread(
 }
 
 bool CanDoTranspose4DParallelizeOneElementPerThread(const cudaDeviceProp& prop,
-                                                    size_t element_size,
+                                                    size_t /*element_size*/,
                                                     int32_t rank,
                                                     const gsl::span<const int64_t>& input_dims,
-                                                    const gsl::span<const size_t>& permutations,
+                                                    const gsl::span<const size_t>& /*permutations*/,
                                                     dim3& grid_size, dim3& block_size) {
   if (rank == 4) {
     // dims[3]: block.x
diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc
index ae12ca328bc7c..17533eb3d9a72 100644
--- a/onnxruntime/core/providers/cuda/tensor/upsample.cc
+++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc
@@ -2,6 +2,9 @@
 // Licensed under the MIT License.
 
 #include "upsample.h"
+
+#include <utility>
+
 #include "upsample_impl.h"
 #include "core/providers/cuda/tensor/resize_impl.h"
 #include "core/providers/cpu/tensor/utils.h"
@@ -37,11 +40,23 @@ REGISTER_VERSIONED_TYPED_KERNEL(MLFloat16, 9, 9);
 REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9);
 REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9);
 
+template <typename T>
+Upsample<T>::Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) {
+  if (UpsampleBase::antialias_) {
+    // Copy the table on DEVICE
+    const uint8_t* lookup_table = GetLookupTableShared();
+    auto alloc = info.GetAllocator(OrtMemTypeDefault);
+    shared_lookup_table_ondevice_ = IAllocator::MakeUniquePtr<uint8_t>(std::move(alloc), kLookupTableSize);
+    CUDA_CALL_THROW(cudaMemcpyAsync(shared_lookup_table_ondevice_.get(), lookup_table, kLookupTableSize,
+                                    cudaMemcpyHostToDevice, nullptr));
+  }
+}
+
 template <typename T>
 Status Upsample<T>::BaseCompute(OpKernelContext* context,
-                                const std::vector<float>& roi,
-                                const std::vector<float>& scales,
-                                const gsl::span<const int64_t>& output_dims) const {
+                                gsl::span<const float> roi,
+                                gsl::span<const float> scales,
+                                gsl::span<const int64_t> output_dims) const {
   const Tensor* X = context->Input<Tensor>(0);
   auto X_dims = X->Shape().GetDims();
   int32_t rank = static_cast<int32_t>(X_dims.size());
@@ -52,7 +67,8 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context,
                   is_resize_ ? "Resize: input tensor cannot be scalar." : "Upsample: input tensor cannot be scalar.");
   if (rank != static_cast<int32_t>(scales.size()))
     return Status(ONNXRUNTIME, INVALID_ARGUMENT,
-                  is_resize_ ? "Resize: input tensor's dimension does not match the scales." : "Upsample: input tensor's dimension does not match the scales.");
+                  is_resize_ ? "Resize: input tensor's dimension does not match the scales."
+                             : "Upsample: input tensor's dimension does not match the scales.");
   if (roi.size() != 2 * X_dims.size())
     return Status(ONNXRUNTIME, INVALID_ARGUMENT,
                   "Resize: size of roi array should be 2 * N where N is the rank of input tensor X.");
@@ -79,22 +95,194 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context,
   size_t output_count = Y->Shape().Size();
 
   if (is_resize_) {
-    TArray<int64_t> input_shape(X_dims);
-    TArray<int64_t> output_shape(output_dims);
-    TArray<float, 10> roi_vals(roi);
-    TArray<float> scales_vals(scales);
-
-    size_t temp_buffer_size = CalcResizeBufferSize(mode_, output_dims);
-    auto dims_mapping_buffer = GetScratchBuffer<unsigned char>(temp_buffer_size, context->GetComputeStream());
-    void* dims_mapping = reinterpret_cast<void*>(dims_mapping_buffer.get());
-    ResizeImpl(Stream(context), mode_, (int)rank, input_shape, output_shape,
-               input_strides, output_div_pitches, scales_vals, roi_vals,
-               reinterpret_cast<const CudaT*>(X->Data<T>()),
-               reinterpret_cast<CudaT*>(Y->MutableData<T>()),
-               output_count, use_extrapolation_, ToCudaType<T>::FromFloat(extrapolation_value_),
-               cubic_coeff_a_, exclude_outside_,
-               coordinate_transform_mode_, nearest_mode_,
-               dims_mapping);
+    const bool is_same = std::all_of(scales.begin(), scales.end(), [](float v) { return v == 1.0f; }) &&
+                         (coordinate_transform_mode_ != ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE);
+    if (is_same) {
+      CUDA_CALL_THROW(cudaMemcpyAsync(Y->MutableData<T>(), X->Data<T>(),
+                                      output_count * sizeof(T), cudaMemcpyDeviceToDevice, Stream(context)));
+      return Status::OK();
+    }
+
+    if (antialias_) {
+      TempSpaceAllocateFunc allocate_temp_space = [&](size_t bytes_size) {
+        return GetScratchBuffer<uint8_t>(bytes_size, context->GetComputeStream());
+      };
+
+      std::optional<float> extrapolation_value;
+      if (use_extrapolation_)
+        extrapolation_value.emplace(extrapolation_value_);
+
+      switch (mode_) {
+        case UpsampleMode::LINEAR: {
+          if (X_dims.size() == 2 || X_dims.size() == 4) {
+            const bool is_2D = X_dims.size() == 2;
+
+            int64_t batch_size = 1;
+            int64_t num_channels = 1;
+
+            int64_t input_height;
+            int64_t input_width;
+
+            int64_t output_height;
+            int64_t output_width;
+
+            float height_scale;
+            float width_scale;
+
+            if (is_2D) {
+              input_height = X_dims[0];
+              input_width = X_dims[1];
+
+              output_height = output_dims[0];
+              output_width = output_dims[1];
+
+              height_scale = scales[0];
+              width_scale = scales[1];
+            } else {
+              if (scales[0] == 1.0f && scales[1] == 1.0f) {
+                batch_size = X_dims[Channels<LAYOUT_NCHW>::N];
+                num_channels = X_dims[Channels<LAYOUT_NCHW>::C];
+                input_height = X_dims[Channels<LAYOUT_NCHW>::H];
+                input_width = X_dims[Channels<LAYOUT_NCHW>::W];
+
+                output_height = output_dims[Channels<LAYOUT_NCHW>::H];
+                output_width = output_dims[Channels<LAYOUT_NCHW>::W];
+
+                height_scale = scales[2];
+                width_scale = scales[3];
+              } else {
+                return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", ": NHWC is not supported yet");
+              }
+            }
+
+            ResizeAntiAliasImpl(Stream(context),
+                                rank,
+                                mode_,
+                                coordinate_transform_mode_,
+                                X_dims, output_dims,
+                                batch_size, num_channels,
+                                std::make_tuple(0, input_height, input_width),
+                                std::make_tuple(0, output_height, output_width),
+                                std::make_tuple(0.f, height_scale, width_scale),
+                                output_div_pitches,
+                                roi,
+                                extrapolation_value,
+                                exclude_outside_,
+                                allocate_temp_space,
+                                shared_lookup_table_ondevice_.get(),
+                                reinterpret_cast<const CudaT*>(X->Data<T>()),
+                                reinterpret_cast<CudaT*>(Y->MutableData<T>()),
+                                output_count);
+
+          } else if (X_dims.size() == 3 || X_dims.size() == 5) {
+            const bool is_3D = X_dims.size() == 3;
+
+            if (!is_3D) {
+              if (!(scales[0] == 1.0f && scales[1] == 1.0f)) {
+                return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", ": NDHWC is not supported yet");
+              }
+            }
+
+            const int64_t batch_size = is_3D ? 1 : X_dims[0];
+            const int64_t num_channels = is_3D ? 1 : X_dims[1];
+            const int64_t input_depth = is_3D ? X_dims[0] : X_dims[2];
+            const int64_t input_height = is_3D ? X_dims[1] : X_dims[3];
+            const int64_t input_width = is_3D ? X_dims[2] : X_dims[4];
+
+            const int64_t output_depth = is_3D ? output_dims[0] : output_dims[2];
+            const int64_t output_height = is_3D ? output_dims[1] : output_dims[3];
+            const int64_t output_width = is_3D ? output_dims[2] : output_dims[4];
+
+            const float depth_scale = is_3D ? scales[0] : scales[2];
+            const float height_scale = is_3D ? scales[1] : scales[3];
+            const float width_scale = is_3D ? scales[2] : scales[4];
+
+            ResizeAntiAliasImpl(Stream(context),
+                                rank,
+                                mode_,
+                                coordinate_transform_mode_,
+                                X_dims, output_dims,
+                                batch_size, num_channels,
+                                std::make_tuple(input_depth, input_height, input_width),
+                                std::make_tuple(output_depth, output_height, output_width),
+                                std::make_tuple(depth_scale, height_scale, width_scale),
+                                output_div_pitches,
+                                roi,
+                                extrapolation_value,
+                                exclude_outside_,
+                                allocate_temp_space,
+                                shared_lookup_table_ondevice_.get(),
+                                reinterpret_cast<const CudaT*>(X->Data<T>()),
+                                reinterpret_cast<CudaT*>(Y->MutableData<T>()),
+                                output_count);
+          } else {
+            return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize",
+                                   ": 'Linear' mode only support 2-D inputs or 3-D inputs ('Bilinear', 'Trilinear') "
+                                   "or 4-D inputs or 5-D inputs with the corresponding outermost 2 scale values "
+                                   "being 1.");
+          }
+        } break;
+        case UpsampleMode::CUBIC: {
+          if (X_dims.size() != 2 && X_dims.size() != 4) {
+            return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize",
+                                   ": 'Cubic' mode only support 2-D inputs ('Bicubic') or 4-D inputs "
+                                   "with the corresponding outermost 2 scale values being 1.");
+          }
+
+          const bool is_2D = X_dims.size() == 2;
+          const bool is_nchw = is_2D ? true : (scales[1] == 1.0f && scales[1] == 1.0f);
+
+          ORT_RETURN_IF_NOT(is_nchw,
+                            "Resize 'Cubic' mode only supports NCWH layout "
+                            " with 2-D or 4-D with leading dims equal to 1");
+
+          const int64_t batch_size = is_2D ? 1 : X_dims[Channels<LAYOUT_NCHW>::N];
+          const int64_t num_channels = is_2D ? 1 : X_dims[Channels<LAYOUT_NCHW>::C];
+          const int64_t input_height = is_2D ? X_dims[0] : X_dims[Channels<LAYOUT_NCHW>::H];
+          const int64_t input_width = is_2D ? X_dims[1] : X_dims[Channels<LAYOUT_NCHW>::W];
+
+          const int64_t output_height = is_2D ? output_dims[0] : output_dims[Channels<LAYOUT_NCHW>::H];
+          const int64_t output_width = is_2D ? output_dims[1] : output_dims[Channels<LAYOUT_NCHW>::W];
+          const float height_scale = is_2D ? scales[0] : scales[2];
+          const float width_scale = is_2D ? scales[1] : scales[3];
+
+          ResizeAntiAliasImpl(Stream(context), rank, mode_, coordinate_transform_mode_,
+                              X_dims, output_dims,
+                              batch_size, num_channels,
+                              std::make_tuple(0, input_height, input_width),
+                              std::make_tuple(0, output_height, output_width),
+                              std::make_tuple(0.f, height_scale, width_scale),
+                              output_div_pitches,
+                              roi,
+                              extrapolation_value,
+                              exclude_outside_,
+                              allocate_temp_space,
+                              shared_lookup_table_ondevice_.get(),
+                              reinterpret_cast<const CudaT*>(X->Data<T>()),
+                              reinterpret_cast<CudaT*>(Y->MutableData<T>()),
+                              output_count);
+        } break;
+        default:
+          return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Resize: unexpected mode");
+      }
+    } else {
+      TArray<int64_t> input_shape(X_dims);
+      TArray<int64_t> output_shape(output_dims);
+      TArray<float, 10> roi_vals(roi);
+      TArray<float> scales_vals(scales);
+
+      size_t temp_buffer_size = CalcResizeBufferSize(mode_, output_dims);
+      auto dims_mapping_buffer = GetScratchBuffer<unsigned char>(temp_buffer_size, context->GetComputeStream());
+      void* dims_mapping = reinterpret_cast<void*>(dims_mapping_buffer.get());
+      ResizeImpl(Stream(context), mode_, rank, input_shape, output_shape,
+                 input_strides, output_div_pitches, scales_vals, roi_vals,
+                 reinterpret_cast<const CudaT*>(X->Data<T>()),
+                 reinterpret_cast<CudaT*>(Y->MutableData<T>()),
+                 output_count, use_extrapolation_, ToCudaType<T>::FromFloat(extrapolation_value_),
+                 cubic_coeff_a_, exclude_outside_,
+                 coordinate_transform_mode_, nearest_mode_,
+                 dims_mapping);
+    }
   } else {
     TArray<fast_divmod> scales_div(rank);
 
@@ -124,7 +312,7 @@ Status Upsample<T>::ComputeInternal(OpKernelContext* context) const {
   auto input_dims = X->Shape().GetDims();
 
   TensorShapeVector output_dims(input_dims.size());
-  std::vector<float> roi_array(input_dims.size() * 2, 0.0f);
+  InlinedVector<float> roi_array(input_dims.size() * 2, 0.0f);
   if (!roi_cached_) {
     bool use_default_roi = true;
     if (need_roi_input_) {
@@ -147,29 +335,37 @@ Status Upsample<T>::ComputeInternal(OpKernelContext* context) const {
     }
   }
 
-  const std::vector<float>& roi = roi_cached_ ? roi_ : roi_array;
-  std::vector<float> scales_array = scales_;
+  ComputeROIWithAxes(roi_array, input_dims.size());
 
+  InlinedVector<float> scales_array(input_dims.size());
+  // opset < 10
   if (OpKernel::Node().InputDefs().size() == 1) {
-    // Compute output shape from scales and input dims
+    // Compute output shape from scales attributes and input dims
+    scales_array = scales_;
+
     ComputeOutputShape(scales_array, input_dims, output_dims);
-    return BaseCompute(context, roi, scales_, output_dims);
+    return BaseCompute(context, roi_array, scales_, output_dims);
   }
 
   const Tensor* scales = context->Input<Tensor>(scales_input_idx_);
   const Tensor* sizes = context->Input<Tensor>(sizes_input_idx_);
 
+  // This is when scales are obtained and cached from a constant initializer
   if (scales_cached_) {
-    ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input.");
+    ORT_RETURN_IF_NOT(sizes == nullptr, "Only one of scales or sizes must be provided as input.");
+    scales_array = scales_;
+    // Compute output shape from scales and input dims
     ComputeOutputShape(scales_array, input_dims, output_dims);
-    return BaseCompute(context, roi, scales_, output_dims);
+    return BaseCompute(context, roi_array, scales_array, output_dims);
   }
 
-  scales_array.resize((input_dims.size()));
+  // Scales and sizes are input to the node
   if (scales != nullptr && scales->Shape().Size() != 0) {
     // use scales input data
     ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input.");
     ORT_RETURN_IF_ERROR(ParseScalesData(scales, scales_array, input_dims.size()));
+
+    // Compute output shape from scales and input dims
     ComputeOutputShape(scales_array, input_dims, output_dims);
   } else {
     // When sizes input is available directly populate it into the output_dims array.
@@ -179,7 +375,7 @@ Status Upsample<T>::ComputeInternal(OpKernelContext* context) const {
     ORT_RETURN_IF_ERROR(ParseScalesDataAndAdjustOutputSize(output_dims, input_dims, scales_array));
   }
 
-  return BaseCompute(context, roi, scales_array, output_dims);
+  return BaseCompute(context, roi_array, scales_array, output_dims);
 }
 
 }  // namespace cuda
diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.h b/onnxruntime/core/providers/cuda/tensor/upsample.h
index 7bf2a23ede399..50597e0fba1b9 100644
--- a/onnxruntime/core/providers/cuda/tensor/upsample.h
+++ b/onnxruntime/core/providers/cuda/tensor/upsample.h
@@ -13,12 +13,14 @@ namespace cuda {
 template <typename T>
 class Upsample : public UpsampleBase, public CudaKernel {
  public:
-  Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) {
-  }
+  explicit Upsample(const OpKernelInfo& info);
 
   Status ComputeInternal(OpKernelContext* context) const override;
-  Status BaseCompute(OpKernelContext* context, const std::vector<float>& roi, const std::vector<float>& scales,
-                     const gsl::span<const int64_t>& output_dims) const;
+  Status BaseCompute(OpKernelContext* context, gsl::span<const float> roi, gsl::span<const float> scales,
+                     gsl::span<const int64_t> output_dims) const;
+
+ private:
+  IAllocatorUniquePtr<uint8_t> shared_lookup_table_ondevice_;
 };
 
 }  // namespace cuda
diff --git a/onnxruntime/core/providers/cuda/triton_kernel.cu b/onnxruntime/core/providers/cuda/triton_kernel.cu
index 6ffbf0420a15f..b42dbd0291b7a 100644
--- a/onnxruntime/core/providers/cuda/triton_kernel.cu
+++ b/onnxruntime/core/providers/cuda/triton_kernel.cu
@@ -130,27 +130,11 @@ void LoadOrtTritonKernel() {
   std::call_once(load_ort_triton_kernel_flag, TryToLoadKernel);
 }
 
-Status LaunchTritonKernel(cudaStream_t stream, std::string fname,
-                          int grid0, int grid1, int grid2, void* args, size_t args_size) {
-#ifdef USE_TRITON_KERNEL
-  if (ort_triton_kernel_map.count(fname) == 0) {
-    // Return unsupported status if function name not found in registry.
-    // This error status will be used by TunableOp
-    std::ostringstream message_stream;
-    message_stream << "Can't find ort triton kernel name: " << fname;
-    std::string message = message_stream.str();
-    TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(true, message);
-  }
-  auto idx = ort_triton_kernel_map[fname];
-  return LaunchTritonKernel(stream, idx, grid0, grid1, grid2, args, args_size);
-#else
-  return Status::OK();
-#endif
-}
 
-Status LaunchTritonKernel(cudaStream_t stream, size_t idx,
-                          int grid0, int grid1, int grid2, void* args, size_t args_size) {
+
 #ifdef USE_TRITON_KERNEL
+Status LaunchTritonKernel(cudaStream_t stream, size_t idx, int grid0, int grid1, int grid2,
+                          void* args, size_t args_size) {
   if (idx >= ort_triton_kernel_metadata.size()) {
     // Return unsupported status when idx exceeds the size of ort_triton_kernel_metadata.
     // This error status will be used by TunableOp
@@ -181,11 +165,37 @@ Status LaunchTritonKernel(cudaStream_t stream, size_t idx,
                                   nullptr,
                                   (void**)&config),
                    "Launching kernel failed.");
-#endif
 
   return Status::OK();
 }
 
+Status LaunchTritonKernel(cudaStream_t stream, std::string fname, int grid0, int grid1, int grid2,
+                          void* args, size_t args_size) {
+  if (ort_triton_kernel_map.count(fname) == 0) {
+    // Return unsupported status if function name not found in registry.
+    // This error status will be used by TunableOp
+    std::ostringstream message_stream;
+    message_stream << "Can't find ort triton kernel name: " << fname;
+    std::string message = message_stream.str();
+    TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(true, message);
+  }
+  auto idx = ort_triton_kernel_map[fname];
+  return LaunchTritonKernel(stream, idx, grid0, grid1, grid2, args, args_size);
+}
+
+#else
+Status LaunchTritonKernel(cudaStream_t /*stream*/, std::string /*fname*/, int /*grid0*/, int /*grid1*/, int /*grid2*/,
+                          void* /*args*/, size_t /*args_size*/) {
+  return Status::OK();
+}
+
+Status LaunchTritonKernel(cudaStream_t /*stream*/, size_t /*idx*/, int /*grid0*/, int /*grid1*/, int /*grid2*/,
+                          void* /*args*/, size_t /*args_size*/) {
+  return Status::OK();
+}
+#endif
+
+
 const TritonKernelMetaData* GetOrtTritonKernelMetadata(size_t idx) {
   if (idx >= ort_triton_kernel_metadata.size()) {
     return nullptr;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h
index f29cc3afc3cda..88e3dd487d427 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h
@@ -80,15 +80,10 @@ namespace Windows::AI::MachineLearning::Adapter
     };
 
     // This is the counterpart to the MLOperatorGraphDesc ABI struct which owns its memory and uses containers.
-    // Either nodesAsOperatorDesc or nodesAsIDMLOperator can have non-zero size.
     struct DmlGraphNodeCreateInfo
     {
         uint32_t nodeCount = 0;
-        std::vector<std::unique_ptr<AbstractOperatorDesc>> nodesAsOperatorDesc;
-
-        // TODO (jeffbloo): Remove this
-        std::vector<Microsoft::WRL::ComPtr<IDMLOperator>> nodesAsIDMLOperator;
-
+        std::vector<std::unique_ptr<AbstractOperatorDesc>> nodes;
         std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
         std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
         std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp
new file mode 100644
index 0000000000000..bf9800458102b
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp
@@ -0,0 +1,570 @@
+//---------------------------------------------------------------------------
+// Copyright (c) Microsoft Corporation. All rights reserved.
+//
+// This file is automatically generated. Please do not edit it directly.
+// To modify this file, edit the schema: dml/Tools/DirectMLSchema.json
+// And run this script to regenerate: dml/Tools/GenerateSchema.ps1
+//
+// #dml-new-operator-location
+//---------------------------------------------------------------------------
+
+#pragma once
+
+#include "precomp.h"
+
+template <typename T>
+T ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+#ifndef WAI_BUILD_LINUX
+    // Clang will instantiate this template even if it isn't used,
+    // so this static_assert will always fire and break the build.
+    static_assert(false, "Not implemented for this type");
+#endif
+}
+
+template <>
+DML_TENSOR_DATA_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_TENSOR_DATA_TYPE_UNKNOWN", DML_TENSOR_DATA_TYPE_UNKNOWN},
+        {"DML_TENSOR_DATA_TYPE_FLOAT32", DML_TENSOR_DATA_TYPE_FLOAT32},
+        {"DML_TENSOR_DATA_TYPE_FLOAT16", DML_TENSOR_DATA_TYPE_FLOAT16},
+        {"DML_TENSOR_DATA_TYPE_UINT32", DML_TENSOR_DATA_TYPE_UINT32},
+        {"DML_TENSOR_DATA_TYPE_UINT16", DML_TENSOR_DATA_TYPE_UINT16},
+        {"DML_TENSOR_DATA_TYPE_UINT8", DML_TENSOR_DATA_TYPE_UINT8},
+        {"DML_TENSOR_DATA_TYPE_INT32", DML_TENSOR_DATA_TYPE_INT32},
+        {"DML_TENSOR_DATA_TYPE_INT16", DML_TENSOR_DATA_TYPE_INT16},
+        {"DML_TENSOR_DATA_TYPE_INT8", DML_TENSOR_DATA_TYPE_INT8},
+        {"DML_TENSOR_DATA_TYPE_FLOAT64", DML_TENSOR_DATA_TYPE_FLOAT64},
+        {"DML_TENSOR_DATA_TYPE_UINT64", DML_TENSOR_DATA_TYPE_UINT64},
+        {"DML_TENSOR_DATA_TYPE_INT64", DML_TENSOR_DATA_TYPE_INT64},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_TENSOR_DATA_TYPE>(0);
+    }
+    return static_cast<DML_TENSOR_DATA_TYPE>(*index);
+}
+
+
+template <>
+DML_TENSOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_TENSOR_TYPE_INVALID", DML_TENSOR_TYPE_INVALID},
+        {"DML_TENSOR_TYPE_BUFFER", DML_TENSOR_TYPE_BUFFER},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_TENSOR_TYPE>(0);
+    }
+    return static_cast<DML_TENSOR_TYPE>(*index);
+}
+
+
+template <>
+DML_OPERATOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_OPERATOR_INVALID", DML_OPERATOR_INVALID},
+        {"DML_OPERATOR_ELEMENT_WISE_IDENTITY", DML_OPERATOR_ELEMENT_WISE_IDENTITY},
+        {"DML_OPERATOR_ELEMENT_WISE_ABS", DML_OPERATOR_ELEMENT_WISE_ABS},
+        {"DML_OPERATOR_ELEMENT_WISE_ACOS", DML_OPERATOR_ELEMENT_WISE_ACOS},
+        {"DML_OPERATOR_ELEMENT_WISE_ADD", DML_OPERATOR_ELEMENT_WISE_ADD},
+        {"DML_OPERATOR_ELEMENT_WISE_ASIN", DML_OPERATOR_ELEMENT_WISE_ASIN},
+        {"DML_OPERATOR_ELEMENT_WISE_ATAN", DML_OPERATOR_ELEMENT_WISE_ATAN},
+        {"DML_OPERATOR_ELEMENT_WISE_CEIL", DML_OPERATOR_ELEMENT_WISE_CEIL},
+        {"DML_OPERATOR_ELEMENT_WISE_CLIP", DML_OPERATOR_ELEMENT_WISE_CLIP},
+        {"DML_OPERATOR_ELEMENT_WISE_COS", DML_OPERATOR_ELEMENT_WISE_COS},
+        {"DML_OPERATOR_ELEMENT_WISE_DIVIDE", DML_OPERATOR_ELEMENT_WISE_DIVIDE},
+        {"DML_OPERATOR_ELEMENT_WISE_EXP", DML_OPERATOR_ELEMENT_WISE_EXP},
+        {"DML_OPERATOR_ELEMENT_WISE_FLOOR", DML_OPERATOR_ELEMENT_WISE_FLOOR},
+        {"DML_OPERATOR_ELEMENT_WISE_LOG", DML_OPERATOR_ELEMENT_WISE_LOG},
+        {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND", DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND},
+        {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS", DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS},
+        {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN", DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN},
+        {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN", DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN},
+        {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL", DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL},
+        {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL", DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL},
+        {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT", DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT},
+        {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR", DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR},
+        {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR", DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR},
+        {"DML_OPERATOR_ELEMENT_WISE_MAX", DML_OPERATOR_ELEMENT_WISE_MAX},
+        {"DML_OPERATOR_ELEMENT_WISE_MEAN", DML_OPERATOR_ELEMENT_WISE_MEAN},
+        {"DML_OPERATOR_ELEMENT_WISE_MIN", DML_OPERATOR_ELEMENT_WISE_MIN},
+        {"DML_OPERATOR_ELEMENT_WISE_MULTIPLY", DML_OPERATOR_ELEMENT_WISE_MULTIPLY},
+        {"DML_OPERATOR_ELEMENT_WISE_POW", DML_OPERATOR_ELEMENT_WISE_POW},
+        {"DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW", DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW},
+        {"DML_OPERATOR_ELEMENT_WISE_RECIP", DML_OPERATOR_ELEMENT_WISE_RECIP},
+        {"DML_OPERATOR_ELEMENT_WISE_SIN", DML_OPERATOR_ELEMENT_WISE_SIN},
+        {"DML_OPERATOR_ELEMENT_WISE_SQRT", DML_OPERATOR_ELEMENT_WISE_SQRT},
+        {"DML_OPERATOR_ELEMENT_WISE_SUBTRACT", DML_OPERATOR_ELEMENT_WISE_SUBTRACT},
+        {"DML_OPERATOR_ELEMENT_WISE_TAN", DML_OPERATOR_ELEMENT_WISE_TAN},
+        {"DML_OPERATOR_ELEMENT_WISE_THRESHOLD", DML_OPERATOR_ELEMENT_WISE_THRESHOLD},
+        {"DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR", DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR},
+        {"DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR", DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR},
+        {"DML_OPERATOR_ACTIVATION_ELU", DML_OPERATOR_ACTIVATION_ELU},
+        {"DML_OPERATOR_ACTIVATION_CELU", DML_OPERATOR_ACTIVATION_CELU},
+        {"DML_OPERATOR_ACTIVATION_HARDMAX", DML_OPERATOR_ACTIVATION_HARDMAX},
+        {"DML_OPERATOR_ACTIVATION_HARDMAX1", DML_OPERATOR_ACTIVATION_HARDMAX1},
+        {"DML_OPERATOR_ACTIVATION_HARD_SIGMOID", DML_OPERATOR_ACTIVATION_HARD_SIGMOID},
+        {"DML_OPERATOR_ACTIVATION_IDENTITY", DML_OPERATOR_ACTIVATION_IDENTITY},
+        {"DML_OPERATOR_ACTIVATION_LEAKY_RELU", DML_OPERATOR_ACTIVATION_LEAKY_RELU},
+        {"DML_OPERATOR_ACTIVATION_LINEAR", DML_OPERATOR_ACTIVATION_LINEAR},
+        {"DML_OPERATOR_ACTIVATION_LOG_SOFTMAX", DML_OPERATOR_ACTIVATION_LOG_SOFTMAX},
+        {"DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1", DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1},
+        {"DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU", DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU},
+        {"DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS", DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS},
+        {"DML_OPERATOR_ACTIVATION_RELU", DML_OPERATOR_ACTIVATION_RELU},
+        {"DML_OPERATOR_ACTIVATION_SCALED_ELU", DML_OPERATOR_ACTIVATION_SCALED_ELU},
+        {"DML_OPERATOR_ACTIVATION_SCALED_TANH", DML_OPERATOR_ACTIVATION_SCALED_TANH},
+        {"DML_OPERATOR_ACTIVATION_SIGMOID", DML_OPERATOR_ACTIVATION_SIGMOID},
+        {"DML_OPERATOR_ACTIVATION_SOFTMAX", DML_OPERATOR_ACTIVATION_SOFTMAX},
+        {"DML_OPERATOR_ACTIVATION_SOFTMAX1", DML_OPERATOR_ACTIVATION_SOFTMAX1},
+        {"DML_OPERATOR_ACTIVATION_SOFTPLUS", DML_OPERATOR_ACTIVATION_SOFTPLUS},
+        {"DML_OPERATOR_ACTIVATION_SOFTSIGN", DML_OPERATOR_ACTIVATION_SOFTSIGN},
+        {"DML_OPERATOR_ACTIVATION_TANH", DML_OPERATOR_ACTIVATION_TANH},
+        {"DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU", DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU},
+        {"DML_OPERATOR_CONVOLUTION", DML_OPERATOR_CONVOLUTION},
+        {"DML_OPERATOR_GEMM", DML_OPERATOR_GEMM},
+        {"DML_OPERATOR_REDUCE", DML_OPERATOR_REDUCE},
+        {"DML_OPERATOR_AVERAGE_POOLING", DML_OPERATOR_AVERAGE_POOLING},
+        {"DML_OPERATOR_AVERAGE_POOLING1", DML_OPERATOR_AVERAGE_POOLING1},
+        {"DML_OPERATOR_LP_POOLING", DML_OPERATOR_LP_POOLING},
+        {"DML_OPERATOR_LP_POOLING1", DML_OPERATOR_LP_POOLING1},
+        {"DML_OPERATOR_MAX_POOLING", DML_OPERATOR_MAX_POOLING},
+        {"DML_OPERATOR_ROI_POOLING", DML_OPERATOR_ROI_POOLING},
+        {"DML_OPERATOR_SLICE", DML_OPERATOR_SLICE},
+        {"DML_OPERATOR_CAST", DML_OPERATOR_CAST},
+        {"DML_OPERATOR_SPLIT", DML_OPERATOR_SPLIT},
+        {"DML_OPERATOR_JOIN", DML_OPERATOR_JOIN},
+        {"DML_OPERATOR_PADDING", DML_OPERATOR_PADDING},
+        {"DML_OPERATOR_PADDING1", DML_OPERATOR_PADDING1},
+        {"DML_OPERATOR_VALUE_SCALE_2D", DML_OPERATOR_VALUE_SCALE_2D},
+        {"DML_OPERATOR_UPSAMPLE_2D", DML_OPERATOR_UPSAMPLE_2D},
+        {"DML_OPERATOR_GATHER", DML_OPERATOR_GATHER},
+        {"DML_OPERATOR_SPACE_TO_DEPTH", DML_OPERATOR_SPACE_TO_DEPTH},
+        {"DML_OPERATOR_DEPTH_TO_SPACE", DML_OPERATOR_DEPTH_TO_SPACE},
+        {"DML_OPERATOR_TILE", DML_OPERATOR_TILE},
+        {"DML_OPERATOR_TOP_K", DML_OPERATOR_TOP_K},
+        {"DML_OPERATOR_BATCH_NORMALIZATION", DML_OPERATOR_BATCH_NORMALIZATION},
+        {"DML_OPERATOR_BATCH_NORMALIZATION_TRAINING", DML_OPERATOR_BATCH_NORMALIZATION_TRAINING},
+        {"DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION", DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION},
+        {"DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION", DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION},
+        {"DML_OPERATOR_LP_NORMALIZATION", DML_OPERATOR_LP_NORMALIZATION},
+        {"DML_OPERATOR_RNN", DML_OPERATOR_RNN},
+        {"DML_OPERATOR_LSTM", DML_OPERATOR_LSTM},
+        {"DML_OPERATOR_GRU", DML_OPERATOR_GRU},
+        {"DML_OPERATOR_ELEMENT_WISE_SIGN", DML_OPERATOR_ELEMENT_WISE_SIGN},
+        {"DML_OPERATOR_ELEMENT_WISE_IS_NAN", DML_OPERATOR_ELEMENT_WISE_IS_NAN},
+        {"DML_OPERATOR_ELEMENT_WISE_ERF", DML_OPERATOR_ELEMENT_WISE_ERF},
+        {"DML_OPERATOR_ELEMENT_WISE_SINH", DML_OPERATOR_ELEMENT_WISE_SINH},
+        {"DML_OPERATOR_ELEMENT_WISE_COSH", DML_OPERATOR_ELEMENT_WISE_COSH},
+        {"DML_OPERATOR_ELEMENT_WISE_TANH", DML_OPERATOR_ELEMENT_WISE_TANH},
+        {"DML_OPERATOR_ELEMENT_WISE_ASINH", DML_OPERATOR_ELEMENT_WISE_ASINH},
+        {"DML_OPERATOR_ELEMENT_WISE_ACOSH", DML_OPERATOR_ELEMENT_WISE_ACOSH},
+        {"DML_OPERATOR_ELEMENT_WISE_ATANH", DML_OPERATOR_ELEMENT_WISE_ATANH},
+        {"DML_OPERATOR_ELEMENT_WISE_IF", DML_OPERATOR_ELEMENT_WISE_IF},
+        {"DML_OPERATOR_ELEMENT_WISE_ADD1", DML_OPERATOR_ELEMENT_WISE_ADD1},
+        {"DML_OPERATOR_ACTIVATION_SHRINK", DML_OPERATOR_ACTIVATION_SHRINK},
+        {"DML_OPERATOR_MAX_POOLING1", DML_OPERATOR_MAX_POOLING1},
+        {"DML_OPERATOR_MAX_UNPOOLING", DML_OPERATOR_MAX_UNPOOLING},
+        {"DML_OPERATOR_DIAGONAL_MATRIX", DML_OPERATOR_DIAGONAL_MATRIX},
+        {"DML_OPERATOR_SCATTER", DML_OPERATOR_SCATTER},
+        {"DML_OPERATOR_ONE_HOT", DML_OPERATOR_ONE_HOT},
+        {"DML_OPERATOR_RESAMPLE", DML_OPERATOR_RESAMPLE},
+        {"DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT", DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT},
+        {"DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT", DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT},
+        {"DML_OPERATOR_ELEMENT_WISE_ROUND", DML_OPERATOR_ELEMENT_WISE_ROUND},
+        {"DML_OPERATOR_ELEMENT_WISE_IS_INFINITY", DML_OPERATOR_ELEMENT_WISE_IS_INFINITY},
+        {"DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE", DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE},
+        {"DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR", DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR},
+        {"DML_OPERATOR_FILL_VALUE_SEQUENCE", DML_OPERATOR_FILL_VALUE_SEQUENCE},
+        {"DML_OPERATOR_FILL_VALUE_CONSTANT", DML_OPERATOR_FILL_VALUE_CONSTANT},
+        {"DML_OPERATOR_CUMULATIVE_SUMMATION", DML_OPERATOR_CUMULATIVE_SUMMATION},
+        {"DML_OPERATOR_REVERSE_SUBSEQUENCES", DML_OPERATOR_REVERSE_SUBSEQUENCES},
+        {"DML_OPERATOR_GATHER_ELEMENTS", DML_OPERATOR_GATHER_ELEMENTS},
+        {"DML_OPERATOR_GATHER_ND", DML_OPERATOR_GATHER_ND},
+        {"DML_OPERATOR_SCATTER_ND", DML_OPERATOR_SCATTER_ND},
+        {"DML_OPERATOR_MAX_POOLING2", DML_OPERATOR_MAX_POOLING2},
+        {"DML_OPERATOR_SLICE1", DML_OPERATOR_SLICE1},
+        {"DML_OPERATOR_TOP_K1", DML_OPERATOR_TOP_K1},
+        {"DML_OPERATOR_DEPTH_TO_SPACE1", DML_OPERATOR_DEPTH_TO_SPACE1},
+        {"DML_OPERATOR_SPACE_TO_DEPTH1", DML_OPERATOR_SPACE_TO_DEPTH1},
+        {"DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1", DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1},
+        {"DML_OPERATOR_RESAMPLE1", DML_OPERATOR_RESAMPLE1},
+        {"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER", DML_OPERATOR_MATRIX_MULTIPLY_INTEGER},
+        {"DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY", DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY},
+        {"DML_OPERATOR_CONVOLUTION_INTEGER", DML_OPERATOR_CONVOLUTION_INTEGER},
+        {"DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION", DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION},
+        {"DML_OPERATOR_ELEMENT_WISE_BIT_AND", DML_OPERATOR_ELEMENT_WISE_BIT_AND},
+        {"DML_OPERATOR_ELEMENT_WISE_BIT_OR", DML_OPERATOR_ELEMENT_WISE_BIT_OR},
+        {"DML_OPERATOR_ELEMENT_WISE_BIT_XOR", DML_OPERATOR_ELEMENT_WISE_BIT_XOR},
+        {"DML_OPERATOR_ELEMENT_WISE_BIT_NOT", DML_OPERATOR_ELEMENT_WISE_BIT_NOT},
+        {"DML_OPERATOR_ELEMENT_WISE_BIT_COUNT", DML_OPERATOR_ELEMENT_WISE_BIT_COUNT},
+        {"DML_OPERATOR_ACTIVATION_RELU_GRAD", DML_OPERATOR_ACTIVATION_RELU_GRAD},
+        {"DML_OPERATOR_AVERAGE_POOLING_GRAD", DML_OPERATOR_AVERAGE_POOLING_GRAD},
+        {"DML_OPERATOR_MAX_POOLING_GRAD", DML_OPERATOR_MAX_POOLING_GRAD},
+        {"DML_OPERATOR_RANDOM_GENERATOR", DML_OPERATOR_RANDOM_GENERATOR},
+        {"DML_OPERATOR_NONZERO_COORDINATES", DML_OPERATOR_NONZERO_COORDINATES},
+        {"DML_OPERATOR_RESAMPLE_GRAD", DML_OPERATOR_RESAMPLE_GRAD},
+        {"DML_OPERATOR_SLICE_GRAD", DML_OPERATOR_SLICE_GRAD},
+        {"DML_OPERATOR_ADAM_OPTIMIZER", DML_OPERATOR_ADAM_OPTIMIZER},
+        {"DML_OPERATOR_ARGMIN", DML_OPERATOR_ARGMIN},
+        {"DML_OPERATOR_ARGMAX", DML_OPERATOR_ARGMAX},
+        {"DML_OPERATOR_ROI_ALIGN", DML_OPERATOR_ROI_ALIGN},
+        {"DML_OPERATOR_GATHER_ND1", DML_OPERATOR_GATHER_ND1},
+        {"DML_OPERATOR_ELEMENT_WISE_ATAN_YX", DML_OPERATOR_ELEMENT_WISE_ATAN_YX},
+        {"DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD", DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD},
+        {"DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE", DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE},
+        {"DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD", DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD},
+        {"DML_OPERATOR_CUMULATIVE_PRODUCT", DML_OPERATOR_CUMULATIVE_PRODUCT},
+        {"DML_OPERATOR_BATCH_NORMALIZATION_GRAD", DML_OPERATOR_BATCH_NORMALIZATION_GRAD},
+        {"DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD", DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD},
+        {"DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD", DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD},
+        {"DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR", DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR},
+        {"DML_OPERATOR_ROI_ALIGN1", DML_OPERATOR_ROI_ALIGN1},
+        {"DML_OPERATOR_ELEMENT_WISE_CLIP1", DML_OPERATOR_ELEMENT_WISE_CLIP1},
+        {"DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1", DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1},
+        {"DML_OPERATOR_ELEMENT_WISE_NEGATE", DML_OPERATOR_ELEMENT_WISE_NEGATE},
+        {"DML_OPERATOR_ACTIVATION_GELU", DML_OPERATOR_ACTIVATION_GELU},
+        {"DML_OPERATOR_ACTIVATION_SWISH", DML_OPERATOR_ACTIVATION_SWISH},
+        {"DML_OPERATOR_ACTIVATION_HARD_SWISH", DML_OPERATOR_ACTIVATION_HARD_SWISH},
+        {"DML_OPERATOR_RESAMPLE2", DML_OPERATOR_RESAMPLE2},
+        {"DML_OPERATOR_RESAMPLE_GRAD1", DML_OPERATOR_RESAMPLE_GRAD1},
+        {"DML_OPERATOR_DIAGONAL_MATRIX1", DML_OPERATOR_DIAGONAL_MATRIX1},
+        {"DML_OPERATOR_MULTIHEAD_ATTENTION", DML_OPERATOR_MULTIHEAD_ATTENTION},
+        {"DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING},
+        {"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_OPERATOR_TYPE>(0);
+    }
+    return static_cast<DML_OPERATOR_TYPE>(*index);
+}
+
+
+template <>
+DML_BINDING_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_BINDING_TYPE_NONE", DML_BINDING_TYPE_NONE},
+        {"DML_BINDING_TYPE_BUFFER", DML_BINDING_TYPE_BUFFER},
+        {"DML_BINDING_TYPE_BUFFER_ARRAY", DML_BINDING_TYPE_BUFFER_ARRAY},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_BINDING_TYPE>(0);
+    }
+    return static_cast<DML_BINDING_TYPE>(*index);
+}
+
+
+template <>
+DML_REDUCE_FUNCTION ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_REDUCE_FUNCTION_ARGMAX", DML_REDUCE_FUNCTION_ARGMAX},
+        {"DML_REDUCE_FUNCTION_ARGMIN", DML_REDUCE_FUNCTION_ARGMIN},
+        {"DML_REDUCE_FUNCTION_AVERAGE", DML_REDUCE_FUNCTION_AVERAGE},
+        {"DML_REDUCE_FUNCTION_L1", DML_REDUCE_FUNCTION_L1},
+        {"DML_REDUCE_FUNCTION_L2", DML_REDUCE_FUNCTION_L2},
+        {"DML_REDUCE_FUNCTION_LOG_SUM", DML_REDUCE_FUNCTION_LOG_SUM},
+        {"DML_REDUCE_FUNCTION_LOG_SUM_EXP", DML_REDUCE_FUNCTION_LOG_SUM_EXP},
+        {"DML_REDUCE_FUNCTION_MAX", DML_REDUCE_FUNCTION_MAX},
+        {"DML_REDUCE_FUNCTION_MIN", DML_REDUCE_FUNCTION_MIN},
+        {"DML_REDUCE_FUNCTION_MULTIPLY", DML_REDUCE_FUNCTION_MULTIPLY},
+        {"DML_REDUCE_FUNCTION_SUM", DML_REDUCE_FUNCTION_SUM},
+        {"DML_REDUCE_FUNCTION_SUM_SQUARE", DML_REDUCE_FUNCTION_SUM_SQUARE},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_REDUCE_FUNCTION>(0);
+    }
+    return static_cast<DML_REDUCE_FUNCTION>(*index);
+}
+
+template <>
+DML_MATRIX_TRANSFORM ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_MATRIX_TRANSFORM_NONE", DML_MATRIX_TRANSFORM_NONE},
+        {"DML_MATRIX_TRANSFORM_TRANSPOSE", DML_MATRIX_TRANSFORM_TRANSPOSE},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_MATRIX_TRANSFORM>(0);
+    }
+    return static_cast<DML_MATRIX_TRANSFORM>(*index);
+}
+
+
+template <>
+DML_CONVOLUTION_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_CONVOLUTION_MODE_CONVOLUTION", DML_CONVOLUTION_MODE_CONVOLUTION},
+        {"DML_CONVOLUTION_MODE_CROSS_CORRELATION", DML_CONVOLUTION_MODE_CROSS_CORRELATION},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_CONVOLUTION_MODE>(0);
+    }
+    return static_cast<DML_CONVOLUTION_MODE>(*index);
+}
+
+
+template <>
+DML_CONVOLUTION_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_CONVOLUTION_DIRECTION_FORWARD", DML_CONVOLUTION_DIRECTION_FORWARD},
+        {"DML_CONVOLUTION_DIRECTION_BACKWARD", DML_CONVOLUTION_DIRECTION_BACKWARD},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_CONVOLUTION_DIRECTION>(0);
+    }
+    return static_cast<DML_CONVOLUTION_DIRECTION>(*index);
+}
+
+template <>
+DML_PADDING_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_PADDING_MODE_CONSTANT", DML_PADDING_MODE_CONSTANT},
+        {"DML_PADDING_MODE_EDGE", DML_PADDING_MODE_EDGE},
+        {"DML_PADDING_MODE_REFLECTION", DML_PADDING_MODE_REFLECTION},
+        {"DML_PADDING_MODE_SYMMETRIC", DML_PADDING_MODE_SYMMETRIC},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_PADDING_MODE>(0);
+    }
+    return static_cast<DML_PADDING_MODE>(*index);
+}
+
+
+template <>
+DML_INTERPOLATION_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR},
+        {"DML_INTERPOLATION_MODE_LINEAR", DML_INTERPOLATION_MODE_LINEAR},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_INTERPOLATION_MODE>(0);
+    }
+    return static_cast<DML_INTERPOLATION_MODE>(*index);
+}
+
+
+template <>
+DML_RECURRENT_NETWORK_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_RECURRENT_NETWORK_DIRECTION_FORWARD", DML_RECURRENT_NETWORK_DIRECTION_FORWARD},
+        {"DML_RECURRENT_NETWORK_DIRECTION_BACKWARD", DML_RECURRENT_NETWORK_DIRECTION_BACKWARD},
+        {"DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL", DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_RECURRENT_NETWORK_DIRECTION>(0);
+    }
+    return static_cast<DML_RECURRENT_NETWORK_DIRECTION>(*index);
+}
+
+
+template <>
+DML_FEATURE ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT", DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT},
+        {"DML_FEATURE_FEATURE_LEVELS", DML_FEATURE_FEATURE_LEVELS},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_FEATURE>(0);
+    }
+    return static_cast<DML_FEATURE>(*index);
+}
+
+
+template <>
+DML_FEATURE_LEVEL ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_FEATURE_LEVEL_1_0", DML_FEATURE_LEVEL_1_0},
+        {"DML_FEATURE_LEVEL_2_0", DML_FEATURE_LEVEL_2_0},
+        {"DML_FEATURE_LEVEL_2_1", DML_FEATURE_LEVEL_2_1},
+        {"DML_FEATURE_LEVEL_3_0", DML_FEATURE_LEVEL_3_0},
+        {"DML_FEATURE_LEVEL_3_1", DML_FEATURE_LEVEL_3_1},
+        {"DML_FEATURE_LEVEL_4_0", DML_FEATURE_LEVEL_4_0},
+        {"DML_FEATURE_LEVEL_4_1", DML_FEATURE_LEVEL_4_1},
+        {"DML_FEATURE_LEVEL_5_0", DML_FEATURE_LEVEL_5_0},
+        {"DML_FEATURE_LEVEL_5_1", DML_FEATURE_LEVEL_5_1},
+        {"DML_FEATURE_LEVEL_5_2", DML_FEATURE_LEVEL_5_2},
+        {"DML_FEATURE_LEVEL_6_0", DML_FEATURE_LEVEL_6_0},
+        {"DML_FEATURE_LEVEL_6_1", DML_FEATURE_LEVEL_6_1},
+        {"DML_FEATURE_LEVEL_6_2", DML_FEATURE_LEVEL_6_2},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_FEATURE_LEVEL>(0);
+    }
+    return static_cast<DML_FEATURE_LEVEL>(*index);
+}
+
+
+template <>
+DML_IS_INFINITY_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_IS_INFINITY_MODE_EITHER", DML_IS_INFINITY_MODE_EITHER},
+        {"DML_IS_INFINITY_MODE_POSITIVE", DML_IS_INFINITY_MODE_POSITIVE},
+        {"DML_IS_INFINITY_MODE_NEGATIVE", DML_IS_INFINITY_MODE_NEGATIVE},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_IS_INFINITY_MODE>(0);
+    }
+    return static_cast<DML_IS_INFINITY_MODE>(*index);
+}
+
+
+template <>
+DML_DEPTH_SPACE_ORDER ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW", DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW},
+        {"DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH", DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_DEPTH_SPACE_ORDER>(0);
+    }
+    return static_cast<DML_DEPTH_SPACE_ORDER>(*index);
+}
+
+
+template <>
+DML_AXIS_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_AXIS_DIRECTION_INCREASING", DML_AXIS_DIRECTION_INCREASING},
+        {"DML_AXIS_DIRECTION_DECREASING", DML_AXIS_DIRECTION_DECREASING},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_AXIS_DIRECTION>(0);
+    }
+    return static_cast<DML_AXIS_DIRECTION>(*index);
+}
+
+
+template <>
+DML_ROUNDING_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN", DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN},
+        {"DML_ROUNDING_MODE_TOWARD_ZERO", DML_ROUNDING_MODE_TOWARD_ZERO},
+        {"DML_ROUNDING_MODE_TOWARD_INFINITY", DML_ROUNDING_MODE_TOWARD_INFINITY},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_ROUNDING_MODE>(0);
+    }
+    return static_cast<DML_ROUNDING_MODE>(*index);
+}
+
+
+template <>
+DML_RANDOM_GENERATOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10", DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_RANDOM_GENERATOR_TYPE>(0);
+    }
+    return static_cast<DML_RANDOM_GENERATOR_TYPE>(*index);
+}
+
+
+template <>
+DML_MULTIHEAD_ATTENTION_MASK_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value)
+{
+    constexpr StringUtil::NameAndIndex mapping[] =
+    {
+        {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE", DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE},
+        {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH},
+        {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START},
+        {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END},
+        {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN", DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN},
+    };
+    auto index = StringUtil::MapToIndex(value, mapping);
+    if (!index)
+    {
+        assert(false);
+        return static_cast<DML_MULTIHEAD_ATTENTION_MASK_TYPE>(0);
+    }
+    return static_cast<DML_MULTIHEAD_ATTENTION_MASK_TYPE>(*index);
+}
+
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp
new file mode 100644
index 0000000000000..7d8ed17e7d925
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp
@@ -0,0 +1,554 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+
+#pragma once
+#include "precomp.h"
+
+OperatorFieldVariant CreateAttribute(
+    const DML_SCHEMA_FIELD* schemaField,
+    const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc);
+
+OperatorFieldVariant CreateActivation(
+    const dml::ir::operatorFieldTypes::Activation* activationDesc)
+{
+    DML_OPERATOR_TYPE activationOperatorType = ApiTraits::StringifyHelpers::FromString<DML_OPERATOR_TYPE>(activationDesc->type()->c_str());
+    const DML_OPERATOR_SCHEMA& activationSchema = SchemaHelpers::GetSchema(activationOperatorType);
+    std::vector<OperatorField> activationOperatorFields(activationSchema.FieldCount);
+    uint32_t attributeIndex = 0;
+
+    for (uint32_t fieldIndex = 0; fieldIndex < activationSchema.FieldCount; fieldIndex++)
+    {
+        const DML_SCHEMA_FIELD* schemaField = &activationSchema.Fields[fieldIndex];
+        OperatorFieldVariant field;
+        switch (schemaField->Kind)
+        {
+            case DML_SCHEMA_FIELD_KIND_INPUT_TENSOR:
+            case DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR:
+            {
+                if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC)
+                {
+                    field = OperatorFieldTypes::TensorDesc();
+                }
+                else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY)
+                {
+                    field = OperatorFieldTypes::TensorDescArray();
+                }
+                break;
+            }
+            case DML_SCHEMA_FIELD_KIND_ATTRIBUTE:
+            {
+                const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc = 
+                    attributeIndex >= activationDesc->attributes()->size() ?
+                    nullptr : 
+                    activationDesc->attributes()->Get(attributeIndex++);
+                field = CreateAttribute(schemaField, attributeDesc);
+                break;
+            }
+        }
+
+        activationOperatorFields[fieldIndex] = OperatorField(schemaField, std::move(field));
+    }
+
+    return AbstractOperatorDesc(&activationSchema, std::move(activationOperatorFields));
+}
+
+OperatorFieldVariant CreateActivations(
+    const dml::ir::operatorFieldTypes::ActivationArray* activationDescs)
+{
+    std::vector<AbstractOperatorDesc> activations;
+    for (uint32_t index = 0; index < static_cast<uint32_t>(activationDescs->data()->size()); index++)
+    {
+        OperatorFieldVariant activation = CreateActivation(activationDescs->data()->Get(index));
+        activations.push_back(std::get<OperatorFieldTypes::FusedActivationOperatorDesc>(activation).value());
+    }
+    return activations;
+}
+
+OperatorFieldVariant CreateAttribute(
+    const DML_SCHEMA_FIELD* schemaField,
+    const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc)
+{
+    switch (schemaField->Type)
+    {
+        case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC:
+        {
+            return attributeDesc != nullptr && attributeDesc->val_as_Activation() != nullptr ?  
+                CreateActivation(attributeDesc->val_as_Activation()) : 
+                OperatorFieldTypes::FusedActivationOperatorDesc();
+        }
+        case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY:
+        {
+            return attributeDesc != nullptr && attributeDesc->val_as_ActivationArray() != nullptr ?  
+                CreateActivations(attributeDesc->val_as_ActivationArray()) : 
+                OperatorFieldTypes::FusedActivationOperatorDescArray();
+        }
+        case DML_SCHEMA_FIELD_TYPE_UINT:
+        {
+            OperatorFieldTypes::UInt data;
+            if (attributeDesc != nullptr)
+            {
+                data = attributeDesc->val_as_UInt32()->data();
+            }
+            return data;
+        }
+        case DML_SCHEMA_FIELD_TYPE_UINT64:
+        {
+            OperatorFieldTypes::UInt64 data;
+            if (attributeDesc != nullptr)
+            {
+                data = attributeDesc->val_as_UInt64()->data();
+            }
+            return data;
+        }
+        case DML_SCHEMA_FIELD_TYPE_INT:
+        {
+            OperatorFieldTypes::Int data;
+            if (attributeDesc != nullptr)
+            {
+                data = attributeDesc->val_as_Int32()->data();
+            }
+            return data;
+        }
+        case DML_SCHEMA_FIELD_TYPE_FLOAT:
+        {
+            OperatorFieldTypes::Float data;
+            if (attributeDesc != nullptr)
+            {
+                data = attributeDesc->val_as_Float32()->data();
+            }
+            return data;
+        }
+        case DML_SCHEMA_FIELD_TYPE_UINT_ARRAY:
+        {
+            OperatorFieldTypes::UIntArray data;
+            if (attributeDesc != nullptr)
+            {
+                data.assign(attributeDesc->val_as_UIntArray()->data()->begin(), attributeDesc->val_as_UIntArray()->data()->end());
+            }
+            return data;
+        }
+        case DML_SCHEMA_FIELD_TYPE_INT_ARRAY:
+        {
+            OperatorFieldTypes::IntArray data;
+            if (attributeDesc != nullptr)
+            {
+                data.assign(attributeDesc->val_as_IntArray()->data()->begin(), attributeDesc->val_as_IntArray()->data()->end());
+            }
+            return data;
+        }
+        case DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY:
+        {
+            OperatorFieldTypes::FloatArray data;
+            if (attributeDesc != nullptr)
+            {
+                data.assign(attributeDesc->val_as_FloatArray()->data()->begin(), attributeDesc->val_as_FloatArray()->data()->end());
+            }
+            return data;
+        }	
+        case DML_SCHEMA_FIELD_TYPE_SCALE_BIAS:
+        {
+            OperatorFieldTypes::ScaleBias scaleBias;
+            const dml::ir::operatorFieldTypes::ScaleBias* scaleBiasAttribute = attributeDesc->val_as_ScaleBias();
+            if (scaleBiasAttribute != nullptr)
+            {
+                scaleBias = {scaleBiasAttribute->scale(), scaleBiasAttribute->bias()};
+            }
+            return scaleBias;
+        }
+        case DML_SCHEMA_FIELD_TYPE_SIZE_2D:
+        {
+            OperatorFieldTypes::Size2D size2d = {};
+            if (attributeDesc != nullptr)
+            {
+                size2d.Height = attributeDesc->val_as_Size2D()->height();
+                size2d.Width = attributeDesc->val_as_Size2D()->width();
+            }
+            return size2d;
+        }
+        case DML_SCHEMA_FIELD_TYPE_SCALAR_UNION:
+        {
+            DML_SCALAR_UNION scalarUnion;
+            if (attributeDesc != nullptr)
+            {
+                const dml::ir::operatorFieldTypes::ByteArray* byteArr = attributeDesc->val_as_ScalarUnionData()->data_as_ByteArray();
+                std::copy(byteArr->data()->begin(), byteArr->data()->end(), scalarUnion.Bytes);
+            }
+            return scalarUnion;
+        }
+        case DML_SCHEMA_FIELD_TYPE_BOOL:
+        {
+            OperatorFieldTypes::Bool data;
+            if (attributeDesc != nullptr)
+            {
+                data = attributeDesc->val_as_Bool()->data();
+            }
+            return data;
+        }
+        default:
+        {
+            throw std::invalid_argument("Invalid attribute type.");
+        }
+    }
+}
+
+OperatorFieldTypes::TensorDesc CreateBufferTensorDesc(
+    const dml::ir::DmlBufferTensorDesc* tensorDesc,
+    const bool isConstantTensor = false)
+{
+    DmlBufferTensorDesc bufferTensorDesc = {};
+    bufferTensorDesc.dataType = ApiTraits::StringifyHelpers::FromString<DML_TENSOR_DATA_TYPE>(tensorDesc->dataType()->c_str());
+    if (isConstantTensor)
+    {
+        bufferTensorDesc.flags = DML_TENSOR_FLAG_OWNED_BY_DML;
+    }
+    bufferTensorDesc.sizes.assign(tensorDesc->sizes()->begin(), tensorDesc->sizes()->end());
+    if (flatbuffers::IsFieldPresent(tensorDesc, dml::ir::DmlBufferTensorDesc::VT_STRIDES))
+    {
+        bufferTensorDesc.strides.emplace(tensorDesc->strides()->begin(), tensorDesc->strides()->end());
+    }
+    bufferTensorDesc.totalTensorSizeInBytes = tensorDesc->totalTensorSizeInBytes();
+    return bufferTensorDesc;
+}
+
+AbstractOperatorDesc CreateAbstractOperatorDesc(
+    uint32_t nodeIndex,
+    const dml::ir::OperatorNodeDesc* flatbufferOperatorNodeDesc,
+    const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* nodeInputNames,
+    const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* nodeOutputNames,
+    const std::unordered_set<std::string_view>& constantInputs)
+{
+    DML_OPERATOR_TYPE type = ApiTraits::StringifyHelpers::FromString<DML_OPERATOR_TYPE>(flatbufferOperatorNodeDesc->type()->c_str());
+    if (type == DML_OPERATOR_INVALID)
+    {
+        throw std::invalid_argument("Graph operator node at index:" + std::to_string(nodeIndex) +
+                                    " either has empty or invalid operator type.");
+    }
+    const DML_OPERATOR_SCHEMA& schema = SchemaHelpers::GetSchema(type);
+    std::vector<OperatorField> operatorFields(schema.FieldCount);
+    
+    auto inputNameItr = nodeInputNames->begin();
+    uint32_t inputTensorDescIndex = 0;
+    
+    uint32_t outputTensorDescIndex = 0;
+    auto outputNameItr = nodeOutputNames->begin();
+
+    uint32_t attributeIndex = 0;
+    
+
+    for (uint32_t fieldIndex = 0; fieldIndex < schema.FieldCount; fieldIndex++)
+    {
+        const DML_SCHEMA_FIELD* schemaField = &schema.Fields[fieldIndex];
+        
+        OperatorFieldVariant field;
+        switch (schemaField->Kind)
+        {
+            case DML_SCHEMA_FIELD_KIND_INPUT_TENSOR:
+            {
+                if (inputNameItr == nodeInputNames->end())
+                {
+                    throw std::invalid_argument("Missing input names for node at index:" + std::to_string(nodeIndex));
+                }
+
+                if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC)
+                {
+                    const flatbuffers::String* inputName = *inputNameItr;
+                    inputNameItr++;
+                    if (inputName->size() == 0)
+                    {
+                        field = OperatorFieldTypes::TensorDesc();
+                        break;
+                    }
+                    bool isConstantTensor = !constantInputs.empty() && constantInputs.find(inputName->c_str()) != constantInputs.end();
+
+                    if (flatbufferOperatorNodeDesc->inputs()->size() <= inputTensorDescIndex)
+                    {
+                        throw std::invalid_argument("Expecting at least " + std::to_string(inputTensorDescIndex + 1) + 
+                                                    "input tensor desc for graph operator node at index:" + std::to_string(nodeIndex));
+                    }
+                    const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->inputs()->Get(inputTensorDescIndex++);
+                    field = CreateBufferTensorDesc(tensorDesc, isConstantTensor);
+                }
+                else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY)
+                {
+                    std::vector<DmlBufferTensorDesc> tensors;
+                    while (inputTensorDescIndex < static_cast<uint32_t>(flatbufferOperatorNodeDesc->inputs()->size()))
+                    {
+                        const flatbuffers::String* inputName = *inputNameItr;
+                        inputNameItr++;
+                        bool isConstantTensor = !constantInputs.empty() && constantInputs.find(inputName->c_str()) != constantInputs.end();
+                        
+                        if (flatbufferOperatorNodeDesc->inputs()->size() <= inputTensorDescIndex)
+                        {
+                            throw std::invalid_argument("Expecting at least " + std::to_string(inputTensorDescIndex + 1) + 
+                                                        "input tensor desc for graph operator node at index:" + std::to_string(nodeIndex));
+                        }
+                        const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->inputs()->Get(inputTensorDescIndex++);
+                        tensors.push_back(CreateBufferTensorDesc(tensorDesc, isConstantTensor).value());
+                    }
+                    field = tensors;
+                }
+                break;
+            }
+            case DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR:
+            {
+                if (outputNameItr == nodeOutputNames->end())
+                {
+                    throw std::invalid_argument("Missing output names for node at index:" + std::to_string(nodeIndex));
+                }
+
+                if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC)
+                {
+                    const flatbuffers::String* outputName = *outputNameItr;
+                    outputNameItr++;
+
+                    if (outputName->size() == 0)
+                    {
+                        field = OperatorFieldTypes::TensorDesc();
+                        break;
+                    }
+
+                    if (flatbufferOperatorNodeDesc->outputs()->size() <= outputTensorDescIndex)
+                    {
+                        throw std::invalid_argument("Expecting at least " + std::to_string(outputTensorDescIndex + 1) + 
+                                                    "output tensor desc for graph operator node at index:" + std::to_string(nodeIndex));
+                    }
+                    const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->outputs()->Get(outputTensorDescIndex++);
+                    field = CreateBufferTensorDesc(tensorDesc);
+                }
+                else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY)
+                {
+                    std::vector<DmlBufferTensorDesc> tensors;
+                    while (outputTensorDescIndex < static_cast<uint32_t>(flatbufferOperatorNodeDesc->outputs()->size()))
+                    {
+                        if (flatbufferOperatorNodeDesc->outputs()->size() <= outputTensorDescIndex)
+                        {
+                            throw std::invalid_argument("Expecting at least " + std::to_string(outputTensorDescIndex + 1) + 
+                                                        "output tensor desc for graph operator node at index:" + std::to_string(nodeIndex));
+                        }
+                        const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->outputs()->Get(outputTensorDescIndex++);
+                        tensors.push_back(CreateBufferTensorDesc(tensorDesc).value());
+                    }
+                    field = tensors;
+                }
+                break;
+            }
+            case DML_SCHEMA_FIELD_KIND_ATTRIBUTE:
+            {
+                if (flatbufferOperatorNodeDesc->attributes()->size() <= attributeIndex)
+                {
+                    throw std::invalid_argument("Expecting at least " + std::to_string(attributeIndex + 1) + 
+                                                "attributes for graph operator node at index:" + std::to_string(nodeIndex));
+                }
+                const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc = 
+                    attributeIndex >= flatbufferOperatorNodeDesc->attributes()->size() ?
+                    nullptr : 
+                    flatbufferOperatorNodeDesc->attributes()->Get(attributeIndex++);
+                field = CreateAttribute(schemaField, attributeDesc);
+                break;
+            }
+        }
+
+        operatorFields[fieldIndex] = OperatorField(schemaField, std::move(field));
+    }
+
+    return AbstractOperatorDesc(&schema, std::move(operatorFields));
+}
+
+std::unordered_map<std::string_view, uint32_t> ConvertToEdgeNameToIndexMap(
+    const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* list)
+{
+    std::unordered_map<std::string_view, uint32_t> nameToIndexMap;
+    for (uint32_t index = 0; index < list->size(); index++)
+    {
+        const flatbuffers::String* name = list->GetAsString(index);
+        if (name->size() == 0)
+        {
+            continue;
+        }
+        nameToIndexMap[name->string_view()] = index;
+    }
+    return nameToIndexMap; // NRVO will automatically move it. no need to use std::move
+}
+
+template <typename EdgeType> void PopulateEdges(
+    const uint32_t nodeIndex,
+    const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* edgeNames,
+    const std::unordered_map<std::string_view, uint32_t>& edgeNameToIndexMap,
+    /*out*/ std::vector<EdgeType>& edges,
+    /*out*/ std::vector<DmlIntermediateSerializedGraphEdge>& intermediateEdges,
+    /*out*/ std::unordered_map<std::string_view, NodeIndex>& edgeToOutgoingNodeIndexMap)
+{
+    for (flatbuffers::uoffset_t edgeIndex = 0; edgeIndex < edgeNames->size(); edgeIndex++)
+    {
+        const flatbuffers::String* edgeName = edgeNames->Get(edgeIndex);
+        if (edgeName->size() == 0)
+        {
+            // This must be optional input/output
+            continue;
+        }
+        // edge can be graphInput or graphOutput
+        if (edgeNameToIndexMap.find(edgeName->string_view()) != edgeNameToIndexMap.end())
+        {
+            EdgeType edge = {};
+            edge.Name = edgeName->str();
+            
+            if constexpr (std::is_same_v<EdgeType, DmlInputSerializedGraphEdge>)
+            {
+                edge.GraphInputIndex = edgeNameToIndexMap.at(edgeName->string_view());
+                edge.ToNodeIndex = nodeIndex;
+                edge.ToNodeInputIndex = edgeIndex;
+            }
+            else if constexpr (std::is_same_v<EdgeType, DmlOutputSerializedGraphEdge>)
+            {
+                edge.GraphOutputIndex = edgeNameToIndexMap.at(edgeName->string_view());
+                edge.FromNodeIndex = nodeIndex;
+                edge.FromNodeOutputIndex = edgeIndex;
+                edgeToOutgoingNodeIndexMap[edgeName->string_view()] = {nodeIndex, edgeIndex};
+            }
+
+            edges.push_back(edge);
+        }
+        // edge is intermediate edge
+        else 
+        {
+            if constexpr (std::is_same_v<EdgeType, DmlInputSerializedGraphEdge>)
+            {
+                if (edgeToOutgoingNodeIndexMap.find(edgeName->string_view()) == edgeToOutgoingNodeIndexMap.end())
+                {
+                    throw std::range_error("Neither there is any graph input with name " + edgeName->str() + 
+                                           "nor there is any node which has " + edgeName->str() + " as one of the output.");
+                }
+                auto& intermediateEdgeNodeIndex = edgeToOutgoingNodeIndexMap[edgeName->string_view()];
+                DmlIntermediateSerializedGraphEdge intermediateEdge = {};
+                intermediateEdge.Name = edgeName->str();
+                intermediateEdge.FromNodeIndex = intermediateEdgeNodeIndex.nodeIndex;
+                intermediateEdge.FromNodeOutputIndex = intermediateEdgeNodeIndex.nodeOutputIndex;
+                intermediateEdge.ToNodeIndex = nodeIndex;
+                intermediateEdge.ToNodeInputIndex = edgeIndex;
+                intermediateEdges.push_back(std::move(intermediateEdge));
+            }
+            else if constexpr (std::is_same_v<EdgeType, DmlOutputSerializedGraphEdge>)
+            {
+                edgeToOutgoingNodeIndexMap[edgeName->string_view()] = {nodeIndex, edgeIndex};
+            }
+        }
+    }
+}
+
+/*
+* - Handling of empty optional input/output/attibute for non-constant node:
+*   input/output
+*   - <DmlGraphNode.inputNames> and <DmlGraphNode.outputNames> will have an null entry
+*      but the actual OperatorNodeDesc variant's <OperatorNodeDesc.inputs> 
+*      and <OperatorNodeDesc.outputs> will not have any entry.
+*   attribute
+*   - <OperatorNodeDesc.attributes> will have null entry
+*/
+DmlSerializedGraphDesc DeserializeDmlGraph(
+    const uint8_t* flatbufferGraphDescBlob,
+    /*out*/ std::vector<std::unique_ptr<std::byte[]>>& rawData)
+{
+    if (flatbufferGraphDescBlob == nullptr)
+    {
+        throw std::invalid_argument("Given pointer to flatbuffer blob is null");
+    }
+    const dml::ir::DmlGraphDesc* flatbufferGraphDesc = dml::ir::GetDmlGraphDesc(flatbufferGraphDescBlob);
+    
+    std::unordered_map<std::string_view, uint32_t> graphInputEdgeToIndexMap = ConvertToEdgeNameToIndexMap(flatbufferGraphDesc->graphInputNames());
+    std::unordered_map<std::string_view, uint32_t> graphOutputEdgeToIndexMap = ConvertToEdgeNameToIndexMap(flatbufferGraphDesc->graphOutputNames());
+    
+    std::unordered_map<std::string_view, NodeIndex> edgeToOutgoingNodeIndexMap;
+    std::unordered_set<std::string_view> constantInputs;
+
+    std::vector<DmlSerializedGraphNode> nodes(flatbufferGraphDesc->nodes()->size());
+    std::vector<DmlInputSerializedGraphEdge> inputEdges;
+    std::vector<DmlOutputSerializedGraphEdge> outputEdges;
+    std::vector<DmlIntermediateSerializedGraphEdge> intermediateEdges;
+
+    for (uint32_t nodeIndex = 0; nodeIndex < flatbufferGraphDesc->nodes()->size(); nodeIndex++)
+    {
+        const dml::ir::DmlGraphNode* flatbufferNode = flatbufferGraphDesc->nodes()->Get(nodeIndex);
+
+        PopulateEdges<DmlInputSerializedGraphEdge>(
+            nodeIndex,
+            flatbufferNode->inputNames(),
+            graphInputEdgeToIndexMap,
+            inputEdges,
+            intermediateEdges,
+            edgeToOutgoingNodeIndexMap);
+        PopulateEdges<DmlOutputSerializedGraphEdge>(
+            nodeIndex,
+            flatbufferNode->outputNames(),
+            graphOutputEdgeToIndexMap,
+            outputEdges,
+            intermediateEdges,
+            edgeToOutgoingNodeIndexMap);
+
+        DmlSerializedGraphNode node = {};
+        if (flatbufferNode->name()->size() == 0)
+        {
+            throw std::invalid_argument("Graph node at index:" + std::to_string(nodeIndex) + " doesn't have any name");
+        }
+        node.Name = flatbufferNode->name()->c_str();
+
+        if (flatbufferNode->desc_type() == dml::ir::NodeDesc_ConstantNodeDesc)
+        {
+            const dml::ir::ConstantNodeDesc* flatbufferConstantNode = flatbufferNode->desc_as_ConstantNodeDesc();
+            if (flatbufferConstantNode->data_type() == dml::ir::ConstantNodeDescDetail_ConstantName)
+            {
+                if (flatbufferConstantNode->data_as_ConstantName()->name()->size() == 0)
+                {
+                    throw std::invalid_argument("Constant node at index:" + std::to_string(nodeIndex) + 
+                                                " doesn't have constant data name.");
+                }
+
+                ConstantName constantNode = {flatbufferConstantNode->data_as_ConstantName()->name()->c_str()};
+                node.Desc = constantNode;
+                // output of this node will part of constantInputs list
+                for (uint32_t outputIndex = 0; outputIndex < flatbufferNode->outputNames()->size(); outputIndex++)
+                {
+                    constantInputs.insert(flatbufferNode->outputNames()->Get(outputIndex)->c_str());
+                }
+            }
+            else if (flatbufferConstantNode->data_type() == dml::ir::ConstantNodeDescDetail_ConstantRawData)
+            {
+                
+                uint32_t rawDataSize = flatbufferConstantNode->data_as_ConstantRawData()->data()->size();
+                rawData.push_back(std::make_unique<std::byte[]>(rawDataSize));
+                std::transform(
+                    flatbufferConstantNode->data_as_ConstantRawData()->data()->begin(),
+                    flatbufferConstantNode->data_as_ConstantRawData()->data()->end(),
+                    rawData.back().get(),
+                    [](uint8_t b) {return static_cast<std::byte>(b);});
+
+                ConstantData constantData = {};
+                constantData.dataSize = rawDataSize;
+                constantData.data = rawData.back().get();
+                node.Desc = constantData;
+            }
+
+
+        }
+        else if (flatbufferNode->desc_type() == dml::ir::NodeDesc::NodeDesc_OperatorNodeDesc)
+        {
+            // convert dml::ir::OperatorNodeDesc to AbstractOperatorDesc
+            const dml::ir::OperatorNodeDesc* flatbufferOperatorNodeDesc = flatbufferNode->desc_as_OperatorNodeDesc();
+            node.Desc = CreateAbstractOperatorDesc(
+                nodeIndex,
+                flatbufferOperatorNodeDesc,
+                flatbufferNode->inputNames(),
+                flatbufferNode->outputNames(),
+                constantInputs);
+        }
+
+        nodes[nodeIndex] = node;
+    }
+
+    DmlSerializedGraphDesc graphDesc;
+    graphDesc.InputCount = flatbufferGraphDesc->graphInputNames()->size();
+    graphDesc.OutputCount = flatbufferGraphDesc->graphOutputNames()->size();
+    graphDesc.InputEdges = std::move(inputEdges);
+    graphDesc.IntermediateEdges = std::move(intermediateEdges);
+    graphDesc.OutputEdges = std::move(outputEdges);
+    graphDesc.Nodes = std::move(nodes);
+    return graphDesc;	
+}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp
index 642d9aa03eeef..202b762d99e01 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp
@@ -135,8 +135,10 @@ namespace DmlGraphFusionHelper
 
     void ProcessInputData(
         const ExecutionProviderImpl* providerImpl,
+        const bool graphSerializationEnabled,
         const std::vector<uint8_t>& isInputsUploadedByDmlEP,
-        const std::vector<DML_INPUT_GRAPH_EDGE_DESC>& inputEdges,
+        const std::unordered_map<uint32_t, uint32_t>* serializedGraphInputIndexToSubgraphInputIndex,
+        const std::unordered_map<std::string_view, uint32_t>* serializedGraphLargeConstantNameToSubgraphInputIndex,
         const gsl::span<const std::string> subGraphInputArgNames,
         const std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>& initializerNameToInitializerMap,
         onnxruntime::Graph& graph,
@@ -162,8 +164,17 @@ namespace DmlGraphFusionHelper
 
         // Walk through each graph edge and mark used inputs
         inputsUsed.assign(fusedNodeInputCount, false);
-        for (const DML_INPUT_GRAPH_EDGE_DESC& edge : inputEdges) {
-            inputsUsed[edge.GraphInputIndex] = true;
+        for (auto it = serializedGraphInputIndexToSubgraphInputIndex->begin(); it != serializedGraphInputIndexToSubgraphInputIndex->end(); it++) {
+            inputsUsed[it->second] = true;
+        }
+        for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex->begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex->end(); it++) {
+            inputsUsed[it->second] = true;
+        }
+
+        std::wstring modelName;
+        if (graphSerializationEnabled)
+        {
+            modelName = GetModelName(graph.ModelPath());
         }
 
         for (uint32_t i = 0; i < initInputBindings.size(); i++)
@@ -209,6 +220,10 @@ namespace DmlGraphFusionHelper
 
                 // Tensor sizes in DML must be a multiple of 4 bytes large.
                 tensorByteSize = AlignToPow2<size_t>(tensorByteSize, 4);
+                if(graphSerializationEnabled)
+                {
+                    WriteToFile(modelName, ConvertToWString(iter->first) + L".bin", reinterpret_cast<uint8_t*>(tensorPtr), tensorByteSize);
+                }
 
                 if (inputRawData)
                 {
@@ -287,55 +302,158 @@ namespace DmlGraphFusionHelper
         return initializerPartitionMap;
     }
 
+    inline uint32_t GetConstantNodeGraphInputIndex(
+        const std::string& constantName,
+        const std::unordered_map<std::string_view, uint32_t>* serializedGraphConstantNameToMainGraphInputIndex,
+        uint32_t& graphMaxInputIndex,
+        std::unordered_map<std::string_view, uint32_t>& localConstantNameToIndexMap)
+    {
+        if (serializedGraphConstantNameToMainGraphInputIndex == nullptr)
+        {
+            if (localConstantNameToIndexMap.find(constantName) == localConstantNameToIndexMap.end())
+            {
+                localConstantNameToIndexMap[constantName] = ++graphMaxInputIndex;
+            }
+            return localConstantNameToIndexMap[constantName];
+        }
+        else
+        {
+            graphMaxInputIndex = std::max(graphMaxInputIndex, serializedGraphConstantNameToMainGraphInputIndex->at(constantName));
+            return serializedGraphConstantNameToMainGraphInputIndex->at(constantName);
+        }
+    }
+
+    template <size_t AllocatorSize>
     void ConvertGraphDesc(
         const Dml::GraphDescBuilder::GraphDesc& graphDesc,
-        _Out_ DML_GRAPH_DESC& dmlGraphDesc,
         const uint32_t inputCount,
         const uint32_t outputCount,
-        _Inout_ std::vector<DML_OPERATOR_GRAPH_NODE_DESC>& dmlOperatorGraphNodes,
-        _Inout_ std::vector<DML_CONSTANT_DATA_GRAPH_NODE_DESC>& dmlConstantGraphNodes,
+        IDMLDevice* device,
+        StackAllocator<AllocatorSize>& allocator,
+        const std::unordered_map<uint32_t, uint32_t>* serializedGraphInputIndexToSubgraphInputIndex,
+        const std::unordered_map<std::string_view, uint32_t>* serializedGraphLargeConstantNameToSubgraphInputIndex,
+        _Out_ DML_GRAPH_DESC& dmlGraphDesc,
+        _Inout_ std::vector<ComPtr<IDMLOperator>>& dmlOperators,
         _Inout_ std::vector<DML_GRAPH_NODE_DESC>& dmlGraphNodes,
         _Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
         _Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
         _Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges)
     {
-        for (size_t i = 0; i < graphDesc.nodes.size(); ++i)
+        std::unordered_map<uint32_t, uint32_t> oldNodeIndexToNewNodeIndexMap;
+        for (uint32_t index = 0; index < static_cast<uint32_t>(graphDesc.Nodes.size()); index++)
         {
-            auto& nodeInfo = graphDesc.nodes[i];
-
-            if (std::holds_alternative<Microsoft::WRL::ComPtr<IDMLOperator>>(nodeInfo.nodeDef))
+            const DmlSerializedGraphNode& node = graphDesc.Nodes[index];
+            if (std::holds_alternative<AbstractOperatorDesc>(node.Desc))
             {
-                dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get<Microsoft::WRL::ComPtr<IDMLOperator>>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()};
-                dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
+                oldNodeIndexToNewNodeIndexMap[index] = static_cast<uint32_t>(dmlGraphNodes.size());
+                DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc<AllocatorSize>(std::get<AbstractOperatorDesc>(node.Desc), &allocator);
+                ComPtr<IDMLOperator> op;
+                ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op)));
+                dmlOperators.push_back(op);
+                DML_OPERATOR_GRAPH_NODE_DESC* dmlOperatorGraphNode = allocator.template Allocate<DML_OPERATOR_GRAPH_NODE_DESC>();
+                dmlOperatorGraphNode->Name = node.Name.data();
+                dmlOperatorGraphNode->Operator = op.Get();
+                dmlGraphNodes.push_back(DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, dmlOperatorGraphNode});
             }
             else
             {
-                auto& nodeDefinitionData = std::get<std::vector<uint8_t>>(nodeInfo.nodeDef);
-                dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC{
-                    nodeDefinitionData.data(),
-                    nodeDefinitionData.size(),
-                    nodeInfo.name.data()
-                };
-
-                // TODO: Change as new header is ingested
-                dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{static_cast<DML_GRAPH_NODE_TYPE>(2), &dmlConstantGraphNodes[i]};
+                auto& constantNodeVariant = std::get<DmlSerializedGraphNodeConstantVariant>(node.Desc);
+                if (std::holds_alternative<ConstantData>(constantNodeVariant))
+                {
+                    oldNodeIndexToNewNodeIndexMap[index] = static_cast<uint32_t>(dmlGraphNodes.size());
+
+                    auto& constantData = std::get<ConstantData>(constantNodeVariant);
+                    
+                    DML_CONSTANT_DATA_GRAPH_NODE_DESC* constantNode = allocator.template Allocate<DML_CONSTANT_DATA_GRAPH_NODE_DESC>();
+                    constantNode->Name = node.Name.data();
+                    constantNode->DataSize = constantData.dataSize;
+                    constantNode->Data = constantData.data;
+                    dmlGraphNodes.push_back(DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_CONSTANT, constantNode});
+                }
             }
         }
 
-        for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i)
+        uint32_t graphMaxInputIndex = 0;
+
+        for (size_t i = 0; i < graphDesc.InputEdges.size(); ++i)
         {
-            dmlInputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, &graphDesc.inputEdges[i]};
+            DML_INPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate<DML_INPUT_GRAPH_EDGE_DESC>();
+            // 1. If serializedGraphInputIndexToMainGraphInputIndex is not null:
+            //      then use the corresponding main graph input index, because the caller will use corresponding
+            //      main graph input index for extracting the actual input tensor from the main graph and
+            //      the caller does not own the creation of dml bindings directly.
+            //      Use Case: When the caller is ORT (DML EP) or DmlEngine.
+            //
+            // 2. If serializedGraphInputIndexToMainGraphInputIndex is null:
+            //      then assign the sequential graph input index, because it owns the creation of dml bindings
+            //      directly.
+            edge->GraphInputIndex = serializedGraphInputIndexToSubgraphInputIndex == nullptr ?
+                graphDesc.InputEdges[i].GraphInputIndex :
+                serializedGraphInputIndexToSubgraphInputIndex->at(graphDesc.InputEdges[i].GraphInputIndex);
+            edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.InputEdges[i].ToNodeIndex];
+            edge->ToNodeInputIndex = graphDesc.InputEdges[i].ToNodeInputIndex;
+            edge->Name = graphDesc.InputEdges[i].Name.data();
+
+            graphMaxInputIndex = std::max(graphMaxInputIndex, edge->GraphInputIndex);
+            dmlInputEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, edge});
         }
 
-        for (size_t i = 0; i < graphDesc.outputEdges.size(); ++i)
+        for (size_t i = 0; i < graphDesc.OutputEdges.size(); ++i)
         {
-            dmlOutputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, &graphDesc.outputEdges[i]};
+            DML_OUTPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate<DML_OUTPUT_GRAPH_EDGE_DESC>();
+            edge->GraphOutputIndex = graphDesc.OutputEdges[i].GraphOutputIndex;
+            edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.OutputEdges[i].FromNodeIndex];
+            edge->FromNodeOutputIndex = graphDesc.OutputEdges[i].FromNodeOutputIndex;
+            edge->Name = graphDesc.OutputEdges[i].Name.data();
+
+            dmlOutputEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, edge});
         }
 
-        for (size_t i = 0; i < graphDesc.intermediateEdges.size(); ++i)
+        std::unordered_map<std::string_view, uint32_t> localConstantNameToIndexMap;
+        for (uint32_t i = 0; i < static_cast<uint32_t>(graphDesc.IntermediateEdges.size()); ++i)
         {
-            dmlIntermediateEdges[i] =
-                DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graphDesc.intermediateEdges[i]};
+            DmlSerializedGraphNodeDescVariant descVariant = graphDesc.Nodes[graphDesc.IntermediateEdges[i].FromNodeIndex].Desc;
+            bool isConstantEdge = std::holds_alternative<DmlSerializedGraphNodeConstantVariant>(descVariant);
+            if (isConstantEdge)
+            {
+                auto& constantNodeVariant = std::get<DmlSerializedGraphNodeConstantVariant>(descVariant);
+                if (std::holds_alternative<ConstantData>(constantNodeVariant))
+                {
+                    DML_INTERMEDIATE_GRAPH_EDGE_DESC* edge = allocator.template Allocate<DML_INTERMEDIATE_GRAPH_EDGE_DESC>();
+                    edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].FromNodeIndex];
+                    edge->FromNodeOutputIndex = graphDesc.IntermediateEdges[i].FromNodeOutputIndex;
+                    edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex];
+                    edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex;
+                    edge->Name = graphDesc.IntermediateEdges[i].Name.data();
+                    dmlIntermediateEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, edge});
+                }
+                else
+                {
+                    const std::string& constantName = graphDesc.Nodes[graphDesc.IntermediateEdges[i].FromNodeIndex].Name;
+
+                    DML_INPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate<DML_INPUT_GRAPH_EDGE_DESC>();
+                    edge->GraphInputIndex = GetConstantNodeGraphInputIndex(
+                        constantName,
+                        serializedGraphLargeConstantNameToSubgraphInputIndex,
+                        graphMaxInputIndex,
+                        localConstantNameToIndexMap);
+                    edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex];
+                    edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex;
+                    edge->Name = graphDesc.IntermediateEdges[i].Name.data();
+
+                    dmlInputEdges.push_back({DML_GRAPH_EDGE_TYPE_INPUT, edge});
+                }
+            }
+            else
+            {
+                DML_INTERMEDIATE_GRAPH_EDGE_DESC* edge = allocator.template Allocate<DML_INTERMEDIATE_GRAPH_EDGE_DESC>();
+                edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].FromNodeIndex];
+                edge->FromNodeOutputIndex = graphDesc.IntermediateEdges[i].FromNodeOutputIndex;
+                edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex];
+                edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex;
+                edge->Name = graphDesc.IntermediateEdges[i].Name.data();
+                dmlIntermediateEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, edge});
+            }
         }
 
         dmlGraphDesc.InputCount = inputCount;
@@ -400,27 +518,34 @@ namespace DmlGraphFusionHelper
     Microsoft::WRL::ComPtr<IDMLCompiledOperator> TryCreateCompiledOperator(
         const GraphDescBuilder::GraphDesc& graphDesc,
         const onnxruntime::IndexedSubGraph& indexedSubGraph,
-        const ExecutionProviderImpl* providerImpl)
+        const ExecutionProviderImpl* providerImpl,
+        const std::unordered_map<uint32_t, uint32_t>* serializedGraphInputIndexToSubgraphInputIndex,
+        const std::unordered_map<std::string_view, uint32_t>* serializedGraphLargeConstantNameToSubgraphInputIndex)
     {
         const uint32_t fusedNodeInputCount = gsl::narrow_cast<uint32_t>(indexedSubGraph.GetMetaDef()->inputs.size());
         const uint32_t fusedNodeOutputCount = gsl::narrow_cast<uint32_t>(indexedSubGraph.GetMetaDef()->outputs.size());
 
         // convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator
-        DML_GRAPH_DESC dmlGraphDesc = {};
-        std::vector<DML_OPERATOR_GRAPH_NODE_DESC> dmlOperatorGraphNodes(graphDesc.nodes.size());
-        std::vector<DML_CONSTANT_DATA_GRAPH_NODE_DESC> dmlConstantGraphNodes(graphDesc.nodes.size());
+        ComPtr<IDMLDevice> device;
+        ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf()));
 
-        std::vector<DML_GRAPH_NODE_DESC> dmlGraphNodes(graphDesc.nodes.size());
-        std::vector<DML_GRAPH_EDGE_DESC> dmlInputEdges(graphDesc.inputEdges.size());
-        std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges(graphDesc.outputEdges.size());
-        std::vector<DML_GRAPH_EDGE_DESC> dmlIntermediateEdges(graphDesc.intermediateEdges.size());
+        StackAllocator<1024> allocator;
+        DML_GRAPH_DESC dmlGraphDesc = {};
+        std::vector<ComPtr<IDMLOperator>> dmlOperators;
+        std::vector<DML_GRAPH_NODE_DESC> dmlGraphNodes;
+        std::vector<DML_GRAPH_EDGE_DESC> dmlInputEdges;
+        std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges;
+        std::vector<DML_GRAPH_EDGE_DESC> dmlIntermediateEdges;
         ConvertGraphDesc(
             graphDesc,
-            dmlGraphDesc,
             fusedNodeInputCount,
             fusedNodeOutputCount,
-            dmlOperatorGraphNodes,
-            dmlConstantGraphNodes,
+            device.Get(),
+            allocator,
+            serializedGraphInputIndexToSubgraphInputIndex,
+            serializedGraphLargeConstantNameToSubgraphInputIndex,
+            dmlGraphDesc,
+            dmlOperators,
             dmlGraphNodes,
             dmlInputEdges,
             dmlOutputEdges,
@@ -438,8 +563,6 @@ namespace DmlGraphFusionHelper
             executionFlags |= DML_EXECUTION_FLAG_DISABLE_META_COMMANDS;
         }
 
-        ComPtr<IDMLDevice> device;
-        ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf()));
 
         ComPtr<IDMLDevice1> device1;
         ORT_THROW_IF_FAILED(device.As(&device1));
@@ -460,6 +583,7 @@ namespace DmlGraphFusionHelper
     }
 
     void FusePartitionAndRegisterKernel(
+        const uint32_t partitionIndex,
         onnxruntime::Graph& graph,
         onnxruntime::KernelRegistry* registryForPartitionKernels,
         const std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>& initializerNameToInitializerMap,
@@ -467,8 +591,43 @@ namespace DmlGraphFusionHelper
         const onnxruntime::IndexedSubGraph& indexedSubGraph,
         std::vector<uint8_t>&& isInputsUploadedByDmlEP,
         const GraphDescBuilder::GraphDesc& graphDesc,
-        Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiledExecutionPlanOperator)
+        Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiledExecutionPlanOperator,
+        const bool graphSerializationEnabled,
+        const std::unordered_map<uint32_t, uint32_t>* serializedGraphInputIndexToSubgraphInputIndex,
+        const std::unordered_map<std::string_view, uint32_t>* serializedGraphLargeConstantNameToSubgraphInputIndex)
     {
+      if (graphSerializationEnabled)
+      {
+
+        const std::wstring modelName = GetModelName(graph.ModelPath());
+        auto buffer = SerializeDmlGraph(graphDesc);
+
+        const std::wstring partitionName =
+            L"Partition_" +
+            std::to_wstring(partitionIndex) +
+            L".bin";
+        WriteToFile(modelName, partitionName, buffer.data(), buffer.size());
+
+        std::vector<std::unique_ptr<std::byte[]>> rawData;
+        DmlSerializedGraphDesc deserializedGraphDesc = DeserializeDmlGraph(buffer.data(), rawData);
+        GraphDescBuilder::GraphDesc deserializedDmlGraphDesc = {};
+        deserializedDmlGraphDesc.InputCount = deserializedGraphDesc.InputCount;
+        deserializedDmlGraphDesc.InputEdges = std::move(deserializedGraphDesc.InputEdges);
+        deserializedDmlGraphDesc.IntermediateEdges = std::move(deserializedGraphDesc.IntermediateEdges);
+        deserializedDmlGraphDesc.Nodes = std::move(deserializedGraphDesc.Nodes);
+        deserializedDmlGraphDesc.OutputCount = deserializedGraphDesc.OutputCount;
+        deserializedDmlGraphDesc.OutputEdges = std::move(deserializedGraphDesc.OutputEdges);
+        deserializedDmlGraphDesc.reuseCommandList = graphDesc.reuseCommandList;
+        deserializedDmlGraphDesc.outputShapes = graphDesc.outputShapes;
+
+        compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator(
+                        deserializedDmlGraphDesc,
+                        indexedSubGraph,
+                        providerImpl,
+                        serializedGraphInputIndexToSubgraphInputIndex,
+                        serializedGraphLargeConstantNameToSubgraphInputIndex);
+      }
+
         auto& fusedNode = graph.BeginFuseSubGraph(indexedSubGraph, indexedSubGraph.GetMetaDef()->name);
         fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider);
 
@@ -482,8 +641,10 @@ namespace DmlGraphFusionHelper
         std::vector<bool> inputsUsed;
         ProcessInputData(
             providerImpl,
+            graphSerializationEnabled,
             isInputsUploadedByDmlEP,
-            graphDesc.inputEdges,
+            serializedGraphInputIndexToSubgraphInputIndex,
+            serializedGraphLargeConstantNameToSubgraphInputIndex,
             indexedSubGraph.GetMetaDef()->inputs,
             initializerNameToInitializerMap,
             graph,
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h
index f8f6162aaa1e0..f1e9654021196 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h
@@ -45,12 +45,17 @@ namespace DmlGraphFusionHelper
         gsl::span<std::unique_ptr<GraphPartition>> partitions
     );
 
+    template <size_t AllocatorSize>
     void ConvertGraphDesc(
         const Dml::GraphDescBuilder::GraphDesc& graphDesc,
-        _Out_ DML_GRAPH_DESC& dmlGraphDesc,
         const uint32_t inputCount,
         const uint32_t outputCount,
-        _Inout_ std::vector<DML_OPERATOR_GRAPH_NODE_DESC>& dmlOperatorGraphNodes,
+        IDMLDevice* device,
+        StackAllocator<AllocatorSize>& allocator,
+        const std::unordered_map<uint32_t, uint32_t>* serializedGraphInputIndexToSubgraphInputIndex,
+        const std::unordered_map<std::string_view, uint32_t>* serializedGraphLargeConstantNameToSubgraphInputIndex,
+        _Out_ DML_GRAPH_DESC& dmlGraphDesc,
+        _Inout_ std::vector<ComPtr<IDMLOperator>>& dmlOperators,
         _Inout_ std::vector<DML_GRAPH_NODE_DESC>& dmlGraphNodes,
         _Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
         _Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
@@ -69,9 +74,12 @@ namespace DmlGraphFusionHelper
     Microsoft::WRL::ComPtr<IDMLCompiledOperator> TryCreateCompiledOperator(
         const GraphDescBuilder::GraphDesc& graphDesc,
         const onnxruntime::IndexedSubGraph& indexedSubGraph,
-        const ExecutionProviderImpl* providerImpl);
+        const ExecutionProviderImpl* providerImpl,
+        const std::unordered_map<uint32_t, uint32_t>* serializedGraphInputIndexToSubgraphInputIndex,
+        const std::unordered_map<std::string_view, uint32_t>* serializedGraphLargeConstantNameToSubgraphInputIndex);
 
     void FusePartitionAndRegisterKernel(
+        const uint32_t partitionIndex,
         onnxruntime::Graph& graph,
         onnxruntime::KernelRegistry* registryForPartitionKernels,
         const std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>& initializerNameToInitializerMap,
@@ -79,7 +87,10 @@ namespace DmlGraphFusionHelper
         const onnxruntime::IndexedSubGraph& indexedSubGraph,
         std::vector<uint8_t>&& isInputsUploadedByDmlEP,
         const GraphDescBuilder::GraphDesc& graphDesc,
-        Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiledExecutionPlanOperator);
+        Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiledExecutionPlanOperator,
+        const bool graphSerializationEnabled,
+        const std::unordered_map<uint32_t, uint32_t>* serializedGraphInputIndexToSubgraphInputIndex = nullptr,
+        const std::unordered_map<std::string_view, uint32_t>* serializedGraphLargeConstantNameToSubgraphInputIndex = nullptr);
 
     void RegisterDynamicKernel(
         onnxruntime::Graph& graph,
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp
index 679738b639ec9..35a2c451a49a5 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp
@@ -24,15 +24,20 @@ namespace Dml
             std::vector<uint8_t> isInputsUploadedByDmlEP;
             GraphDescBuilder::GraphDesc graphDesc;
             std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>> isInitializerTransferable;
+            std::vector<std::unique_ptr<std::byte[]>> smallConstantData; // Need to keep it alive for maintaining lifetime
+            std::unordered_map<uint32_t, uint32_t> serializedGraphInputIndexToSubgraphInputIndex;
+            std::unordered_map<std::string_view, uint32_t> serializedGraphLargeConstantNameToSubgraphInputIndex;
         };
     }
 
     DmlGraphFusionTransformer::DmlGraphFusionTransformer(
         const std::string& name,
-        const onnxruntime::IExecutionProvider* provider
+        const onnxruntime::IExecutionProvider* provider,
+        const bool graphSerializationEnabled
     )
         :onnxruntime::GraphTransformer(name),
-         m_providerImpl(static_cast<const ExecutionProvider*>(provider)->GetImpl())
+         m_providerImpl(static_cast<const ExecutionProvider*>(provider)->GetImpl()),
+         graphSerializationEnabled(graphSerializationEnabled)
     {
     }
 
@@ -227,23 +232,39 @@ namespace Dml
 
                     ComPtr<IDMLDevice> device;
                     ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf()));
+                    // This map will be used to transfer the initializer to D3D12 system heap memory.
+                    // 'serializedDmlGraphDesc' will have constant input as intermediate edges, that's why
+                    // we need a mapping between intermediateEdgeIndex and indexedSubGraph's (a given partition)
+                    // input arg index.
+                    //   For ex: Let's say intermediate edge index = idx, then
+                    //           indexedSubGraphInputArgIdx = constantEdgeIdxToSubgraphInputArgIdxMap[idx];
+                    //           corresponding constant tensor = initializerNameToInitializerMap[indexedSubGraph.GetMetaDef()->inputs[indexedSubGraphInputArgIdx]]
+                    // We are using intermediate edge index as a key because same constant tensor can be used by
+                    // multiple nodes.
+                    std::unordered_map<uint32_t, uint32_t> serializedGraphInputIndexToSubgraphInputIndex;
+                    std::unordered_map<std::string_view, uint32_t> serializedGraphLargeConstantNameToSubgraphInputIndex;
+                    std::vector<std::unique_ptr<std::byte[]>> smallConstantData;
                     GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc(
                         isInputsUploadedByDmlEP.data(),
                         isInputsUploadedByDmlEP.size(),
                         isInitializerTransferable,
                         partitionNodePropsMap,
-                        device.Get(),
                         m_providerImpl,
                         modelPath,
                         subgraphNodes,
                         subgraphInputs,
-                        subgraphOutputs);
+                        subgraphOutputs,
+                        serializedGraphInputIndexToSubgraphInputIndex,
+                        serializedGraphLargeConstantNameToSubgraphInputIndex,
+                        smallConstantData);
 
                     // Compile the operator
                     auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator(
                         graphDesc,
                         indexedSubGraph,
-                        m_providerImpl);
+                        m_providerImpl,
+                        &serializedGraphInputIndexToSubgraphInputIndex,
+                        &serializedGraphLargeConstantNameToSubgraphInputIndex);
 
                     if (!compiledPartition)
                     {
@@ -264,6 +285,9 @@ namespace Dml
                         compiledPartitionInfo->isInputsUploadedByDmlEP = std::move(isInputsUploadedByDmlEP);
                         compiledPartitionInfo->graphDesc = std::move(graphDesc);
                         compiledPartitionInfo->isInitializerTransferable = std::move(isInitializerTransferable);
+                        compiledPartitionInfo->smallConstantData = std::move(smallConstantData);
+                        compiledPartitionInfo->serializedGraphInputIndexToSubgraphInputIndex = std::move(serializedGraphInputIndexToSubgraphInputIndex);
+                        compiledPartitionInfo->serializedGraphLargeConstantNameToSubgraphInputIndex = std::move(serializedGraphLargeConstantNameToSubgraphInputIndex);
                         compiledPartitionInfos[partitionIndex] = std::move(compiledPartitionInfo);
                     }
                 }
@@ -271,12 +295,14 @@ namespace Dml
         }
         while (!additionalSplittingNodes.empty());
 
+        uint32_t partitionIndex = 0;
         for (auto&& compiledPartitionInfo : compiledPartitionInfos)
         {
             // Null compiled operators were not DML partitions
             if (compiledPartitionInfo)
             {
                 DmlGraphFusionHelper::FusePartitionAndRegisterKernel(
+                    partitionIndex++,
                     graph,
                     m_providerImpl->GetKernelRegistry().get(),
                     compiledPartitionInfo->isInitializerTransferable,
@@ -284,7 +310,10 @@ namespace Dml
                     compiledPartitionInfo->indexedSubGraph,
                     std::move(compiledPartitionInfo->isInputsUploadedByDmlEP),
                     compiledPartitionInfo->graphDesc,
-                    compiledPartitionInfo->compiledOperator);
+                    compiledPartitionInfo->compiledOperator,
+                    graphSerializationEnabled,
+                    &compiledPartitionInfo->serializedGraphInputIndexToSubgraphInputIndex,
+                    &compiledPartitionInfo->serializedGraphLargeConstantNameToSubgraphInputIndex);
             }
         }
 
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h
index 19dab0c89943c..b370f3ef9043c 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h
@@ -16,7 +16,8 @@ class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer
 public:
     DmlGraphFusionTransformer(
         const std::string& name,
-        const onnxruntime::IExecutionProvider* provider
+        const onnxruntime::IExecutionProvider* provider,
+        const bool graphSerializationEnabled
     );
 
 public:
@@ -38,5 +39,6 @@ class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer
 
 private:
     const ExecutionProviderImpl* m_providerImpl = nullptr;
+    const bool graphSerializationEnabled = false;
 };
 }
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp
new file mode 100644
index 0000000000000..5355964e8db74
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp
@@ -0,0 +1,580 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+
+#pragma once
+#include "precomp.h"
+
+template <typename T>
+T* ReadAs(uint8_t* base, size_t byteOffset)
+{
+    return reinterpret_cast<T*>(base + byteOffset);
+}
+
+void SerializeAttributeDescs(
+    flatbuffers::FlatBufferBuilder& builder,
+    const AbstractOperatorDesc& operatorDesc,
+    /*out*/ std::vector<flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>>& attributeDescs);
+
+flatbuffers::Offset<dml::ir::operatorFieldTypes::Activation> serializeActivation(
+    flatbuffers::FlatBufferBuilder& builder,
+    const AbstractOperatorDesc& activationOperatorDesc)
+{
+    std::vector<flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>> attributeDescs;
+    SerializeAttributeDescs(builder, activationOperatorDesc, attributeDescs);
+    
+    flatbuffers::Offset<dml::ir::operatorFieldTypes::Activation> offset = dml::ir::operatorFieldTypes::CreateActivationDirect(
+        builder,
+        activationOperatorDesc.schema->OperatorName,
+        &attributeDescs);
+    return offset;
+}
+
+void SerializeAttributeDescs(
+    flatbuffers::FlatBufferBuilder& builder,
+    const AbstractOperatorDesc& operatorDesc,
+    /*out*/ std::vector<flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>>& attributeDescs)
+{
+    for (const OperatorField& field : operatorDesc.fields)
+    {
+        if (field.GetSchema()->Kind == DML_SCHEMA_FIELD_KIND_INPUT_TENSOR || 
+            field.GetSchema()->Kind == DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR)
+        {
+            continue;
+        }
+
+        flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc> offset;
+
+        if (std::holds_alternative<OperatorFieldTypes::FusedActivationOperatorDesc>(field.GetData()))
+        {
+            const OperatorFieldTypes::FusedActivationOperatorDesc& fusedActivation = field.AsFusedActivationOperatorDesc();
+            if (!fusedActivation.has_value())
+            {
+                offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                    builder,
+                    nullptr,
+                    dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation);
+            }
+            else
+            {
+                offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                    builder,
+                    field.GetSchema()->Name,
+                    dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation,
+                    serializeActivation(builder, fusedActivation.value()).Union());
+            }
+        }
+        else if (std::holds_alternative<OperatorFieldTypes::FusedActivationOperatorDescArray>(field.GetData()))
+        {
+            const OperatorFieldTypes::FusedActivationOperatorDescArray& fusedActivations = 
+                field.AsFusedActivationOperatorDescArray();
+            if (!fusedActivations.has_value())
+            {
+                offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                    builder,
+                    nullptr,
+                    dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray);
+            }
+            else
+            {
+                std::vector<flatbuffers::Offset<dml::ir::operatorFieldTypes::Activation>> fbActivations;
+
+                for (AbstractOperatorDesc activationOpDesc : fusedActivations.value())
+                {
+                    flatbuffers::Offset<dml::ir::operatorFieldTypes::Activation> fbActivation = 
+                        serializeActivation(builder, activationOpDesc);
+                    fbActivations.push_back(fbActivation);
+                }
+
+                flatbuffers::Offset<dml::ir::operatorFieldTypes::ActivationArray> activationOffset = 
+                    dml::ir::operatorFieldTypes::CreateActivationArrayDirect(builder, &fbActivations);
+                
+                offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                    builder,
+                    field.GetSchema()->Name,
+                    dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray,
+                    activationOffset.Union());
+            }
+        }
+        else if (std::holds_alternative<OperatorFieldTypes::UInt>(field.GetData()))
+        {
+            offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                builder,
+                field.GetSchema()->Name,
+                dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt32,
+                builder.CreateStruct(dml::ir::operatorFieldTypes::UInt32(field.AsUInt())).Union());
+        }
+        else if (std::holds_alternative<OperatorFieldTypes::UInt64>(field.GetData()))
+        {
+            offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                builder,
+                field.GetSchema()->Name,
+                dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt64,
+                builder.CreateStruct(dml::ir::operatorFieldTypes::UInt64(field.AsUInt64())).Union());
+        }
+        else if (std::holds_alternative<OperatorFieldTypes::Int>(field.GetData()))
+        {
+            offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                builder,
+                field.GetSchema()->Name,
+                dml::ir::operatorFieldTypes::AttributeFieldVariant_Int32,
+                builder.CreateStruct(dml::ir::operatorFieldTypes::Int32(field.AsInt())).Union());
+        }
+        else if (std::holds_alternative<OperatorFieldTypes::Float>(field.GetData()))
+        {
+            offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                builder,
+                field.GetSchema()->Name,
+                dml::ir::operatorFieldTypes::AttributeFieldVariant_Float32,
+                builder.CreateStruct(dml::ir::operatorFieldTypes::Float32(field.AsFloat())).Union());
+        }
+        else if (std::holds_alternative<OperatorFieldTypes::UIntArray>(field.GetData()))
+        {
+            offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                builder,
+                field.GetSchema()->Name,
+                dml::ir::operatorFieldTypes::AttributeFieldVariant_UIntArray,
+                dml::ir::operatorFieldTypes::CreateUIntArray(builder, builder.CreateVector(field.AsUIntArray())).Union());
+        }
+        else if (std::holds_alternative<OperatorFieldTypes::IntArray>(field.GetData()))
+        {
+            offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                builder,
+                field.GetSchema()->Name,
+                dml::ir::operatorFieldTypes::AttributeFieldVariant_IntArray,
+                dml::ir::operatorFieldTypes::CreateIntArray(builder, builder.CreateVector(field.AsIntArray())).Union());
+        }
+        else if (std::holds_alternative<OperatorFieldTypes::FloatArray>(field.GetData()))
+        {
+            offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                builder,
+                field.GetSchema()->Name,
+                dml::ir::operatorFieldTypes::AttributeFieldVariant_FloatArray,
+                dml::ir::operatorFieldTypes::CreateFloatArray(builder, builder.CreateVector(field.AsFloatArray())).Union());
+        }
+        else if (std::holds_alternative<OperatorFieldTypes::ScaleBias>(field.GetData()))
+        {
+            const OperatorFieldTypes::ScaleBias& scaleBias = field.AsScaleBias();
+            if (!scaleBias.has_value())
+            {
+                offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                    builder,
+                    nullptr,
+                    dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias);
+            }
+            else
+            {
+                dml::ir::operatorFieldTypes::ScaleBias fbScaleBias(scaleBias.value().Scale, scaleBias.value().Bias);
+                offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                    builder,
+                    field.GetSchema()->Name,
+                    dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias,
+                    builder.CreateStruct(fbScaleBias).Union());
+            }
+        }
+        else if (std::holds_alternative<OperatorFieldTypes::Size2D>(field.GetData()))
+        {
+            const DML_SIZE_2D size2d = field.AsSize2D();
+            offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                builder,
+                field.GetSchema()->Name,
+                dml::ir::operatorFieldTypes::AttributeFieldVariant_Size2D,
+                builder.CreateStruct(dml::ir::operatorFieldTypes::Size2D(size2d.Width, size2d.Height)).Union());
+        }
+        else if (std::holds_alternative<OperatorFieldTypes::ScalarUnion>(field.GetData()))
+        {
+            OperatorFieldTypes::ScalarUnion scalarUnion = field.AsScalarUnion();
+            dml::ir::operatorFieldTypes::ByteArray byteArr;
+            for (uint32_t index = 0; index < static_cast<uint32_t>(sizeof(scalarUnion.Bytes)); index++)
+            {
+                byteArr.mutable_data()->Mutate(index, scalarUnion.Bytes[index]);
+            }
+
+            flatbuffers::Offset<dml::ir::operatorFieldTypes::ScalarUnionData> scalarUnionOffset = 
+                dml::ir::operatorFieldTypes::CreateScalarUnionData(
+                    builder,
+                    dml::ir::operatorFieldTypes::ScalarVariant_ByteArray,
+                    builder.CreateStruct(byteArr).Union());
+
+            offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                builder,
+                field.GetSchema()->Name,
+                dml::ir::operatorFieldTypes::AttributeFieldVariant_ScalarUnionData,
+                scalarUnionOffset.Union());
+        }
+        else if (std::holds_alternative<OperatorFieldTypes::Bool>(field.GetData()))
+        {
+            offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect(
+                builder,
+                field.GetSchema()->Name,
+                dml::ir::operatorFieldTypes::AttributeFieldVariant_Bool,
+                builder.CreateStruct(dml::ir::operatorFieldTypes::Bool(field.AsBool())).Union());
+        }
+        else
+        {
+            continue;
+        }
+        
+        attributeDescs.push_back(offset);
+    }
+}
+
+flatbuffers::Offset<dml::ir::DmlBufferTensorDesc> SerializeDmlTensorDesc(
+    flatbuffers::FlatBufferBuilder& builder,
+    const DmlBufferTensorDesc* tensorDesc)
+{
+    const std::vector<uint32_t> *strides = nullptr;
+    if (tensorDesc->strides.has_value())
+    {
+        strides = &tensorDesc->strides.value();
+    }
+    
+    flatbuffers::Offset<dml::ir::DmlBufferTensorDesc> offset = dml::ir::CreateDmlBufferTensorDescDirect(
+        builder,
+        ApiTraits::StringifyHelpers::ToString(tensorDesc->dataType),
+        &tensorDesc->sizes,
+        strides,
+        tensorDesc->totalTensorSizeInBytes);
+    return offset;
+}
+
+flatbuffers::Offset<void> SerializeOperatorNodeDesc(
+    flatbuffers::FlatBufferBuilder& builder,
+    const AbstractOperatorDesc& operatorDesc)
+{
+    const DML_OPERATOR_SCHEMA* operatorSchema = operatorDesc.schema;
+
+    std::vector<flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>> inputTensorDescs;
+    std::vector<flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>> outputTensorDescs;
+    
+    for (const DmlBufferTensorDesc* tensorDesc : operatorDesc.GetInputTensors())
+    {
+        if (tensorDesc == nullptr)
+        {
+            continue;
+        }
+        flatbuffers::Offset<dml::ir::DmlBufferTensorDesc> serializedDmlTensorDesc = SerializeDmlTensorDesc(builder, tensorDesc);
+        inputTensorDescs.push_back(serializedDmlTensorDesc);
+    }
+    
+    for (const DmlBufferTensorDesc* tensorDesc : operatorDesc.GetOutputTensors())
+    {
+        if (tensorDesc == nullptr)
+        {
+            continue;
+        }
+        flatbuffers::Offset<dml::ir::DmlBufferTensorDesc> serializedDmlTensorDesc = SerializeDmlTensorDesc(builder, tensorDesc);
+        outputTensorDescs.push_back(serializedDmlTensorDesc);
+    }
+    
+    std::vector<flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>> attributeDescs;
+    SerializeAttributeDescs(builder, operatorDesc, attributeDescs);
+    
+    flatbuffers::Offset<dml::ir::OperatorNodeDesc> offset = dml::ir::CreateOperatorNodeDesc(
+        builder,
+        builder.CreateString(operatorSchema->OperatorName),
+        builder.CreateVector(inputTensorDescs),
+        builder.CreateVector(outputTensorDescs),
+        builder.CreateVector(attributeDescs));
+    return offset.Union();
+}
+
+flatbuffers::Offset<void> SerializeConstantNodeDesc(
+    flatbuffers::FlatBufferBuilder& builder,
+    uint32_t nodeIndex,
+    const DmlSerializedGraphNodeConstantVariant& constantNodeDesc)
+{
+    flatbuffers::Offset<dml::ir::ConstantNodeDesc> offset;
+    
+    if (std::holds_alternative<ConstantName>(constantNodeDesc))
+    {
+        auto& constantName = std::get<ConstantName>(constantNodeDesc);
+        if (constantName.name.empty())
+        {
+            throw std::invalid_argument("Graph constant node at index:" + std::to_string(nodeIndex) +
+                                        " doesn't have the constant data name.");
+        }
+
+        flatbuffers::Offset<dml::ir::ConstantName> constantNameOffset = dml::ir::CreateConstantName(
+            builder, 
+            builder.CreateString(constantName.name));
+
+        offset = dml::ir::CreateConstantNodeDesc(
+            builder,
+            dml::ir::ConstantNodeDescDetail_ConstantName,
+            constantNameOffset.Union());
+    }
+    else
+    {
+        auto& constantData = std::get<ConstantData>(constantNodeDesc);
+        std::vector<uint8_t> rawBytes;
+        std::transform(constantData.data, constantData.data + constantData.dataSize, 
+                       std::back_inserter(rawBytes), [](std::byte b) {return static_cast<uint8_t>(b); });
+        flatbuffers::Offset<dml::ir::ConstantRawData> constantDataOffset = dml::ir::CreateConstantRawDataDirect(
+            builder,
+            &rawBytes);
+
+        offset = dml::ir::CreateConstantNodeDesc(
+            builder,
+            dml::ir::ConstantNodeDescDetail_ConstantRawData,
+            constantDataOffset.Union());
+    }
+    
+    return offset.Union();
+}
+
+flatbuffers::Offset<dml::ir::DmlGraphNode> SerializeNode(
+    flatbuffers::FlatBufferBuilder& builder,
+    const uint32_t nodeIndex,
+    const DmlSerializedGraphNode& graphNode,
+    const std::vector<flatbuffers::Offset<flatbuffers::String>>& nodeInputNames,
+    const std::vector<flatbuffers::Offset<flatbuffers::String>>& nodeOutputNames)
+{
+    if (graphNode.Name.empty())
+    {        
+        throw std::invalid_argument("Graph node at index:" + std::to_string(nodeIndex) + 
+                                    " does not have any name.");
+    }
+
+    flatbuffers::Offset<dml::ir::DmlGraphNode> offset;
+    if (std::holds_alternative<AbstractOperatorDesc>(graphNode.Desc))
+    {
+        auto& operatorNode = std::get<AbstractOperatorDesc>(graphNode.Desc);
+        offset = dml::ir::CreateDmlGraphNode(
+            builder,
+            dml::ir::NodeDesc_OperatorNodeDesc,
+            SerializeOperatorNodeDesc(builder, operatorNode),
+            builder.CreateString(graphNode.Name),
+            builder.CreateVector(nodeInputNames),
+            builder.CreateVector(nodeOutputNames));
+    }
+    else
+    {
+        auto& constantNodeVariant = std::get<DmlSerializedGraphNodeConstantVariant>(graphNode.Desc);
+        offset = dml::ir::CreateDmlGraphNode(
+            builder,
+            dml::ir::NodeDesc_ConstantNodeDesc,
+            SerializeConstantNodeDesc(builder, nodeIndex, constantNodeVariant),
+            builder.CreateString(graphNode.Name),
+            builder.CreateVector(nodeInputNames),
+            builder.CreateVector(nodeOutputNames));
+    }
+    return offset;
+}
+
+/*
+* validates input/output edges and throws exception if an edge 
+* does not have a name or if an edge has more than 1 names.
+*/
+template <typename Edge>
+std::unordered_map<uint32_t, flatbuffers::Offset<flatbuffers::String>> ConvertToEdgeIndexToNameMap(
+    const std::vector<Edge>& edges,
+    flatbuffers::FlatBufferBuilder& builder)
+{
+    std::unordered_map<uint32_t, flatbuffers::Offset<flatbuffers::String>> edgeIndexToNameMap;
+    for (auto& edge : edges)
+    {
+        uint32_t index;
+        if constexpr (std::is_same_v<Edge, DmlInputSerializedGraphEdge>)
+        {
+            index = edge.GraphInputIndex;
+        }
+        else if constexpr (std::is_same_v<Edge, DmlOutputSerializedGraphEdge>)
+        {
+            index = edge.GraphOutputIndex;
+        }
+        
+        if (edge.Name.empty())
+        {
+            throw std::invalid_argument("Graph input or output edge at index " + std::to_string(index) + " does not have name.");
+        }
+
+        if (edgeIndexToNameMap.find(index) != edgeIndexToNameMap.end())
+        {
+            flatbuffers::String* edgeName = ReadAs<flatbuffers::String>(
+                builder.GetCurrentBufferPointer(),
+                builder.GetSize() - edgeIndexToNameMap[index].o);
+            if (edge.Name != edgeName->str())
+            {
+                throw std::invalid_argument("Graph input or output edge at index " + std::to_string(index) + " has more than 1 names.");
+            }
+        }
+
+        edgeIndexToNameMap[index] = builder.CreateString(edge.Name);
+    }
+    return edgeIndexToNameMap; // NRVO will automatically move it. no need to use std::move
+}
+
+void PopulateNonConstantNodeInputOutputCount(
+    const std::vector<DmlSerializedGraphNode>& nodes,
+    /*out*/ std::vector<uint32_t>& nodeInputCounts,
+    /*out*/ std::vector<uint32_t>& nodeOutputCounts)
+{
+    for (uint32_t nodeIndex = 0; nodeIndex < static_cast<uint32_t>(nodes.size()); nodeIndex++)
+    {
+        auto& node = nodes[nodeIndex];
+        if (std::holds_alternative<AbstractOperatorDesc>(node.Desc))
+        {
+            auto& operatorNode = std::get<AbstractOperatorDesc>(node.Desc);
+            nodeInputCounts[nodeIndex] = std::max(
+                nodeInputCounts[nodeIndex], 
+                static_cast<uint32_t>(operatorNode.GetInputTensors().size()));
+
+            nodeOutputCounts[nodeIndex] = std::max(
+                nodeOutputCounts[nodeIndex], 
+                static_cast<uint32_t>(operatorNode.GetOutputTensors().size()));
+        }
+    }
+}
+
+void PopulateConstantNodeInputOutputCount(
+    const std::vector<DmlIntermediateSerializedGraphEdge>& edges,
+    /*out*/std::vector<uint32_t>& maxInputIndexForNodes,
+    /*out*/std::vector<uint32_t>& maxOutputIndexForNodes)
+{
+    for (auto& edge : edges)
+    {
+        maxInputIndexForNodes[edge.ToNodeIndex] = std::max(maxInputIndexForNodes[edge.ToNodeIndex], edge.ToNodeInputIndex + 1);
+        maxOutputIndexForNodes[edge.FromNodeIndex] = std::max(maxOutputIndexForNodes[edge.FromNodeIndex], edge.FromNodeOutputIndex + 1);
+    }
+}
+
+/*
+* validates intermediate edge and throws exception if an edge 
+* does not have a name or if an edge has more than 1 names.
+*/
+void PopulateNodeInputOutputNames(
+    flatbuffers::FlatBufferBuilder& builder,
+    const DmlSerializedGraphDesc& graphDesc,
+    const std::unordered_map<uint32_t, flatbuffers::Offset<flatbuffers::String>>& graphInputIndexToNameMap,
+    const std::unordered_map<uint32_t, flatbuffers::Offset<flatbuffers::String>>& graphOutputIndexToNameMap,
+    /*out*/std::vector<std::vector<flatbuffers::Offset<flatbuffers::String>>>& nodeToInputNames, 
+    /*out*/std::vector<std::vector<flatbuffers::Offset<flatbuffers::String>>>& nodeToOutputNames)
+{
+    for (auto& edge : graphDesc.InputEdges)
+    {
+        nodeToInputNames[edge.ToNodeIndex][edge.ToNodeInputIndex] = graphInputIndexToNameMap.at(edge.GraphInputIndex);
+    }
+
+    for (auto& edge : graphDesc.OutputEdges)
+    {
+        nodeToOutputNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = graphOutputIndexToNameMap.at(edge.GraphOutputIndex);
+    }
+
+    std::unordered_map<uint32_t, std::unordered_map<uint32_t, flatbuffers::Offset<flatbuffers::String>>> intermediateEdgeNames;
+    for (uint32_t edgeIndex = 0; edgeIndex < static_cast<uint32_t>(graphDesc.IntermediateEdges.size()); edgeIndex++)
+    {
+        auto& edge = graphDesc.IntermediateEdges[edgeIndex];
+        if (edge.Name.empty())
+        {
+            throw std::invalid_argument(
+                    "Graph intermediate edge from nodeIndex:" + std::to_string(edge.FromNodeIndex) + 
+                    " & nodeOutputIndex:" + std::to_string(edge.FromNodeOutputIndex) + " doesn't have name.");
+        }
+        
+        if (intermediateEdgeNames.find(edge.FromNodeIndex) != intermediateEdgeNames.end() &&
+            intermediateEdgeNames[edge.FromNodeIndex].find(edge.FromNodeOutputIndex) != intermediateEdgeNames[edge.FromNodeIndex].end())
+        {
+            flatbuffers::Offset edgeNameOffset = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex];
+            flatbuffers::String* edgeName = ReadAs<flatbuffers::String>(
+                builder.GetCurrentBufferPointer(),
+                builder.GetSize() - edgeNameOffset.o);
+
+            if (edgeName->str() != edge.Name)
+            {
+                throw std::invalid_argument(
+                    "Graph intermediate edge from nodeIndex:" + std::to_string(edge.FromNodeIndex) + 
+                    " & nodeOutputIndex:" + std::to_string(edge.FromNodeOutputIndex) + " has more than 1 names.");
+            }
+        }
+        else
+        {
+            intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = builder.CreateString(edge.Name.c_str());
+        }
+        nodeToInputNames[edge.ToNodeIndex][edge.ToNodeInputIndex] = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex];
+        nodeToOutputNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex];
+    }
+}
+
+
+/*
+* - If an edge is connected to multiple nodes, then there will be multiple instances 
+*   of input or intermediate edges, all with the same name.
+* - The input <graphDesc> will be validated incrementally throughout the execution 
+*   of the method.
+* - Handling of empty optional input/output/attibute for non-constant node:
+*   input/output
+*   - <DmlGraphNode.inputNames> and <DmlGraphNode.outputNames> will have an null entry
+*      but the actual OperatorNodeDesc variant's <OperatorNodeDesc.inputs> 
+*      and <OperatorNodeDesc.outputs> will not have any entry.
+*   attribute
+*   - <OperatorNodeDesc.attributes> will have null entry
+*/
+flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& graphDesc)
+{
+
+    flatbuffers::FlatBufferBuilder builder(1024);
+    if (graphDesc.Nodes.empty())
+    {
+        return builder.Release();
+    }
+
+    // create input/output edge index to name map
+    std::unordered_map<uint32_t, flatbuffers::Offset<flatbuffers::String>> graphInputIndexToNameMap = 
+        ConvertToEdgeIndexToNameMap<DmlInputSerializedGraphEdge>(graphDesc.InputEdges, builder);
+    std::unordered_map<uint32_t, flatbuffers::Offset<flatbuffers::String>> graphOutputIndexToNameMap = 
+        ConvertToEdgeIndexToNameMap<DmlOutputSerializedGraphEdge>(graphDesc.OutputEdges, builder);
+
+    /*
+    * - Calculate number of input/output for each operator to allocate
+    *   appropriate amount of memory for each node to store input/output names.
+    * - Non-constant node's input/output count can be determined by the
+    *   AbstractOperatorDesc.
+    * - Constant node will only have outgoing edges and those outgoing edges 
+    *   will be intermediate edges.
+    */
+    std::vector<uint32_t> nodeInputCounts(graphDesc.Nodes.size(), 0);
+    std::vector<uint32_t> nodeOutputCounts(graphDesc.Nodes.size(), 0);
+    PopulateNonConstantNodeInputOutputCount(graphDesc.Nodes, nodeInputCounts, nodeOutputCounts);
+    PopulateConstantNodeInputOutputCount(graphDesc.IntermediateEdges, nodeInputCounts, nodeOutputCounts);
+    
+    // populate node input/output names.
+    std::vector<std::vector<flatbuffers::Offset<flatbuffers::String>>> nodeToInputNames(graphDesc.Nodes.size());
+    std::vector<std::vector<flatbuffers::Offset<flatbuffers::String>>> nodeToOutputNames(graphDesc.Nodes.size());
+    for (uint32_t nodeIndex = 0; nodeIndex < static_cast<uint32_t>(graphDesc.Nodes.size()); nodeIndex++)
+    {
+        nodeToInputNames[nodeIndex].assign(nodeInputCounts[nodeIndex], builder.CreateString(nullptr, 0));
+        nodeToOutputNames[nodeIndex].assign(nodeOutputCounts[nodeIndex], builder.CreateString(nullptr, 0));
+    }
+    PopulateNodeInputOutputNames(builder, graphDesc, graphInputIndexToNameMap, graphOutputIndexToNameMap, nodeToInputNames, nodeToOutputNames);
+
+    // Create flatbuffer node objects
+    std::vector<flatbuffers::Offset<dml::ir::DmlGraphNode>> nodes(graphDesc.Nodes.size());
+    for (uint32_t nodeIndex = 0; nodeIndex < static_cast<uint32_t>(graphDesc.Nodes.size()); nodeIndex++)
+    {
+        nodes[nodeIndex] = SerializeNode(
+                            builder,
+                            nodeIndex,
+                            graphDesc.Nodes[nodeIndex],
+                            nodeToInputNames[nodeIndex],
+                            nodeToOutputNames[nodeIndex]);
+    }
+
+    // Convert to std::vector to create the <dml::ir::DmlGraphDesc> object.
+    std::vector<flatbuffers::Offset<flatbuffers::String>> graphInputNames(graphDesc.InputCount, builder.CreateString(nullptr, 0));
+    std::vector<flatbuffers::Offset<flatbuffers::String>> graphOutputNames(graphDesc.OutputCount, builder.CreateString(nullptr, 0));
+    for (const auto& [key, value] : graphInputIndexToNameMap)
+    {
+        graphInputNames[key] = value;
+    }
+    for (const auto& [key, value] : graphOutputIndexToNameMap)
+    {
+        graphOutputNames[key] = value;
+    }
+
+    flatbuffers::Offset<dml::ir::DmlGraphDesc> dmlGraphDescOffset = dml::ir::CreateDmlGraphDescDirect(
+        builder,
+        &nodes,
+        &graphInputNames,
+        &graphOutputNames);
+    builder.Finish(dmlGraphDescOffset);
+    return builder.Release();
+}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp
index 5c7b7bff1e370..0f0d445a95bae 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp
@@ -180,32 +180,50 @@ namespace Dml
                 // Convert partitionONNXGraph into DML EP GraphDesc
                 ComPtr<IDMLDevice> device;
                 ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf()));
+                // This map will be used to transfer the initializer to D3D12 system heap memory.
+                // 'serializedDmlGraphDesc' will have constant input as intermediate edges, that's why
+                // we need a mapping between intermediateEdgeIndex and indexedSubGraph's (a given partition)
+                // input arg index.
+                //   For ex: Let's say intermediate edge index = idx, then
+                //           indexedSubGraphInputArgIdx = constantEdgeIdxToSubgraphInputArgIdxMap[idx];
+                //           corresponding constant tensor = initializerNameToInitializerMap[indexedSubGraph.GetMetaDef()->inputs[indexedSubGraphInputArgIdx]]
+                // We are using intermediate edge index as a key because same constant tensor can be used by
+                // multiple nodes.
+                std::unordered_map<uint32_t, uint32_t> serializedGraphInputIndexToSubgraphInputIndex;
+                std::unordered_map<std::string_view, uint32_t> serializedGraphLargeConstantNameToSubgraphInputIndex;
+                std::vector<std::unique_ptr<std::byte[]>> smallConstantData;
                 GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc(
                     isInputsUploadedByDmlEP.data(),
                     isInputsUploadedByDmlEP.size(),
                     m_isInitializerTransferable,
                     m_partitionNodePropsMap,
-                    device.Get(),
                     providerImpl,
                     m_modelPath,
                     m_subgraphNodePointers,
                     m_subgraphInputs,
-                    m_subgraphOutputs);
+                    m_subgraphOutputs,
+                    serializedGraphInputIndexToSubgraphInputIndex,
+                    serializedGraphLargeConstantNameToSubgraphInputIndex,
+                    smallConstantData);
 
                 m_outputShapes = graphDesc.outputShapes;
 
                 // Walk through each graph edge and mark used inputs
                 m_inputsUsed.resize(fusedNodeInputCount, false);
-                for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges)
-                {
-                    m_inputsUsed[edge.GraphInputIndex] = true;
+                for (auto it = serializedGraphInputIndexToSubgraphInputIndex.begin(); it != serializedGraphInputIndexToSubgraphInputIndex.end(); it++) {
+                    m_inputsUsed[it->second] = true;
+                }
+                for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex.begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex.end(); it++) {
+                    m_inputsUsed[it->second] = true;
                 }
 
                 // Compile the operator
                 m_compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator(
                     graphDesc,
                     *m_indexedSubGraph,
-                    providerImpl);
+                    providerImpl,
+                    &serializedGraphInputIndexToSubgraphInputIndex,
+                    &serializedGraphLargeConstantNameToSubgraphInputIndex);
 
                 // Queue references to objects which must be kept alive until resulting GPU work completes
                 m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get());
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
index 8a32d06534dda..6c347ebdca7c1 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
@@ -771,8 +771,14 @@ namespace Dml
                 !native16BitShaderOpsSupported &&
                 IsCustomOpShader(node))
             {
-                nodeContainsSupportedDataTypes = false;
-                return;
+                // STFT is a special case since it has a dml ep registered
+                // graph transformation that will decompose fp16 STFT into convolution
+                // and so it is OK to register for fp16.
+                if (strcmp("STFT", node.OpType().c_str()) != 0)
+                {
+                    nodeContainsSupportedDataTypes = false;
+                    return;
+                }
             }
 
             // Allow nodeArgs that are SequenceTensor when they are actually implemented by CPU Kernels.
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
index 5617bc7bdcac6..841d6244a983e 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
@@ -270,7 +270,7 @@ namespace Dml
             return m_impl->OnSessionInitializationEnd();
         }
 
-        virtual onnxruntime::Status Sync() const final override
+        onnxruntime::Status Sync() const final override
         {
             // Completely wait until the device has completed all preceding tasks.
             // The application could have called SynchronizeBoundOutputs().
@@ -278,7 +278,7 @@ namespace Dml
             return Status::OK();
         }
 
-        virtual onnxruntime::Status OnRunEnd(bool /*sync_stream*/) final override
+        onnxruntime::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) final override
         {
             // Flush any pending work to the GPU, but don't block for completion, permitting it
             // to overlap other work.
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
index a5415ba85f3d3..7c25755a7d09e 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
@@ -24,8 +24,8 @@ struct EnumTraits<DML_TENSOR_TYPE>
 template <>
 struct EnumTraits<DML_OPERATOR_TYPE>
 {
-    static constexpr auto ValueCount = 161;
-    static constexpr size_t ActivationFunctionCount = 24;
+    static constexpr auto ValueCount = 168;
+    static constexpr size_t ActivationFunctionCount = 26;
 };
 
 template <>
@@ -62,7 +62,7 @@ struct EnumTraits<DML_CONVOLUTION_DIRECTION>
 template <>
 struct EnumTraits<DML_PADDING_MODE>
 {
-    static constexpr auto ValueCount = 4;
+    static constexpr auto ValueCount = 5;
 };
 
 template <>
@@ -86,7 +86,7 @@ struct EnumTraits<DML_FEATURE>
 template <>
 struct EnumTraits<DML_FEATURE_LEVEL>
 {
-    static constexpr auto ValueCount = 8;
+    static constexpr auto ValueCount = 13;
 };
 
 template <>
@@ -119,6 +119,12 @@ struct EnumTraits<DML_RANDOM_GENERATOR_TYPE>
     static constexpr auto ValueCount = 1;
 };
 
+template <>
+struct EnumTraits<DML_MULTIHEAD_ATTENTION_MASK_TYPE>
+{
+    static constexpr auto ValueCount = 5;
+};
+
 template <typename T>
 constexpr auto EnumValueCount = EnumTraits<T>::ValueCount;
 
@@ -495,12 +501,6 @@ struct OperatorDescTraits<DML_ROI_POOLING_OPERATOR_DESC>
     static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING;
 };
 
-template <>
-struct OperatorDescTraits<DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC>
-{
-    static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING;
-};
-
 template <>
 struct OperatorDescTraits<DML_SLICE_OPERATOR_DESC>
 {
@@ -879,6 +879,12 @@ struct OperatorDescTraits<DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC>
     static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY;
 };
 
+template <>
+struct OperatorDescTraits<DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC>
+{
+    static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT;
+};
+
 template <>
 struct OperatorDescTraits<DML_CONVOLUTION_INTEGER_OPERATOR_DESC>
 {
@@ -1029,6 +1035,18 @@ struct OperatorDescTraits<DML_DIAGONAL_MATRIX1_OPERATOR_DESC>
     static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DIAGONAL_MATRIX1;
 };
 
+template <>
+struct OperatorDescTraits<DML_MULTIHEAD_ATTENTION_OPERATOR_DESC>
+{
+    static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MULTIHEAD_ATTENTION;
+};
+
+template <>
+struct OperatorDescTraits<DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC>
+{
+    static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING;
+};
+
 template <>
 struct OperatorDescTraits<DML_ACTIVATION_ELU_OPERATOR_DESC>
 {
@@ -1174,9 +1192,15 @@ struct OperatorDescTraits<DML_ACTIVATION_GELU_OPERATOR_DESC>
 };
 
 template <>
-struct OperatorDescTraits<DML_MULTIHEAD_ATTENTION_OPERATOR_DESC>
+struct OperatorDescTraits<DML_ACTIVATION_SWISH_OPERATOR_DESC>
 {
-    static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MULTIHEAD_ATTENTION;
+    static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SWISH;
+};
+
+template <>
+struct OperatorDescTraits<DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC>
+{
+    static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARD_SWISH;
 };
 
 template <DML_OPERATOR_TYPE Type>
@@ -1502,12 +1526,6 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING>
     using DescType = DML_ROI_POOLING_OPERATOR_DESC;
 };
 
-template <>
-struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING>
-{
-    using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC;
-};
-
 template <>
 struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE>
 {
@@ -2036,6 +2054,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DIAGONAL_MATRIX1>
     using DescType = DML_DIAGONAL_MATRIX1_OPERATOR_DESC;
 };
 
+template <>
+struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION>
+{
+    using DescType = DML_MULTIHEAD_ATTENTION_OPERATOR_DESC;
+};
+
+template <>
+struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING>
+{
+    using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC;
+};
+
+template <>
+struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT>
+{
+    using DescType = DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC;
+};
+
 template <>
 struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU>
 {
@@ -2181,14 +2217,20 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_GELU>
 };
 
 template <>
-struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION>
+struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SWISH>
 {
-    using DescType = DML_MULTIHEAD_ATTENTION_OPERATOR_DESC;
+    using DescType = DML_ACTIVATION_SWISH_OPERATOR_DESC;
+};
+
+template <>
+struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARD_SWISH>
+{
+    using DescType = DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC;
 };
 
 // Calls a visitor functor, supplying an empty operator desc corresponding to the given DML_OPERATOR_TYPE as
 // the first argument.
-//
+// 
 // For example:
 //   Visit(DML_OPERATOR_ELEMENT_WISE_IDENTITY, [](auto tag) {
 //       using T = decltype(tag); // T is one of the DML_*_OPERATOR_DESC structs
@@ -2485,6 +2527,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
         return std::invoke(std::forward<Visitor>(visitor), DML_DIAGONAL_MATRIX1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
     case DML_OPERATOR_MULTIHEAD_ATTENTION:
         return std::invoke(std::forward<Visitor>(visitor), DML_MULTIHEAD_ATTENTION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
+    case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
+        return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
+    case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
+        return std::invoke(std::forward<Visitor>(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
     case DML_OPERATOR_ACTIVATION_ELU:
         return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
     case DML_OPERATOR_ACTIVATION_CELU:
@@ -2533,13 +2579,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
         return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_SHRINK_OPERATOR_DESC{}, std::forward<Ts>(args)...);
     case DML_OPERATOR_ACTIVATION_GELU:
         return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
-
-#pragma warning(push)
-#pragma warning(disable: 4063)
-    case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
-        return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
-#pragma warning(pop)
-
+    case DML_OPERATOR_ACTIVATION_SWISH:
+        return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_SWISH_OPERATOR_DESC{}, std::forward<Ts>(args)...);
+    case DML_OPERATOR_ACTIVATION_HARD_SWISH:
+        return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC{}, std::forward<Ts>(args)...);
     default:
         ORT_THROW_HR(E_INVALIDARG);
         return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
@@ -2547,7 +2590,55 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
 }
 #pragma warning(pop)
 
+namespace StringifyHelpers
+{
+template <typename T>
+inline gsl::czstring ToString(T value)
+{
+#ifndef WAI_BUILD_LINUX
+    // Clang will instantiate this template even if it isn't used,
+    // so this static_assert will always fire and break the build.
+    static_assert(false, "Not implemented for this type");
+#endif
+}
+
+template <>
+inline gsl::czstring ToString(DML_TENSOR_DATA_TYPE value)
+{
+    switch (value)
+    {
+    case DML_TENSOR_DATA_TYPE_UNKNOWN: return "DML_TENSOR_DATA_TYPE_UNKNOWN";
+    case DML_TENSOR_DATA_TYPE_FLOAT32: return "DML_TENSOR_DATA_TYPE_FLOAT32";
+    case DML_TENSOR_DATA_TYPE_FLOAT16: return "DML_TENSOR_DATA_TYPE_FLOAT16";
+    case DML_TENSOR_DATA_TYPE_UINT32: return "DML_TENSOR_DATA_TYPE_UINT32";
+    case DML_TENSOR_DATA_TYPE_UINT16: return "DML_TENSOR_DATA_TYPE_UINT16";
+    case DML_TENSOR_DATA_TYPE_UINT8: return "DML_TENSOR_DATA_TYPE_UINT8";
+    case DML_TENSOR_DATA_TYPE_INT32: return "DML_TENSOR_DATA_TYPE_INT32";
+    case DML_TENSOR_DATA_TYPE_INT16: return "DML_TENSOR_DATA_TYPE_INT16";
+    case DML_TENSOR_DATA_TYPE_INT8: return "DML_TENSOR_DATA_TYPE_INT8";
+    case DML_TENSOR_DATA_TYPE_FLOAT64: return "DML_TENSOR_DATA_TYPE_FLOAT64";
+    case DML_TENSOR_DATA_TYPE_UINT64: return "DML_TENSOR_DATA_TYPE_UINT64";
+    case DML_TENSOR_DATA_TYPE_INT64: return "DML_TENSOR_DATA_TYPE_INT64";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_TENSOR_TYPE value)
+{
+    switch (value)
+    {
+    case DML_TENSOR_TYPE_INVALID: return "DML_TENSOR_TYPE_INVALID";
+    case DML_TENSOR_TYPE_BUFFER: return "DML_TENSOR_TYPE_BUFFER";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
 
+template <>
 inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
 {
     switch (value)
@@ -2561,9 +2652,6 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
     case DML_OPERATOR_ELEMENT_WISE_ATAN: return "DML_OPERATOR_ELEMENT_WISE_ATAN";
     case DML_OPERATOR_ELEMENT_WISE_CEIL: return "DML_OPERATOR_ELEMENT_WISE_CEIL";
     case DML_OPERATOR_ELEMENT_WISE_CLIP: return "DML_OPERATOR_ELEMENT_WISE_CLIP";
-    case DML_OPERATOR_ELEMENT_WISE_CLIP1: return "DML_OPERATOR_ELEMENT_WISE_CLIP1";
-    case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD";
-    case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1";
     case DML_OPERATOR_ELEMENT_WISE_COS: return "DML_OPERATOR_ELEMENT_WISE_COS";
     case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return "DML_OPERATOR_ELEMENT_WISE_DIVIDE";
     case DML_OPERATOR_ELEMENT_WISE_EXP: return "DML_OPERATOR_ELEMENT_WISE_EXP";
@@ -2587,24 +2675,41 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
     case DML_OPERATOR_ELEMENT_WISE_RECIP: return "DML_OPERATOR_ELEMENT_WISE_RECIP";
     case DML_OPERATOR_ELEMENT_WISE_SIN: return "DML_OPERATOR_ELEMENT_WISE_SIN";
     case DML_OPERATOR_ELEMENT_WISE_SQRT: return "DML_OPERATOR_ELEMENT_WISE_SQRT";
-    case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return "DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE";
-    case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return "DML_OPERATOR_ELEMENT_WISE_ATAN_YX";
     case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return "DML_OPERATOR_ELEMENT_WISE_SUBTRACT";
     case DML_OPERATOR_ELEMENT_WISE_TAN: return "DML_OPERATOR_ELEMENT_WISE_TAN";
     case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return "DML_OPERATOR_ELEMENT_WISE_THRESHOLD";
     case DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR";
     case DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR: return "DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR";
+    case DML_OPERATOR_ACTIVATION_ELU: return "DML_OPERATOR_ACTIVATION_ELU";
+    case DML_OPERATOR_ACTIVATION_CELU: return "DML_OPERATOR_ACTIVATION_CELU";
+    case DML_OPERATOR_ACTIVATION_HARDMAX: return "DML_OPERATOR_ACTIVATION_HARDMAX";
+    case DML_OPERATOR_ACTIVATION_HARDMAX1: return "DML_OPERATOR_ACTIVATION_HARDMAX1";
+    case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return "DML_OPERATOR_ACTIVATION_HARD_SIGMOID";
+    case DML_OPERATOR_ACTIVATION_IDENTITY: return "DML_OPERATOR_ACTIVATION_IDENTITY";
+    case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return "DML_OPERATOR_ACTIVATION_LEAKY_RELU";
+    case DML_OPERATOR_ACTIVATION_LINEAR: return "DML_OPERATOR_ACTIVATION_LINEAR";
+    case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return "DML_OPERATOR_ACTIVATION_LOG_SOFTMAX";
+    case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1: return "DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1";
+    case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return "DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU";
+    case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: return "DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS";
+    case DML_OPERATOR_ACTIVATION_RELU: return "DML_OPERATOR_ACTIVATION_RELU";
+    case DML_OPERATOR_ACTIVATION_SCALED_ELU: return "DML_OPERATOR_ACTIVATION_SCALED_ELU";
+    case DML_OPERATOR_ACTIVATION_SCALED_TANH: return "DML_OPERATOR_ACTIVATION_SCALED_TANH";
+    case DML_OPERATOR_ACTIVATION_SIGMOID: return "DML_OPERATOR_ACTIVATION_SIGMOID";
+    case DML_OPERATOR_ACTIVATION_SOFTMAX: return "DML_OPERATOR_ACTIVATION_SOFTMAX";
+    case DML_OPERATOR_ACTIVATION_SOFTMAX1: return "DML_OPERATOR_ACTIVATION_SOFTMAX1";
+    case DML_OPERATOR_ACTIVATION_SOFTPLUS: return "DML_OPERATOR_ACTIVATION_SOFTPLUS";
+    case DML_OPERATOR_ACTIVATION_SOFTSIGN: return "DML_OPERATOR_ACTIVATION_SOFTSIGN";
+    case DML_OPERATOR_ACTIVATION_TANH: return "DML_OPERATOR_ACTIVATION_TANH";
+    case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return "DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU";
     case DML_OPERATOR_CONVOLUTION: return "DML_OPERATOR_CONVOLUTION";
     case DML_OPERATOR_GEMM: return "DML_OPERATOR_GEMM";
     case DML_OPERATOR_REDUCE: return "DML_OPERATOR_REDUCE";
-    case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN";
-    case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX";
     case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING";
     case DML_OPERATOR_AVERAGE_POOLING1: return "DML_OPERATOR_AVERAGE_POOLING1";
     case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING";
     case DML_OPERATOR_LP_POOLING1: return "DML_OPERATOR_LP_POOLING1";
     case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING";
-    case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1";
     case DML_OPERATOR_ROI_POOLING: return "DML_OPERATOR_ROI_POOLING";
     case DML_OPERATOR_SLICE: return "DML_OPERATOR_SLICE";
     case DML_OPERATOR_CAST: return "DML_OPERATOR_CAST";
@@ -2620,18 +2725,15 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
     case DML_OPERATOR_TILE: return "DML_OPERATOR_TILE";
     case DML_OPERATOR_TOP_K: return "DML_OPERATOR_TOP_K";
     case DML_OPERATOR_BATCH_NORMALIZATION: return "DML_OPERATOR_BATCH_NORMALIZATION";
-    case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_GRAD";
-    case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD";
+    case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING";
     case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION";
     case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION";
-    case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD";
     case DML_OPERATOR_LP_NORMALIZATION: return "DML_OPERATOR_LP_NORMALIZATION";
     case DML_OPERATOR_RNN: return "DML_OPERATOR_RNN";
     case DML_OPERATOR_LSTM: return "DML_OPERATOR_LSTM";
     case DML_OPERATOR_GRU: return "DML_OPERATOR_GRU";
     case DML_OPERATOR_ELEMENT_WISE_SIGN: return "DML_OPERATOR_ELEMENT_WISE_SIGN";
     case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return "DML_OPERATOR_ELEMENT_WISE_IS_NAN";
-    case DML_OPERATOR_ELEMENT_WISE_NEGATE: return "DML_OPERATOR_ELEMENT_WISE_NEGATE";
     case DML_OPERATOR_ELEMENT_WISE_ERF: return "DML_OPERATOR_ELEMENT_WISE_ERF";
     case DML_OPERATOR_ELEMENT_WISE_SINH: return "DML_OPERATOR_ELEMENT_WISE_SINH";
     case DML_OPERATOR_ELEMENT_WISE_COSH: return "DML_OPERATOR_ELEMENT_WISE_COSH";
@@ -2641,6 +2743,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
     case DML_OPERATOR_ELEMENT_WISE_ATANH: return "DML_OPERATOR_ELEMENT_WISE_ATANH";
     case DML_OPERATOR_ELEMENT_WISE_IF: return "DML_OPERATOR_ELEMENT_WISE_IF";
     case DML_OPERATOR_ELEMENT_WISE_ADD1: return "DML_OPERATOR_ELEMENT_WISE_ADD1";
+    case DML_OPERATOR_ACTIVATION_SHRINK: return "DML_OPERATOR_ACTIVATION_SHRINK";
+    case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1";
     case DML_OPERATOR_MAX_UNPOOLING: return "DML_OPERATOR_MAX_UNPOOLING";
     case DML_OPERATOR_DIAGONAL_MATRIX: return "DML_OPERATOR_DIAGONAL_MATRIX";
     case DML_OPERATOR_SCATTER: return "DML_OPERATOR_SCATTER";
@@ -2652,10 +2756,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
     case DML_OPERATOR_ELEMENT_WISE_IS_INFINITY: return "DML_OPERATOR_ELEMENT_WISE_IS_INFINITY";
     case DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE";
     case DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR";
-    case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT";
     case DML_OPERATOR_FILL_VALUE_SEQUENCE: return "DML_OPERATOR_FILL_VALUE_SEQUENCE";
+    case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT";
     case DML_OPERATOR_CUMULATIVE_SUMMATION: return "DML_OPERATOR_CUMULATIVE_SUMMATION";
-    case DML_OPERATOR_CUMULATIVE_PRODUCT: return "DML_OPERATOR_CUMULATIVE_PRODUCT";
     case DML_OPERATOR_REVERSE_SUBSEQUENCES: return "DML_OPERATOR_REVERSE_SUBSEQUENCES";
     case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS";
     case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND";
@@ -2684,20 +2787,278 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
     case DML_OPERATOR_RESAMPLE_GRAD: return "DML_OPERATOR_RESAMPLE_GRAD";
     case DML_OPERATOR_SLICE_GRAD: return "DML_OPERATOR_SLICE_GRAD";
     case DML_OPERATOR_ADAM_OPTIMIZER: return "DML_OPERATOR_ADAM_OPTIMIZER";
+    case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN";
+    case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX";
     case DML_OPERATOR_ROI_ALIGN: return "DML_OPERATOR_ROI_ALIGN";
-    case DML_OPERATOR_ROI_ALIGN1: return "DML_OPERATOR_ROI_ALIGN1";
     case DML_OPERATOR_GATHER_ND1: return "DML_OPERATOR_GATHER_ND1";
-    case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR";
+    case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return "DML_OPERATOR_ELEMENT_WISE_ATAN_YX";
+    case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD";
+    case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return "DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE";
+    case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD";
+    case DML_OPERATOR_CUMULATIVE_PRODUCT: return "DML_OPERATOR_CUMULATIVE_PRODUCT";
+    case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_GRAD";
+    case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD";
     case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD";
-    case DML_OPERATOR_ROI_ALIGN_GRAD: return "DML_OPERATOR_ROI_ALIGN_GRAD";
-    case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING";
+    case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR";
+    case DML_OPERATOR_ROI_ALIGN1: return "DML_OPERATOR_ROI_ALIGN1";
+    case DML_OPERATOR_ELEMENT_WISE_CLIP1: return "DML_OPERATOR_ELEMENT_WISE_CLIP1";
+    case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1";
+    case DML_OPERATOR_ELEMENT_WISE_NEGATE: return "DML_OPERATOR_ELEMENT_WISE_NEGATE";
+    case DML_OPERATOR_ACTIVATION_GELU: return "DML_OPERATOR_ACTIVATION_GELU";
+    case DML_OPERATOR_ACTIVATION_SWISH: return "DML_OPERATOR_ACTIVATION_SWISH";
+    case DML_OPERATOR_ACTIVATION_HARD_SWISH: return "DML_OPERATOR_ACTIVATION_HARD_SWISH";
     case DML_OPERATOR_RESAMPLE2: return "DML_OPERATOR_RESAMPLE2";
     case DML_OPERATOR_RESAMPLE_GRAD1: return "DML_OPERATOR_RESAMPLE_GRAD1";
     case DML_OPERATOR_DIAGONAL_MATRIX1: return "DML_OPERATOR_DIAGONAL_MATRIX1";
     case DML_OPERATOR_MULTIHEAD_ATTENTION: return "DML_OPERATOR_MULTIHEAD_ATTENTION";
+    case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING";
+    case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_BINDING_TYPE value)
+{
+    switch (value)
+    {
+    case DML_BINDING_TYPE_NONE: return "DML_BINDING_TYPE_NONE";
+    case DML_BINDING_TYPE_BUFFER: return "DML_BINDING_TYPE_BUFFER";
+    case DML_BINDING_TYPE_BUFFER_ARRAY: return "DML_BINDING_TYPE_BUFFER_ARRAY";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_REDUCE_FUNCTION value)
+{
+    switch (value)
+    {
+    case DML_REDUCE_FUNCTION_ARGMAX: return "DML_REDUCE_FUNCTION_ARGMAX";
+    case DML_REDUCE_FUNCTION_ARGMIN: return "DML_REDUCE_FUNCTION_ARGMIN";
+    case DML_REDUCE_FUNCTION_AVERAGE: return "DML_REDUCE_FUNCTION_AVERAGE";
+    case DML_REDUCE_FUNCTION_L1: return "DML_REDUCE_FUNCTION_L1";
+    case DML_REDUCE_FUNCTION_L2: return "DML_REDUCE_FUNCTION_L2";
+    case DML_REDUCE_FUNCTION_LOG_SUM: return "DML_REDUCE_FUNCTION_LOG_SUM";
+    case DML_REDUCE_FUNCTION_LOG_SUM_EXP: return "DML_REDUCE_FUNCTION_LOG_SUM_EXP";
+    case DML_REDUCE_FUNCTION_MAX: return "DML_REDUCE_FUNCTION_MAX";
+    case DML_REDUCE_FUNCTION_MIN: return "DML_REDUCE_FUNCTION_MIN";
+    case DML_REDUCE_FUNCTION_MULTIPLY: return "DML_REDUCE_FUNCTION_MULTIPLY";
+    case DML_REDUCE_FUNCTION_SUM: return "DML_REDUCE_FUNCTION_SUM";
+    case DML_REDUCE_FUNCTION_SUM_SQUARE: return "DML_REDUCE_FUNCTION_SUM_SQUARE";
     default:
         assert(false);
         return "<unknown>";
     }
 }
+
+template <>
+inline gsl::czstring ToString(DML_MATRIX_TRANSFORM value)
+{
+    switch (value)
+    {
+    case DML_MATRIX_TRANSFORM_NONE: return "DML_MATRIX_TRANSFORM_NONE";
+    case DML_MATRIX_TRANSFORM_TRANSPOSE: return "DML_MATRIX_TRANSFORM_TRANSPOSE";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_CONVOLUTION_MODE value)
+{
+    switch (value)
+    {
+    case DML_CONVOLUTION_MODE_CONVOLUTION: return "DML_CONVOLUTION_MODE_CONVOLUTION";
+    case DML_CONVOLUTION_MODE_CROSS_CORRELATION: return "DML_CONVOLUTION_MODE_CROSS_CORRELATION";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_CONVOLUTION_DIRECTION value)
+{
+    switch (value)
+    {
+    case DML_CONVOLUTION_DIRECTION_FORWARD: return "DML_CONVOLUTION_DIRECTION_FORWARD";
+    case DML_CONVOLUTION_DIRECTION_BACKWARD: return "DML_CONVOLUTION_DIRECTION_BACKWARD";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_PADDING_MODE value)
+{
+    switch (value)
+    {
+    case DML_PADDING_MODE_CONSTANT: return "DML_PADDING_MODE_CONSTANT";
+    case DML_PADDING_MODE_EDGE: return "DML_PADDING_MODE_EDGE";
+    case DML_PADDING_MODE_REFLECTION: return "DML_PADDING_MODE_REFLECTION";
+    case DML_PADDING_MODE_SYMMETRIC: return "DML_PADDING_MODE_SYMMETRIC";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_INTERPOLATION_MODE value)
+{
+    switch (value)
+    {
+    case DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR: return "DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR";
+    case DML_INTERPOLATION_MODE_LINEAR: return "DML_INTERPOLATION_MODE_LINEAR";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_RECURRENT_NETWORK_DIRECTION value)
+{
+    switch (value)
+    {
+    case DML_RECURRENT_NETWORK_DIRECTION_FORWARD: return "DML_RECURRENT_NETWORK_DIRECTION_FORWARD";
+    case DML_RECURRENT_NETWORK_DIRECTION_BACKWARD: return "DML_RECURRENT_NETWORK_DIRECTION_BACKWARD";
+    case DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL: return "DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_FEATURE value)
+{
+    switch (value)
+    {
+    case DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT: return "DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT";
+    case DML_FEATURE_FEATURE_LEVELS: return "DML_FEATURE_FEATURE_LEVELS";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_FEATURE_LEVEL value)
+{
+    switch (value)
+    {
+    case DML_FEATURE_LEVEL_1_0: return "DML_FEATURE_LEVEL_1_0";
+    case DML_FEATURE_LEVEL_2_0: return "DML_FEATURE_LEVEL_2_0";
+    case DML_FEATURE_LEVEL_2_1: return "DML_FEATURE_LEVEL_2_1";
+    case DML_FEATURE_LEVEL_3_0: return "DML_FEATURE_LEVEL_3_0";
+    case DML_FEATURE_LEVEL_3_1: return "DML_FEATURE_LEVEL_3_1";
+    case DML_FEATURE_LEVEL_4_0: return "DML_FEATURE_LEVEL_4_0";
+    case DML_FEATURE_LEVEL_4_1: return "DML_FEATURE_LEVEL_4_1";
+    case DML_FEATURE_LEVEL_5_0: return "DML_FEATURE_LEVEL_5_0";
+    case DML_FEATURE_LEVEL_5_1: return "DML_FEATURE_LEVEL_5_1";
+    case DML_FEATURE_LEVEL_5_2: return "DML_FEATURE_LEVEL_5_2";
+    case DML_FEATURE_LEVEL_6_0: return "DML_FEATURE_LEVEL_6_0";
+    case DML_FEATURE_LEVEL_6_1: return "DML_FEATURE_LEVEL_6_1";
+    case DML_FEATURE_LEVEL_6_2: return "DML_FEATURE_LEVEL_6_2";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_IS_INFINITY_MODE value)
+{
+    switch (value)
+    {
+    case DML_IS_INFINITY_MODE_EITHER: return "DML_IS_INFINITY_MODE_EITHER";
+    case DML_IS_INFINITY_MODE_POSITIVE: return "DML_IS_INFINITY_MODE_POSITIVE";
+    case DML_IS_INFINITY_MODE_NEGATIVE: return "DML_IS_INFINITY_MODE_NEGATIVE";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_DEPTH_SPACE_ORDER value)
+{
+    switch (value)
+    {
+    case DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW: return "DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW";
+    case DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH: return "DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_AXIS_DIRECTION value)
+{
+    switch (value)
+    {
+    case DML_AXIS_DIRECTION_INCREASING: return "DML_AXIS_DIRECTION_INCREASING";
+    case DML_AXIS_DIRECTION_DECREASING: return "DML_AXIS_DIRECTION_DECREASING";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_ROUNDING_MODE value)
+{
+    switch (value)
+    {
+    case DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN: return "DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN";
+    case DML_ROUNDING_MODE_TOWARD_ZERO: return "DML_ROUNDING_MODE_TOWARD_ZERO";
+    case DML_ROUNDING_MODE_TOWARD_INFINITY: return "DML_ROUNDING_MODE_TOWARD_INFINITY";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_RANDOM_GENERATOR_TYPE value)
+{
+    switch (value)
+    {
+    case DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10: return "DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+template <>
+inline gsl::czstring ToString(DML_MULTIHEAD_ATTENTION_MASK_TYPE value)
+{
+    switch (value)
+    {
+    case DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE";
+    case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH";
+    case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START";
+    case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END";
+    case DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN";
+    default:
+        assert(false);
+        return "<unknown>";
+    }
+}
+
+
+template <typename T>
+T FromString(std::string_view value);
+
+}
 }
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
index 2a82c12872a72..64ea5b7801a84 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
@@ -618,7 +618,7 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA {
 constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] {
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", true },
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
 };
 
@@ -633,7 +633,7 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA {
 constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] {
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", true },
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
 };
 
@@ -869,31 +869,6 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA {
     DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS,
 };
 
-
-constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] {
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false },
-};
-
-constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA {
-    "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING",
-    static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING),
-    DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
-    13,
-    DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
-};
-
 constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] {
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
@@ -1146,7 +1121,7 @@ constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleGradientTensor", false },
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputBiasGradientTensor", false },
-    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", true },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false },
 };
 
 constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA {
@@ -1890,6 +1865,25 @@ constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA {
     DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA_FIELDS,
 };
 
+constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] {
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
+};
+
+constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA {
+    "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT",
+    static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT),
+    DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
+    8,
+    DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS,
+};
+
 constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS[9] {
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
@@ -2312,7 +2306,7 @@ constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA {
     DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA_FIELDS,
 };
 
-constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8]{
+constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8] {
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false },
@@ -2323,7 +2317,7 @@ constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8]{
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "OutputPixelOffsets", false },
 };
 
-constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE2_OPERATOR_SCHEMA{
+constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE2_OPERATOR_SCHEMA {
     "DML_OPERATOR_RESAMPLE2",
     DML_OPERATOR_RESAMPLE2,
     DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
@@ -2342,7 +2336,7 @@ constexpr DML_SCHEMA_FIELD DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA_FIELDS[8]{
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "OutputPixelOffsets", false },
 };
 
-constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA{
+constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA {
     "DML_OPERATOR_RESAMPLE_GRAD1",
     DML_OPERATOR_RESAMPLE_GRAD1,
     DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
@@ -2350,7 +2344,7 @@ constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA{
     DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA_FIELDS,
 };
 
-constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6]{
+constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6] {
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", true },
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ValueDataType", false },
@@ -2359,7 +2353,7 @@ constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6]{
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_INT, "DiagonalFillEnd", false },
 };
 
-constexpr DML_OPERATOR_SCHEMA DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA{
+constexpr DML_OPERATOR_SCHEMA DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA {
     "DML_OPERATOR_DIAGONAL_MATRIX1",
     DML_OPERATOR_DIAGONAL_MATRIX1,
     DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
@@ -2396,6 +2390,30 @@ constexpr DML_OPERATOR_SCHEMA DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA {
     DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA_FIELDS,
 };
 
+constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] {
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false },
+};
+
+constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA {
+    "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING",
+    DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING,
+    DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
+    13,
+    DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
+};
+
 constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] {
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
@@ -2732,6 +2750,35 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_GELU_OPERATOR_SCHEMA {
     DML_ACTIVATION_GELU_OPERATOR_SCHEMA_FIELDS,
 };
 
+constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SWISH_OPERATOR_SCHEMA_FIELDS[3] {
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SigmoidInputScale", false },
+};
+
+constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SWISH_OPERATOR_SCHEMA {
+    "DML_OPERATOR_ACTIVATION_SWISH",
+    DML_OPERATOR_ACTIVATION_SWISH,
+    DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
+    3,
+    DML_ACTIVATION_SWISH_OPERATOR_SCHEMA_FIELDS,
+};
+
+constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA_FIELDS[4] {
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false },
+    DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false },
+};
+
+constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA {
+    "DML_OPERATOR_ACTIVATION_HARD_SWISH",
+    DML_OPERATOR_ACTIVATION_HARD_SWISH,
+    DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
+    4,
+    DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA_FIELDS,
+};
+
 constexpr DML_SCHEMA_FIELD DML_RNN_ZERO_OPERATOR_SCHEMA_FIELDS[3] {
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
     DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", false },
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h
new file mode 100644
index 0000000000000..df485396f1e47
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h
@@ -0,0 +1,850 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_
+#define FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_
+
+#include "core/common/flatbuffers.h"
+
+// Ensure the included flatbuffers.h is the same version as when this file was
+// generated, otherwise it may not be compatible.
+static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
+              FLATBUFFERS_VERSION_MINOR == 5 &&
+              FLATBUFFERS_VERSION_REVISION == 26,
+             "Non-compatible flatbuffers version included");
+
+#include "OperatorFieldTypes_generated.h"
+
+namespace dml {
+namespace ir {
+
+struct ConstantRawData;
+struct ConstantRawDataBuilder;
+
+struct ConstantName;
+struct ConstantNameBuilder;
+
+struct ConstantNodeDesc;
+struct ConstantNodeDescBuilder;
+
+struct DmlBufferTensorDesc;
+struct DmlBufferTensorDescBuilder;
+
+struct OperatorNodeDesc;
+struct OperatorNodeDescBuilder;
+
+struct DmlGraphNode;
+struct DmlGraphNodeBuilder;
+
+struct DmlGraphDesc;
+struct DmlGraphDescBuilder;
+
+enum ConstantNodeDescDetail : uint8_t {
+  ConstantNodeDescDetail_NONE = 0,
+  ConstantNodeDescDetail_ConstantName = 1,
+  ConstantNodeDescDetail_ConstantRawData = 2,
+  ConstantNodeDescDetail_MIN = ConstantNodeDescDetail_NONE,
+  ConstantNodeDescDetail_MAX = ConstantNodeDescDetail_ConstantRawData
+};
+
+inline const ConstantNodeDescDetail (&EnumValuesConstantNodeDescDetail())[3] {
+  static const ConstantNodeDescDetail values[] = {
+    ConstantNodeDescDetail_NONE,
+    ConstantNodeDescDetail_ConstantName,
+    ConstantNodeDescDetail_ConstantRawData
+  };
+  return values;
+}
+
+inline const char * const *EnumNamesConstantNodeDescDetail() {
+  static const char * const names[4] = {
+    "NONE",
+    "ConstantName",
+    "ConstantRawData",
+    nullptr
+  };
+  return names;
+}
+
+inline const char *EnumNameConstantNodeDescDetail(ConstantNodeDescDetail e) {
+  if (::flatbuffers::IsOutRange(e, ConstantNodeDescDetail_NONE, ConstantNodeDescDetail_ConstantRawData)) return "";
+  const size_t index = static_cast<size_t>(e);
+  return EnumNamesConstantNodeDescDetail()[index];
+}
+
+template<typename T> struct ConstantNodeDescDetailTraits {
+  static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_NONE;
+};
+
+template<> struct ConstantNodeDescDetailTraits<dml::ir::ConstantName> {
+  static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_ConstantName;
+};
+
+template<> struct ConstantNodeDescDetailTraits<dml::ir::ConstantRawData> {
+  static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_ConstantRawData;
+};
+
+bool VerifyConstantNodeDescDetail(::flatbuffers::Verifier &verifier, const void *obj, ConstantNodeDescDetail type);
+bool VerifyConstantNodeDescDetailVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, const ::flatbuffers::Vector<uint8_t> *types);
+
+enum NodeDesc : uint8_t {
+  NodeDesc_NONE = 0,
+  NodeDesc_OperatorNodeDesc = 1,
+  NodeDesc_ConstantNodeDesc = 2,
+  NodeDesc_MIN = NodeDesc_NONE,
+  NodeDesc_MAX = NodeDesc_ConstantNodeDesc
+};
+
+inline const NodeDesc (&EnumValuesNodeDesc())[3] {
+  static const NodeDesc values[] = {
+    NodeDesc_NONE,
+    NodeDesc_OperatorNodeDesc,
+    NodeDesc_ConstantNodeDesc
+  };
+  return values;
+}
+
+inline const char * const *EnumNamesNodeDesc() {
+  static const char * const names[4] = {
+    "NONE",
+    "OperatorNodeDesc",
+    "ConstantNodeDesc",
+    nullptr
+  };
+  return names;
+}
+
+inline const char *EnumNameNodeDesc(NodeDesc e) {
+  if (::flatbuffers::IsOutRange(e, NodeDesc_NONE, NodeDesc_ConstantNodeDesc)) return "";
+  const size_t index = static_cast<size_t>(e);
+  return EnumNamesNodeDesc()[index];
+}
+
+template<typename T> struct NodeDescTraits {
+  static const NodeDesc enum_value = NodeDesc_NONE;
+};
+
+template<> struct NodeDescTraits<dml::ir::OperatorNodeDesc> {
+  static const NodeDesc enum_value = NodeDesc_OperatorNodeDesc;
+};
+
+template<> struct NodeDescTraits<dml::ir::ConstantNodeDesc> {
+  static const NodeDesc enum_value = NodeDesc_ConstantNodeDesc;
+};
+
+bool VerifyNodeDesc(::flatbuffers::Verifier &verifier, const void *obj, NodeDesc type);
+bool VerifyNodeDescVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, const ::flatbuffers::Vector<uint8_t> *types);
+
+struct ConstantRawData FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+  typedef ConstantRawDataBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_DATA = 4
+  };
+  const ::flatbuffers::Vector<uint8_t> *data() const {
+    return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_DATA);
+  }
+  ::flatbuffers::Vector<uint8_t> *mutable_data() {
+    return GetPointer<::flatbuffers::Vector<uint8_t> *>(VT_DATA);
+  }
+  bool Verify(::flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, VT_DATA) &&
+           verifier.VerifyVector(data()) &&
+           verifier.EndTable();
+  }
+};
+
+struct ConstantRawDataBuilder {
+  typedef ConstantRawData Table;
+  ::flatbuffers::FlatBufferBuilder &fbb_;
+  ::flatbuffers::uoffset_t start_;
+  void add_data(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data) {
+    fbb_.AddOffset(ConstantRawData::VT_DATA, data);
+  }
+  explicit ConstantRawDataBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ::flatbuffers::Offset<ConstantRawData> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = ::flatbuffers::Offset<ConstantRawData>(end);
+    return o;
+  }
+};
+
+inline ::flatbuffers::Offset<ConstantRawData> CreateConstantRawData(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data = 0) {
+  ConstantRawDataBuilder builder_(_fbb);
+  builder_.add_data(data);
+  return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<ConstantRawData> CreateConstantRawDataDirect(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    const std::vector<uint8_t> *data = nullptr) {
+  auto data__ = data ? _fbb.CreateVector<uint8_t>(*data) : 0;
+  return dml::ir::CreateConstantRawData(
+      _fbb,
+      data__);
+}
+
+struct ConstantName FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+  typedef ConstantNameBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_NAME = 4
+  };
+  const ::flatbuffers::String *name() const {
+    return GetPointer<const ::flatbuffers::String *>(VT_NAME);
+  }
+  ::flatbuffers::String *mutable_name() {
+    return GetPointer<::flatbuffers::String *>(VT_NAME);
+  }
+  bool Verify(::flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, VT_NAME) &&
+           verifier.VerifyString(name()) &&
+           verifier.EndTable();
+  }
+};
+
+struct ConstantNameBuilder {
+  typedef ConstantName Table;
+  ::flatbuffers::FlatBufferBuilder &fbb_;
+  ::flatbuffers::uoffset_t start_;
+  void add_name(::flatbuffers::Offset<::flatbuffers::String> name) {
+    fbb_.AddOffset(ConstantName::VT_NAME, name);
+  }
+  explicit ConstantNameBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ::flatbuffers::Offset<ConstantName> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = ::flatbuffers::Offset<ConstantName>(end);
+    return o;
+  }
+};
+
+inline ::flatbuffers::Offset<ConstantName> CreateConstantName(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    ::flatbuffers::Offset<::flatbuffers::String> name = 0) {
+  ConstantNameBuilder builder_(_fbb);
+  builder_.add_name(name);
+  return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<ConstantName> CreateConstantNameDirect(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    const char *name = nullptr) {
+  auto name__ = name ? _fbb.CreateString(name) : 0;
+  return dml::ir::CreateConstantName(
+      _fbb,
+      name__);
+}
+
+struct ConstantNodeDesc FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+  typedef ConstantNodeDescBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_DATA_TYPE = 4,
+    VT_DATA = 6
+  };
+  dml::ir::ConstantNodeDescDetail data_type() const {
+    return static_cast<dml::ir::ConstantNodeDescDetail>(GetField<uint8_t>(VT_DATA_TYPE, 0));
+  }
+  const void *data() const {
+    return GetPointer<const void *>(VT_DATA);
+  }
+  template<typename T> const T *data_as() const;
+  const dml::ir::ConstantName *data_as_ConstantName() const {
+    return data_type() == dml::ir::ConstantNodeDescDetail_ConstantName ? static_cast<const dml::ir::ConstantName *>(data()) : nullptr;
+  }
+  const dml::ir::ConstantRawData *data_as_ConstantRawData() const {
+    return data_type() == dml::ir::ConstantNodeDescDetail_ConstantRawData ? static_cast<const dml::ir::ConstantRawData *>(data()) : nullptr;
+  }
+  void *mutable_data() {
+    return GetPointer<void *>(VT_DATA);
+  }
+  bool Verify(::flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<uint8_t>(verifier, VT_DATA_TYPE, 1) &&
+           VerifyOffset(verifier, VT_DATA) &&
+           VerifyConstantNodeDescDetail(verifier, data(), data_type()) &&
+           verifier.EndTable();
+  }
+};
+
+template<> inline const dml::ir::ConstantName *ConstantNodeDesc::data_as<dml::ir::ConstantName>() const {
+  return data_as_ConstantName();
+}
+
+template<> inline const dml::ir::ConstantRawData *ConstantNodeDesc::data_as<dml::ir::ConstantRawData>() const {
+  return data_as_ConstantRawData();
+}
+
+struct ConstantNodeDescBuilder {
+  typedef ConstantNodeDesc Table;
+  ::flatbuffers::FlatBufferBuilder &fbb_;
+  ::flatbuffers::uoffset_t start_;
+  void add_data_type(dml::ir::ConstantNodeDescDetail data_type) {
+    fbb_.AddElement<uint8_t>(ConstantNodeDesc::VT_DATA_TYPE, static_cast<uint8_t>(data_type), 0);
+  }
+  void add_data(::flatbuffers::Offset<void> data) {
+    fbb_.AddOffset(ConstantNodeDesc::VT_DATA, data);
+  }
+  explicit ConstantNodeDescBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ::flatbuffers::Offset<ConstantNodeDesc> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = ::flatbuffers::Offset<ConstantNodeDesc>(end);
+    return o;
+  }
+};
+
+inline ::flatbuffers::Offset<ConstantNodeDesc> CreateConstantNodeDesc(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    dml::ir::ConstantNodeDescDetail data_type = dml::ir::ConstantNodeDescDetail_NONE,
+    ::flatbuffers::Offset<void> data = 0) {
+  ConstantNodeDescBuilder builder_(_fbb);
+  builder_.add_data(data);
+  builder_.add_data_type(data_type);
+  return builder_.Finish();
+}
+
+struct DmlBufferTensorDesc FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+  typedef DmlBufferTensorDescBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_DATATYPE = 4,
+    VT_SIZES = 6,
+    VT_STRIDES = 8,
+    VT_TOTALTENSORSIZEINBYTES = 10
+  };
+  const ::flatbuffers::String *dataType() const {
+    return GetPointer<const ::flatbuffers::String *>(VT_DATATYPE);
+  }
+  ::flatbuffers::String *mutable_dataType() {
+    return GetPointer<::flatbuffers::String *>(VT_DATATYPE);
+  }
+  const ::flatbuffers::Vector<uint32_t> *sizes() const {
+    return GetPointer<const ::flatbuffers::Vector<uint32_t> *>(VT_SIZES);
+  }
+  ::flatbuffers::Vector<uint32_t> *mutable_sizes() {
+    return GetPointer<::flatbuffers::Vector<uint32_t> *>(VT_SIZES);
+  }
+  const ::flatbuffers::Vector<uint32_t> *strides() const {
+    return GetPointer<const ::flatbuffers::Vector<uint32_t> *>(VT_STRIDES);
+  }
+  ::flatbuffers::Vector<uint32_t> *mutable_strides() {
+    return GetPointer<::flatbuffers::Vector<uint32_t> *>(VT_STRIDES);
+  }
+  uint64_t totalTensorSizeInBytes() const {
+    return GetField<uint64_t>(VT_TOTALTENSORSIZEINBYTES, 0);
+  }
+  bool mutate_totalTensorSizeInBytes(uint64_t _totalTensorSizeInBytes = 0) {
+    return SetField<uint64_t>(VT_TOTALTENSORSIZEINBYTES, _totalTensorSizeInBytes, 0);
+  }
+  bool Verify(::flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, VT_DATATYPE) &&
+           verifier.VerifyString(dataType()) &&
+           VerifyOffset(verifier, VT_SIZES) &&
+           verifier.VerifyVector(sizes()) &&
+           VerifyOffset(verifier, VT_STRIDES) &&
+           verifier.VerifyVector(strides()) &&
+           VerifyField<uint64_t>(verifier, VT_TOTALTENSORSIZEINBYTES, 8) &&
+           verifier.EndTable();
+  }
+};
+
+struct DmlBufferTensorDescBuilder {
+  typedef DmlBufferTensorDesc Table;
+  ::flatbuffers::FlatBufferBuilder &fbb_;
+  ::flatbuffers::uoffset_t start_;
+  void add_dataType(::flatbuffers::Offset<::flatbuffers::String> dataType) {
+    fbb_.AddOffset(DmlBufferTensorDesc::VT_DATATYPE, dataType);
+  }
+  void add_sizes(::flatbuffers::Offset<::flatbuffers::Vector<uint32_t>> sizes) {
+    fbb_.AddOffset(DmlBufferTensorDesc::VT_SIZES, sizes);
+  }
+  void add_strides(::flatbuffers::Offset<::flatbuffers::Vector<uint32_t>> strides) {
+    fbb_.AddOffset(DmlBufferTensorDesc::VT_STRIDES, strides);
+  }
+  void add_totalTensorSizeInBytes(uint64_t totalTensorSizeInBytes) {
+    fbb_.AddElement<uint64_t>(DmlBufferTensorDesc::VT_TOTALTENSORSIZEINBYTES, totalTensorSizeInBytes, 0);
+  }
+  explicit DmlBufferTensorDescBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ::flatbuffers::Offset<DmlBufferTensorDesc> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = ::flatbuffers::Offset<DmlBufferTensorDesc>(end);
+    return o;
+  }
+};
+
+inline ::flatbuffers::Offset<DmlBufferTensorDesc> CreateDmlBufferTensorDesc(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    ::flatbuffers::Offset<::flatbuffers::String> dataType = 0,
+    ::flatbuffers::Offset<::flatbuffers::Vector<uint32_t>> sizes = 0,
+    ::flatbuffers::Offset<::flatbuffers::Vector<uint32_t>> strides = 0,
+    uint64_t totalTensorSizeInBytes = 0) {
+  DmlBufferTensorDescBuilder builder_(_fbb);
+  builder_.add_totalTensorSizeInBytes(totalTensorSizeInBytes);
+  builder_.add_strides(strides);
+  builder_.add_sizes(sizes);
+  builder_.add_dataType(dataType);
+  return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<DmlBufferTensorDesc> CreateDmlBufferTensorDescDirect(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    const char *dataType = nullptr,
+    const std::vector<uint32_t> *sizes = nullptr,
+    const std::vector<uint32_t> *strides = nullptr,
+    uint64_t totalTensorSizeInBytes = 0) {
+  auto dataType__ = dataType ? _fbb.CreateString(dataType) : 0;
+  auto sizes__ = sizes ? _fbb.CreateVector<uint32_t>(*sizes) : 0;
+  auto strides__ = strides ? _fbb.CreateVector<uint32_t>(*strides) : 0;
+  return dml::ir::CreateDmlBufferTensorDesc(
+      _fbb,
+      dataType__,
+      sizes__,
+      strides__,
+      totalTensorSizeInBytes);
+}
+
+struct OperatorNodeDesc FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+  typedef OperatorNodeDescBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_TYPE = 4,
+    VT_INPUTS = 6,
+    VT_OUTPUTS = 8,
+    VT_ATTRIBUTES = 10
+  };
+  const ::flatbuffers::String *type() const {
+    return GetPointer<const ::flatbuffers::String *>(VT_TYPE);
+  }
+  ::flatbuffers::String *mutable_type() {
+    return GetPointer<::flatbuffers::String *>(VT_TYPE);
+  }
+  const ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>> *inputs() const {
+    return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>> *>(VT_INPUTS);
+  }
+  ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>> *mutable_inputs() {
+    return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>> *>(VT_INPUTS);
+  }
+  const ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>> *outputs() const {
+    return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>> *>(VT_OUTPUTS);
+  }
+  ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>> *mutable_outputs() {
+    return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>> *>(VT_OUTPUTS);
+  }
+  const ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>> *attributes() const {
+    return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>> *>(VT_ATTRIBUTES);
+  }
+  ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>> *mutable_attributes() {
+    return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>> *>(VT_ATTRIBUTES);
+  }
+  bool Verify(::flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, VT_TYPE) &&
+           verifier.VerifyString(type()) &&
+           VerifyOffset(verifier, VT_INPUTS) &&
+           verifier.VerifyVector(inputs()) &&
+           verifier.VerifyVectorOfTables(inputs()) &&
+           VerifyOffset(verifier, VT_OUTPUTS) &&
+           verifier.VerifyVector(outputs()) &&
+           verifier.VerifyVectorOfTables(outputs()) &&
+           VerifyOffset(verifier, VT_ATTRIBUTES) &&
+           verifier.VerifyVector(attributes()) &&
+           verifier.VerifyVectorOfTables(attributes()) &&
+           verifier.EndTable();
+  }
+};
+
+struct OperatorNodeDescBuilder {
+  typedef OperatorNodeDesc Table;
+  ::flatbuffers::FlatBufferBuilder &fbb_;
+  ::flatbuffers::uoffset_t start_;
+  void add_type(::flatbuffers::Offset<::flatbuffers::String> type) {
+    fbb_.AddOffset(OperatorNodeDesc::VT_TYPE, type);
+  }
+  void add_inputs(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>>> inputs) {
+    fbb_.AddOffset(OperatorNodeDesc::VT_INPUTS, inputs);
+  }
+  void add_outputs(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>>> outputs) {
+    fbb_.AddOffset(OperatorNodeDesc::VT_OUTPUTS, outputs);
+  }
+  void add_attributes(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>>> attributes) {
+    fbb_.AddOffset(OperatorNodeDesc::VT_ATTRIBUTES, attributes);
+  }
+  explicit OperatorNodeDescBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ::flatbuffers::Offset<OperatorNodeDesc> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = ::flatbuffers::Offset<OperatorNodeDesc>(end);
+    return o;
+  }
+};
+
+inline ::flatbuffers::Offset<OperatorNodeDesc> CreateOperatorNodeDesc(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    ::flatbuffers::Offset<::flatbuffers::String> type = 0,
+    ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>>> inputs = 0,
+    ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>>> outputs = 0,
+    ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>>> attributes = 0) {
+  OperatorNodeDescBuilder builder_(_fbb);
+  builder_.add_attributes(attributes);
+  builder_.add_outputs(outputs);
+  builder_.add_inputs(inputs);
+  builder_.add_type(type);
+  return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<OperatorNodeDesc> CreateOperatorNodeDescDirect(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    const char *type = nullptr,
+    const std::vector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>> *inputs = nullptr,
+    const std::vector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>> *outputs = nullptr,
+    const std::vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>> *attributes = nullptr) {
+  auto type__ = type ? _fbb.CreateString(type) : 0;
+  auto inputs__ = inputs ? _fbb.CreateVector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>>(*inputs) : 0;
+  auto outputs__ = outputs ? _fbb.CreateVector<::flatbuffers::Offset<dml::ir::DmlBufferTensorDesc>>(*outputs) : 0;
+  auto attributes__ = attributes ? _fbb.CreateVector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>>(*attributes) : 0;
+  return dml::ir::CreateOperatorNodeDesc(
+      _fbb,
+      type__,
+      inputs__,
+      outputs__,
+      attributes__);
+}
+
+struct DmlGraphNode FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+  typedef DmlGraphNodeBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_DESC_TYPE = 4,
+    VT_DESC = 6,
+    VT_NAME = 8,
+    VT_INPUTNAMES = 10,
+    VT_OUTPUTNAMES = 12
+  };
+  dml::ir::NodeDesc desc_type() const {
+    return static_cast<dml::ir::NodeDesc>(GetField<uint8_t>(VT_DESC_TYPE, 0));
+  }
+  const void *desc() const {
+    return GetPointer<const void *>(VT_DESC);
+  }
+  template<typename T> const T *desc_as() const;
+  const dml::ir::OperatorNodeDesc *desc_as_OperatorNodeDesc() const {
+    return desc_type() == dml::ir::NodeDesc_OperatorNodeDesc ? static_cast<const dml::ir::OperatorNodeDesc *>(desc()) : nullptr;
+  }
+  const dml::ir::ConstantNodeDesc *desc_as_ConstantNodeDesc() const {
+    return desc_type() == dml::ir::NodeDesc_ConstantNodeDesc ? static_cast<const dml::ir::ConstantNodeDesc *>(desc()) : nullptr;
+  }
+  void *mutable_desc() {
+    return GetPointer<void *>(VT_DESC);
+  }
+  const ::flatbuffers::String *name() const {
+    return GetPointer<const ::flatbuffers::String *>(VT_NAME);
+  }
+  ::flatbuffers::String *mutable_name() {
+    return GetPointer<::flatbuffers::String *>(VT_NAME);
+  }
+  const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *inputNames() const {
+    return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_INPUTNAMES);
+  }
+  ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *mutable_inputNames() {
+    return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_INPUTNAMES);
+  }
+  const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *outputNames() const {
+    return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_OUTPUTNAMES);
+  }
+  ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *mutable_outputNames() {
+    return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_OUTPUTNAMES);
+  }
+  bool Verify(::flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<uint8_t>(verifier, VT_DESC_TYPE, 1) &&
+           VerifyOffset(verifier, VT_DESC) &&
+           VerifyNodeDesc(verifier, desc(), desc_type()) &&
+           VerifyOffset(verifier, VT_NAME) &&
+           verifier.VerifyString(name()) &&
+           VerifyOffset(verifier, VT_INPUTNAMES) &&
+           verifier.VerifyVector(inputNames()) &&
+           verifier.VerifyVectorOfStrings(inputNames()) &&
+           VerifyOffset(verifier, VT_OUTPUTNAMES) &&
+           verifier.VerifyVector(outputNames()) &&
+           verifier.VerifyVectorOfStrings(outputNames()) &&
+           verifier.EndTable();
+  }
+};
+
+template<> inline const dml::ir::OperatorNodeDesc *DmlGraphNode::desc_as<dml::ir::OperatorNodeDesc>() const {
+  return desc_as_OperatorNodeDesc();
+}
+
+template<> inline const dml::ir::ConstantNodeDesc *DmlGraphNode::desc_as<dml::ir::ConstantNodeDesc>() const {
+  return desc_as_ConstantNodeDesc();
+}
+
+struct DmlGraphNodeBuilder {
+  typedef DmlGraphNode Table;
+  ::flatbuffers::FlatBufferBuilder &fbb_;
+  ::flatbuffers::uoffset_t start_;
+  void add_desc_type(dml::ir::NodeDesc desc_type) {
+    fbb_.AddElement<uint8_t>(DmlGraphNode::VT_DESC_TYPE, static_cast<uint8_t>(desc_type), 0);
+  }
+  void add_desc(::flatbuffers::Offset<void> desc) {
+    fbb_.AddOffset(DmlGraphNode::VT_DESC, desc);
+  }
+  void add_name(::flatbuffers::Offset<::flatbuffers::String> name) {
+    fbb_.AddOffset(DmlGraphNode::VT_NAME, name);
+  }
+  void add_inputNames(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> inputNames) {
+    fbb_.AddOffset(DmlGraphNode::VT_INPUTNAMES, inputNames);
+  }
+  void add_outputNames(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> outputNames) {
+    fbb_.AddOffset(DmlGraphNode::VT_OUTPUTNAMES, outputNames);
+  }
+  explicit DmlGraphNodeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ::flatbuffers::Offset<DmlGraphNode> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = ::flatbuffers::Offset<DmlGraphNode>(end);
+    return o;
+  }
+};
+
+inline ::flatbuffers::Offset<DmlGraphNode> CreateDmlGraphNode(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    dml::ir::NodeDesc desc_type = dml::ir::NodeDesc_NONE,
+    ::flatbuffers::Offset<void> desc = 0,
+    ::flatbuffers::Offset<::flatbuffers::String> name = 0,
+    ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> inputNames = 0,
+    ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> outputNames = 0) {
+  DmlGraphNodeBuilder builder_(_fbb);
+  builder_.add_outputNames(outputNames);
+  builder_.add_inputNames(inputNames);
+  builder_.add_name(name);
+  builder_.add_desc(desc);
+  builder_.add_desc_type(desc_type);
+  return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<DmlGraphNode> CreateDmlGraphNodeDirect(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    dml::ir::NodeDesc desc_type = dml::ir::NodeDesc_NONE,
+    ::flatbuffers::Offset<void> desc = 0,
+    const char *name = nullptr,
+    const std::vector<::flatbuffers::Offset<::flatbuffers::String>> *inputNames = nullptr,
+    const std::vector<::flatbuffers::Offset<::flatbuffers::String>> *outputNames = nullptr) {
+  auto name__ = name ? _fbb.CreateString(name) : 0;
+  auto inputNames__ = inputNames ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>(*inputNames) : 0;
+  auto outputNames__ = outputNames ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>(*outputNames) : 0;
+  return dml::ir::CreateDmlGraphNode(
+      _fbb,
+      desc_type,
+      desc,
+      name__,
+      inputNames__,
+      outputNames__);
+}
+
+struct DmlGraphDesc FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+  typedef DmlGraphDescBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_NODES = 4,
+    VT_GRAPHINPUTNAMES = 6,
+    VT_GRAPHOUTPUTNAMES = 8
+  };
+  const ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlGraphNode>> *nodes() const {
+    return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlGraphNode>> *>(VT_NODES);
+  }
+  ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlGraphNode>> *mutable_nodes() {
+    return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlGraphNode>> *>(VT_NODES);
+  }
+  const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *graphInputNames() const {
+    return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_GRAPHINPUTNAMES);
+  }
+  ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *mutable_graphInputNames() {
+    return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_GRAPHINPUTNAMES);
+  }
+  const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *graphOutputNames() const {
+    return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_GRAPHOUTPUTNAMES);
+  }
+  ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *mutable_graphOutputNames() {
+    return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>> *>(VT_GRAPHOUTPUTNAMES);
+  }
+  bool Verify(::flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, VT_NODES) &&
+           verifier.VerifyVector(nodes()) &&
+           verifier.VerifyVectorOfTables(nodes()) &&
+           VerifyOffset(verifier, VT_GRAPHINPUTNAMES) &&
+           verifier.VerifyVector(graphInputNames()) &&
+           verifier.VerifyVectorOfStrings(graphInputNames()) &&
+           VerifyOffset(verifier, VT_GRAPHOUTPUTNAMES) &&
+           verifier.VerifyVector(graphOutputNames()) &&
+           verifier.VerifyVectorOfStrings(graphOutputNames()) &&
+           verifier.EndTable();
+  }
+};
+
+struct DmlGraphDescBuilder {
+  typedef DmlGraphDesc Table;
+  ::flatbuffers::FlatBufferBuilder &fbb_;
+  ::flatbuffers::uoffset_t start_;
+  void add_nodes(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlGraphNode>>> nodes) {
+    fbb_.AddOffset(DmlGraphDesc::VT_NODES, nodes);
+  }
+  void add_graphInputNames(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> graphInputNames) {
+    fbb_.AddOffset(DmlGraphDesc::VT_GRAPHINPUTNAMES, graphInputNames);
+  }
+  void add_graphOutputNames(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> graphOutputNames) {
+    fbb_.AddOffset(DmlGraphDesc::VT_GRAPHOUTPUTNAMES, graphOutputNames);
+  }
+  explicit DmlGraphDescBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ::flatbuffers::Offset<DmlGraphDesc> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = ::flatbuffers::Offset<DmlGraphDesc>(end);
+    return o;
+  }
+};
+
+inline ::flatbuffers::Offset<DmlGraphDesc> CreateDmlGraphDesc(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::DmlGraphNode>>> nodes = 0,
+    ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> graphInputNames = 0,
+    ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> graphOutputNames = 0) {
+  DmlGraphDescBuilder builder_(_fbb);
+  builder_.add_graphOutputNames(graphOutputNames);
+  builder_.add_graphInputNames(graphInputNames);
+  builder_.add_nodes(nodes);
+  return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<DmlGraphDesc> CreateDmlGraphDescDirect(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    const std::vector<::flatbuffers::Offset<dml::ir::DmlGraphNode>> *nodes = nullptr,
+    const std::vector<::flatbuffers::Offset<::flatbuffers::String>> *graphInputNames = nullptr,
+    const std::vector<::flatbuffers::Offset<::flatbuffers::String>> *graphOutputNames = nullptr) {
+  auto nodes__ = nodes ? _fbb.CreateVector<::flatbuffers::Offset<dml::ir::DmlGraphNode>>(*nodes) : 0;
+  auto graphInputNames__ = graphInputNames ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>(*graphInputNames) : 0;
+  auto graphOutputNames__ = graphOutputNames ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>(*graphOutputNames) : 0;
+  return dml::ir::CreateDmlGraphDesc(
+      _fbb,
+      nodes__,
+      graphInputNames__,
+      graphOutputNames__);
+}
+
+inline bool VerifyConstantNodeDescDetail(::flatbuffers::Verifier &verifier, const void *obj, ConstantNodeDescDetail type) {
+  switch (type) {
+    case ConstantNodeDescDetail_NONE: {
+      return true;
+    }
+    case ConstantNodeDescDetail_ConstantName: {
+      auto ptr = reinterpret_cast<const dml::ir::ConstantName *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
+    case ConstantNodeDescDetail_ConstantRawData: {
+      auto ptr = reinterpret_cast<const dml::ir::ConstantRawData *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
+    default: return true;
+  }
+}
+
+inline bool VerifyConstantNodeDescDetailVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, const ::flatbuffers::Vector<uint8_t> *types) {
+  if (!values || !types) return !values && !types;
+  if (values->size() != types->size()) return false;
+  for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+    if (!VerifyConstantNodeDescDetail(
+        verifier,  values->Get(i), types->GetEnum<ConstantNodeDescDetail>(i))) {
+      return false;
+    }
+  }
+  return true;
+}
+
+inline bool VerifyNodeDesc(::flatbuffers::Verifier &verifier, const void *obj, NodeDesc type) {
+  switch (type) {
+    case NodeDesc_NONE: {
+      return true;
+    }
+    case NodeDesc_OperatorNodeDesc: {
+      auto ptr = reinterpret_cast<const dml::ir::OperatorNodeDesc *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
+    case NodeDesc_ConstantNodeDesc: {
+      auto ptr = reinterpret_cast<const dml::ir::ConstantNodeDesc *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
+    default: return true;
+  }
+}
+
+inline bool VerifyNodeDescVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, const ::flatbuffers::Vector<uint8_t> *types) {
+  if (!values || !types) return !values && !types;
+  if (values->size() != types->size()) return false;
+  for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+    if (!VerifyNodeDesc(
+        verifier,  values->Get(i), types->GetEnum<NodeDesc>(i))) {
+      return false;
+    }
+  }
+  return true;
+}
+
+inline const dml::ir::DmlGraphDesc *GetDmlGraphDesc(const void *buf) {
+  return ::flatbuffers::GetRoot<dml::ir::DmlGraphDesc>(buf);
+}
+
+inline const dml::ir::DmlGraphDesc *GetSizePrefixedDmlGraphDesc(const void *buf) {
+  return ::flatbuffers::GetSizePrefixedRoot<dml::ir::DmlGraphDesc>(buf);
+}
+
+inline DmlGraphDesc *GetMutableDmlGraphDesc(void *buf) {
+  return ::flatbuffers::GetMutableRoot<DmlGraphDesc>(buf);
+}
+
+inline dml::ir::DmlGraphDesc *GetMutableSizePrefixedDmlGraphDesc(void *buf) {
+  return ::flatbuffers::GetMutableSizePrefixedRoot<dml::ir::DmlGraphDesc>(buf);
+}
+
+inline bool VerifyDmlGraphDescBuffer(
+    ::flatbuffers::Verifier &verifier) {
+  return verifier.VerifyBuffer<dml::ir::DmlGraphDesc>(nullptr);
+}
+
+inline bool VerifySizePrefixedDmlGraphDescBuffer(
+    ::flatbuffers::Verifier &verifier) {
+  return verifier.VerifySizePrefixedBuffer<dml::ir::DmlGraphDesc>(nullptr);
+}
+
+inline void FinishDmlGraphDescBuffer(
+    ::flatbuffers::FlatBufferBuilder &fbb,
+    ::flatbuffers::Offset<dml::ir::DmlGraphDesc> root) {
+  fbb.Finish(root);
+}
+
+inline void FinishSizePrefixedDmlGraphDescBuffer(
+    ::flatbuffers::FlatBufferBuilder &fbb,
+    ::flatbuffers::Offset<dml::ir::DmlGraphDesc> root) {
+  fbb.FinishSizePrefixed(root);
+}
+
+}  // namespace ir
+}  // namespace dml
+
+#endif  // FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h
new file mode 100644
index 0000000000000..9decf0dce1bb2
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h
@@ -0,0 +1,14 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+
+#pragma once
+#include "DmlSerializedGraphDesc.h"
+
+struct NodeIndex
+{
+    uint32_t nodeIndex;
+    uint32_t nodeOutputIndex;
+};
+
+DmlSerializedGraphDesc DeserializeDmlGraph(
+    const uint8_t* flatbufferGraphDescBlob,
+    /*out*/ std::vector<std::unique_ptr<std::byte[]>>& rawData);
\ No newline at end of file
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h
new file mode 100644
index 0000000000000..d8d069da906b7
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h
@@ -0,0 +1,8 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+
+#pragma once
+#include "DmlGraphDesc_generated.h"
+
+struct DmlSerializedGraphDesc;
+
+flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& graphDesc);
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h
new file mode 100644
index 0000000000000..51c3d6c81244b
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h
@@ -0,0 +1,73 @@
+//-----------------------------------------------------------------------------
+//
+//  Copyright (c) Microsoft Corporation. All rights reserved.
+//
+//-----------------------------------------------------------------------------
+
+#pragma once
+
+struct ConstantName
+{
+    std::string name;
+};
+
+struct ConstantData
+{
+    std::byte* data;
+    uint64_t dataSize;
+};
+
+using DmlSerializedGraphNodeConstantVariant = std::variant<
+    ConstantName,
+    ConstantData
+>;
+
+using DmlSerializedGraphNodeDescVariant = std::variant<
+    AbstractOperatorDesc,
+    DmlSerializedGraphNodeConstantVariant
+>;
+
+struct DmlSerializedGraphNode   
+{
+    DmlSerializedGraphNodeDescVariant Desc;
+    std::string Name; 
+};
+
+struct DmlInputSerializedGraphEdge
+{
+    uint32_t GraphInputIndex; 
+    uint32_t ToNodeIndex; 
+    uint32_t ToNodeInputIndex; 
+    std::string Name; 
+};
+
+struct DmlOutputSerializedGraphEdge
+{
+    uint32_t FromNodeIndex; 
+    uint32_t FromNodeOutputIndex; 
+    uint32_t GraphOutputIndex; 
+    std::string Name; 
+};
+
+struct DmlIntermediateSerializedGraphEdge
+{
+    uint32_t FromNodeIndex; 
+    uint32_t FromNodeOutputIndex; 
+    uint32_t ToNodeIndex; 
+    uint32_t ToNodeInputIndex; 
+    std::string Name; 
+};
+
+struct DmlSerializedGraphDesc
+{
+    uint32_t InputCount;
+    uint32_t OutputCount;
+    // nodes must be present in topological order for deserialization to work
+    // because while creating a intermediate edge during deserialization, node (from
+    // which given intermediate edge is outputting) must be visited before than the node
+    // (to which given intermediate edge is inputting)
+    std::vector<DmlSerializedGraphNode> Nodes;
+    std::vector<DmlInputSerializedGraphEdge> InputEdges;
+    std::vector<DmlOutputSerializedGraphEdge> OutputEdges;
+    std::vector<DmlIntermediateSerializedGraphEdge> IntermediateEdges;
+};
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
index 99218c135f058..86c66d8cca26c 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
@@ -425,7 +425,6 @@ inline std::vector<OperatorField> GetFields(const DML_AVERAGE_POOLING_OPERATOR_D
         OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<UINT>(desc.IncludePadding))),
     };
 }
-
 inline std::vector<OperatorField> GetFields(const DML_AVERAGE_POOLING1_OPERATOR_DESC& desc)
 {
     return {
@@ -502,24 +501,6 @@ inline std::vector<OperatorField> GetFields(const DML_ROI_POOLING_OPERATOR_DESC&
         OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<DML_SIZE_2D>(desc.PooledSize))),
     };
 }
-inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc)
-{
-    return {
-        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
-        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputScaleTensor))),
-        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputZeroPointTensor))),
-        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleTensor))),
-        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputZeroPointTensor))),
-        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
-        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
-        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const UINT*>(desc.Strides), desc.DimensionCount)),
-        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const UINT*>(desc.WindowSize), desc.DimensionCount)),
-        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
-        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
-        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast<const UINT*>(desc.Dilations), desc.DimensionCount)),
-        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast<UINT>(desc.IncludePadding))),
-    };
-}
 inline std::vector<OperatorField> GetFields(const DML_SLICE_OPERATOR_DESC& desc)
 {
     return {
@@ -1158,6 +1139,19 @@ inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_MATRIX_MU
         OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
     };
 }
+inline std::vector<OperatorField> GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc)
+{
+    return {
+        OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
+        OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AScaleTensor))),
+        OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AZeroPointTensor))),
+        OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
+        OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BScaleTensor))),
+        OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BZeroPointTensor))),
+        OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BiasTensor))),
+        OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
+    };
+}
 inline std::vector<OperatorField> GetFields(const DML_CONVOLUTION_INTEGER_OPERATOR_DESC& desc)
 {
     return {
@@ -1488,6 +1482,24 @@ inline std::vector<OperatorField> GetFields(const DML_MULTIHEAD_ATTENTION_OPERAT
         OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[17], ToOperatorFieldType(static_cast<UINT>(desc.MaskType))),
     };
 }
+inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc)
+{
+    return {
+        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
+        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputScaleTensor))),
+        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputZeroPointTensor))),
+        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleTensor))),
+        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputZeroPointTensor))),
+        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
+        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
+        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const UINT*>(desc.Strides), desc.DimensionCount)),
+        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const UINT*>(desc.WindowSize), desc.DimensionCount)),
+        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
+        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
+        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast<const UINT*>(desc.Dilations), desc.DimensionCount)),
+        OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast<UINT>(desc.IncludePadding))),
+    };
+}
 inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc)
 {
     return {
@@ -1680,6 +1692,23 @@ inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_GELU_OPERATOR_D
         OperatorField(&DML_ACTIVATION_GELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
     };
 }
+inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_SWISH_OPERATOR_DESC& desc)
+{
+    return {
+        OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
+        OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
+        OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<FLOAT>(desc.SigmoidInputScale))),
+    };
+}
+inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC& desc)
+{
+    return {
+        OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
+        OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
+        OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<FLOAT>(desc.Alpha))),
+        OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<FLOAT>(desc.Beta))),
+    };
+}
 inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
 {
     switch (operatorType)
@@ -1800,6 +1829,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
     case DML_OPERATOR_RESAMPLE1: return DML_RESAMPLE1_OPERATOR_SCHEMA;
     case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER: return DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA;
     case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA;
+    case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA;
     case DML_OPERATOR_CONVOLUTION_INTEGER: return DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA;
     case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA;
     case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA;
@@ -1826,6 +1856,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
     case DML_OPERATOR_RESAMPLE_GRAD1: return DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA;
     case DML_OPERATOR_DIAGONAL_MATRIX1: return DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA;
     case DML_OPERATOR_MULTIHEAD_ATTENTION: return DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA;
+    case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA;
     case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA;
     case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA;
     case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA;
@@ -1850,6 +1881,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
     case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA;
     case DML_OPERATOR_ACTIVATION_SHRINK: return DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA;
     case DML_OPERATOR_ACTIVATION_GELU: return DML_ACTIVATION_GELU_OPERATOR_SCHEMA;
+    case DML_OPERATOR_ACTIVATION_SWISH: return DML_ACTIVATION_SWISH_OPERATOR_SCHEMA;
+    case DML_OPERATOR_ACTIVATION_HARD_SWISH: return DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA;
 
     default:
         ORT_THROW_HR(E_INVALIDARG);
@@ -2327,6 +2360,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
         return AbstractOperatorDesc(
             &DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA,
             GetFields(*static_cast<const DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC*>(opDesc.Desc)));
+    case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
+        return AbstractOperatorDesc(
+            &DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA,
+            GetFields(*static_cast<const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC*>(opDesc.Desc)));
     case DML_OPERATOR_CONVOLUTION_INTEGER:
         return AbstractOperatorDesc(
             &DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA,
@@ -2431,6 +2468,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
         return AbstractOperatorDesc(
             &DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA,
             GetFields(*static_cast<const DML_MULTIHEAD_ATTENTION_OPERATOR_DESC*>(opDesc.Desc)));
+    case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
+        return AbstractOperatorDesc(
+            &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA,
+            GetFields(*static_cast<const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC*>(opDesc.Desc)));
     case DML_OPERATOR_ACTIVATION_ELU:
         return AbstractOperatorDesc(
             &DML_ACTIVATION_ELU_OPERATOR_SCHEMA,
@@ -2527,13 +2568,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
         return AbstractOperatorDesc(
             &DML_ACTIVATION_GELU_OPERATOR_SCHEMA,
             GetFields(*static_cast<const DML_ACTIVATION_GELU_OPERATOR_DESC*>(opDesc.Desc)));
-#pragma warning(push)
-#pragma warning(disable: 4063)
-    case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
+    case DML_OPERATOR_ACTIVATION_SWISH:
         return AbstractOperatorDesc(
-            &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA,
-            GetFields(*static_cast<const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC*>(opDesc.Desc)));
-#pragma warning(pop)
+            &DML_ACTIVATION_SWISH_OPERATOR_SCHEMA,
+            GetFields(*static_cast<const DML_ACTIVATION_SWISH_OPERATOR_DESC*>(opDesc.Desc)));
+    case DML_OPERATOR_ACTIVATION_HARD_SWISH:
+        return AbstractOperatorDesc(
+            &DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA,
+            GetFields(*static_cast<const DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC*>(opDesc.Desc)));
 
     default:
         ORT_THROW_HR(E_INVALIDARG);
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h
index 25f0dd26c6067..a94bb67b68d36 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h
@@ -15,32 +15,34 @@ using ApiAttributeVariant = std::variant<
     const FLOAT*, 
     const DML_SCALE_BIAS*, 
     DML_SIZE_2D, 
-    DML_SCALAR_UNION
+    DML_SCALAR_UNION, 
+    BOOL
     >;
 
 namespace OperatorFieldTypes
 {
     using TensorDesc = std::optional<DmlBufferTensorDesc>; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC
     using TensorDescArray = std::optional<std::vector<DmlBufferTensorDesc>>; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY
-    using OperatorDesc = std::optional<AbstractOperatorDesc>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC
-    using OperatorDescArray = std::optional<std::vector<AbstractOperatorDesc>>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY
+    using FusedActivationOperatorDesc = std::optional<AbstractOperatorDesc>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC
+    using FusedActivationOperatorDescArray = std::optional<std::vector<AbstractOperatorDesc>>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY
     using UInt = uint32_t; // DML_SCHEMA_FIELD_TYPE_UINT
     using UInt64 = uint64_t; // DML_SCHEMA_FIELD_TYPE_UINT64
     using Int = int32_t; // DML_SCHEMA_FIELD_TYPE_INT
     using Float = float; // DML_SCHEMA_FIELD_TYPE_FLOAT
-    using UIntArray = std::optional<std::vector<uint32_t>>; // DML_SCHEMA_FIELD_TYPE_UINT_ARRAY
-    using IntArray = std::optional<std::vector<int32_t>>; // DML_SCHEMA_FIELD_TYPE_INT_ARRAY
-    using FloatArray = std::optional<std::vector<float>>; // DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY
+    using UIntArray = std::vector<uint32_t>; // DML_SCHEMA_FIELD_TYPE_UINT_ARRAY
+    using IntArray = std::vector<int32_t>; // DML_SCHEMA_FIELD_TYPE_INT_ARRAY
+    using FloatArray = std::vector<float>; // DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY
     using ScaleBias = std::optional<DML_SCALE_BIAS>; // DML_SCHEMA_FIELD_TYPE_SCALE_BIAS
     using Size2D = DML_SIZE_2D; // DML_SCHEMA_FIELD_TYPE_SIZE_2D
     using ScalarUnion = DML_SCALAR_UNION; // DML_SCHEMA_FIELD_TYPE_SCALAR_UNION
+    using Bool = bool; // DML_SCHEMA_FIELD_TYPE_BOOL
 }
 
 using OperatorFieldVariant = std::variant<
     OperatorFieldTypes::TensorDesc, 
     OperatorFieldTypes::TensorDescArray, 
-    OperatorFieldTypes::OperatorDesc, 
-    OperatorFieldTypes::OperatorDescArray, 
+    OperatorFieldTypes::FusedActivationOperatorDesc, 
+    OperatorFieldTypes::FusedActivationOperatorDescArray, 
     OperatorFieldTypes::UInt, 
     OperatorFieldTypes::UInt64, 
     OperatorFieldTypes::Int, 
@@ -50,7 +52,8 @@ using OperatorFieldVariant = std::variant<
     OperatorFieldTypes::FloatArray, 
     OperatorFieldTypes::ScaleBias, 
     OperatorFieldTypes::Size2D, 
-    OperatorFieldTypes::ScalarUnion
+    OperatorFieldTypes::ScalarUnion, 
+    OperatorFieldTypes::Bool
     >;
 
 class OperatorField
@@ -80,11 +83,11 @@ class OperatorField
     const OperatorFieldTypes::TensorDescArray& AsTensorDescArray() const { return std::get<OperatorFieldTypes::TensorDescArray>(m_data); }
     OperatorFieldTypes::TensorDescArray& AsTensorDescArray() { return std::get<OperatorFieldTypes::TensorDescArray>(m_data); }
 
-    const OperatorFieldTypes::OperatorDesc& AsOperatorDesc() const { return std::get<OperatorFieldTypes::OperatorDesc>(m_data); }
-    OperatorFieldTypes::OperatorDesc& AsOperatorDesc() { return std::get<OperatorFieldTypes::OperatorDesc>(m_data); }
+    const OperatorFieldTypes::FusedActivationOperatorDesc& AsFusedActivationOperatorDesc() const { return std::get<OperatorFieldTypes::FusedActivationOperatorDesc>(m_data); }
+    OperatorFieldTypes::FusedActivationOperatorDesc& AsFusedActivationOperatorDesc() { return std::get<OperatorFieldTypes::FusedActivationOperatorDesc>(m_data); }
 
-    const OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() const { return std::get<OperatorFieldTypes::OperatorDescArray>(m_data); }
-    OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() { return std::get<OperatorFieldTypes::OperatorDescArray>(m_data); }
+    const OperatorFieldTypes::FusedActivationOperatorDescArray& AsFusedActivationOperatorDescArray() const { return std::get<OperatorFieldTypes::FusedActivationOperatorDescArray>(m_data); }
+    OperatorFieldTypes::FusedActivationOperatorDescArray& AsFusedActivationOperatorDescArray() { return std::get<OperatorFieldTypes::FusedActivationOperatorDescArray>(m_data); }
 
     const OperatorFieldTypes::UInt& AsUInt() const { return std::get<OperatorFieldTypes::UInt>(m_data); }
     OperatorFieldTypes::UInt& AsUInt() { return std::get<OperatorFieldTypes::UInt>(m_data); }
@@ -116,6 +119,9 @@ class OperatorField
     const OperatorFieldTypes::ScalarUnion& AsScalarUnion() const { return std::get<OperatorFieldTypes::ScalarUnion>(m_data); }
     OperatorFieldTypes::ScalarUnion& AsScalarUnion() { return std::get<OperatorFieldTypes::ScalarUnion>(m_data); }
 
+    const OperatorFieldTypes::Bool& AsBool() const { return std::get<OperatorFieldTypes::Bool>(m_data); }
+    OperatorFieldTypes::Bool& AsBool() { return std::get<OperatorFieldTypes::Bool>(m_data); }
+
 private:
     const DML_SCHEMA_FIELD* m_schema;
     OperatorFieldVariant m_data;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h
new file mode 100644
index 0000000000000..639c31f0dc5c8
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h
@@ -0,0 +1,1323 @@
+// automatically generated by the FlatBuffers compiler, do not modify
+
+
+#ifndef FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_
+#define FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_
+
+#include "core/common/flatbuffers.h"
+
+// Ensure the included flatbuffers.h is the same version as when this file was
+// generated, otherwise it may not be compatible.
+static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
+              FLATBUFFERS_VERSION_MINOR == 5 &&
+              FLATBUFFERS_VERSION_REVISION == 26,
+             "Non-compatible flatbuffers version included");
+
+namespace dml {
+namespace ir {
+namespace operatorFieldTypes {
+
+struct AttributeDesc;
+struct AttributeDescBuilder;
+
+struct Activation;
+struct ActivationBuilder;
+
+struct ActivationArray;
+struct ActivationArrayBuilder;
+
+struct UInt8;
+
+struct UInt16;
+
+struct UInt32;
+
+struct UInt64;
+
+struct Int8;
+
+struct Int16;
+
+struct Int32;
+
+struct Int64;
+
+struct Float32;
+
+struct Float64;
+
+struct UIntArray;
+struct UIntArrayBuilder;
+
+struct IntArray;
+struct IntArrayBuilder;
+
+struct FloatArray;
+struct FloatArrayBuilder;
+
+struct ScaleBias;
+
+struct Size2D;
+
+struct ByteArray;
+
+struct ScalarUnionData;
+struct ScalarUnionDataBuilder;
+
+struct Bool;
+
+enum AttributeFieldVariant : uint8_t {
+  AttributeFieldVariant_NONE = 0,
+  AttributeFieldVariant_Activation = 1,
+  AttributeFieldVariant_ActivationArray = 2,
+  AttributeFieldVariant_UInt32 = 3,
+  AttributeFieldVariant_UInt64 = 4,
+  AttributeFieldVariant_Int32 = 5,
+  AttributeFieldVariant_Float32 = 6,
+  AttributeFieldVariant_UIntArray = 7,
+  AttributeFieldVariant_IntArray = 8,
+  AttributeFieldVariant_FloatArray = 9,
+  AttributeFieldVariant_ScaleBias = 10,
+  AttributeFieldVariant_Size2D = 11,
+  AttributeFieldVariant_ScalarUnionData = 12,
+  AttributeFieldVariant_Bool = 13,
+  AttributeFieldVariant_MIN = AttributeFieldVariant_NONE,
+  AttributeFieldVariant_MAX = AttributeFieldVariant_Bool
+};
+
+inline const AttributeFieldVariant (&EnumValuesAttributeFieldVariant())[14] {
+  static const AttributeFieldVariant values[] = {
+    AttributeFieldVariant_NONE,
+    AttributeFieldVariant_Activation,
+    AttributeFieldVariant_ActivationArray,
+    AttributeFieldVariant_UInt32,
+    AttributeFieldVariant_UInt64,
+    AttributeFieldVariant_Int32,
+    AttributeFieldVariant_Float32,
+    AttributeFieldVariant_UIntArray,
+    AttributeFieldVariant_IntArray,
+    AttributeFieldVariant_FloatArray,
+    AttributeFieldVariant_ScaleBias,
+    AttributeFieldVariant_Size2D,
+    AttributeFieldVariant_ScalarUnionData,
+    AttributeFieldVariant_Bool
+  };
+  return values;
+}
+
+inline const char * const *EnumNamesAttributeFieldVariant() {
+  static const char * const names[15] = {
+    "NONE",
+    "Activation",
+    "ActivationArray",
+    "UInt32",
+    "UInt64",
+    "Int32",
+    "Float32",
+    "UIntArray",
+    "IntArray",
+    "FloatArray",
+    "ScaleBias",
+    "Size2D",
+    "ScalarUnionData",
+    "Bool",
+    nullptr
+  };
+  return names;
+}
+
+inline const char *EnumNameAttributeFieldVariant(AttributeFieldVariant e) {
+  if (::flatbuffers::IsOutRange(e, AttributeFieldVariant_NONE, AttributeFieldVariant_Bool)) return "";
+  const size_t index = static_cast<size_t>(e);
+  return EnumNamesAttributeFieldVariant()[index];
+}
+
+template<typename T> struct AttributeFieldVariantTraits {
+  static const AttributeFieldVariant enum_value = AttributeFieldVariant_NONE;
+};
+
+template<> struct AttributeFieldVariantTraits<dml::ir::operatorFieldTypes::Activation> {
+  static const AttributeFieldVariant enum_value = AttributeFieldVariant_Activation;
+};
+
+template<> struct AttributeFieldVariantTraits<dml::ir::operatorFieldTypes::ActivationArray> {
+  static const AttributeFieldVariant enum_value = AttributeFieldVariant_ActivationArray;
+};
+
+template<> struct AttributeFieldVariantTraits<dml::ir::operatorFieldTypes::UInt32> {
+  static const AttributeFieldVariant enum_value = AttributeFieldVariant_UInt32;
+};
+
+template<> struct AttributeFieldVariantTraits<dml::ir::operatorFieldTypes::UInt64> {
+  static const AttributeFieldVariant enum_value = AttributeFieldVariant_UInt64;
+};
+
+template<> struct AttributeFieldVariantTraits<dml::ir::operatorFieldTypes::Int32> {
+  static const AttributeFieldVariant enum_value = AttributeFieldVariant_Int32;
+};
+
+template<> struct AttributeFieldVariantTraits<dml::ir::operatorFieldTypes::Float32> {
+  static const AttributeFieldVariant enum_value = AttributeFieldVariant_Float32;
+};
+
+template<> struct AttributeFieldVariantTraits<dml::ir::operatorFieldTypes::UIntArray> {
+  static const AttributeFieldVariant enum_value = AttributeFieldVariant_UIntArray;
+};
+
+template<> struct AttributeFieldVariantTraits<dml::ir::operatorFieldTypes::IntArray> {
+  static const AttributeFieldVariant enum_value = AttributeFieldVariant_IntArray;
+};
+
+template<> struct AttributeFieldVariantTraits<dml::ir::operatorFieldTypes::FloatArray> {
+  static const AttributeFieldVariant enum_value = AttributeFieldVariant_FloatArray;
+};
+
+template<> struct AttributeFieldVariantTraits<dml::ir::operatorFieldTypes::ScaleBias> {
+  static const AttributeFieldVariant enum_value = AttributeFieldVariant_ScaleBias;
+};
+
+template<> struct AttributeFieldVariantTraits<dml::ir::operatorFieldTypes::Size2D> {
+  static const AttributeFieldVariant enum_value = AttributeFieldVariant_Size2D;
+};
+
+template<> struct AttributeFieldVariantTraits<dml::ir::operatorFieldTypes::ScalarUnionData> {
+  static const AttributeFieldVariant enum_value = AttributeFieldVariant_ScalarUnionData;
+};
+
+template<> struct AttributeFieldVariantTraits<dml::ir::operatorFieldTypes::Bool> {
+  static const AttributeFieldVariant enum_value = AttributeFieldVariant_Bool;
+};
+
+bool VerifyAttributeFieldVariant(::flatbuffers::Verifier &verifier, const void *obj, AttributeFieldVariant type);
+bool VerifyAttributeFieldVariantVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, const ::flatbuffers::Vector<uint8_t> *types);
+
+enum ScalarVariant : uint8_t {
+  ScalarVariant_NONE = 0,
+  ScalarVariant_ByteArray = 1,
+  ScalarVariant_Int8 = 2,
+  ScalarVariant_UInt8 = 3,
+  ScalarVariant_Int16 = 4,
+  ScalarVariant_UInt16 = 5,
+  ScalarVariant_Int32 = 6,
+  ScalarVariant_UInt32 = 7,
+  ScalarVariant_Int64 = 8,
+  ScalarVariant_UInt64 = 9,
+  ScalarVariant_Float32 = 10,
+  ScalarVariant_Float64 = 11,
+  ScalarVariant_MIN = ScalarVariant_NONE,
+  ScalarVariant_MAX = ScalarVariant_Float64
+};
+
+inline const ScalarVariant (&EnumValuesScalarVariant())[12] {
+  static const ScalarVariant values[] = {
+    ScalarVariant_NONE,
+    ScalarVariant_ByteArray,
+    ScalarVariant_Int8,
+    ScalarVariant_UInt8,
+    ScalarVariant_Int16,
+    ScalarVariant_UInt16,
+    ScalarVariant_Int32,
+    ScalarVariant_UInt32,
+    ScalarVariant_Int64,
+    ScalarVariant_UInt64,
+    ScalarVariant_Float32,
+    ScalarVariant_Float64
+  };
+  return values;
+}
+
+inline const char * const *EnumNamesScalarVariant() {
+  static const char * const names[13] = {
+    "NONE",
+    "ByteArray",
+    "Int8",
+    "UInt8",
+    "Int16",
+    "UInt16",
+    "Int32",
+    "UInt32",
+    "Int64",
+    "UInt64",
+    "Float32",
+    "Float64",
+    nullptr
+  };
+  return names;
+}
+
+inline const char *EnumNameScalarVariant(ScalarVariant e) {
+  if (::flatbuffers::IsOutRange(e, ScalarVariant_NONE, ScalarVariant_Float64)) return "";
+  const size_t index = static_cast<size_t>(e);
+  return EnumNamesScalarVariant()[index];
+}
+
+template<typename T> struct ScalarVariantTraits {
+  static const ScalarVariant enum_value = ScalarVariant_NONE;
+};
+
+template<> struct ScalarVariantTraits<dml::ir::operatorFieldTypes::ByteArray> {
+  static const ScalarVariant enum_value = ScalarVariant_ByteArray;
+};
+
+template<> struct ScalarVariantTraits<dml::ir::operatorFieldTypes::Int8> {
+  static const ScalarVariant enum_value = ScalarVariant_Int8;
+};
+
+template<> struct ScalarVariantTraits<dml::ir::operatorFieldTypes::UInt8> {
+  static const ScalarVariant enum_value = ScalarVariant_UInt8;
+};
+
+template<> struct ScalarVariantTraits<dml::ir::operatorFieldTypes::Int16> {
+  static const ScalarVariant enum_value = ScalarVariant_Int16;
+};
+
+template<> struct ScalarVariantTraits<dml::ir::operatorFieldTypes::UInt16> {
+  static const ScalarVariant enum_value = ScalarVariant_UInt16;
+};
+
+template<> struct ScalarVariantTraits<dml::ir::operatorFieldTypes::Int32> {
+  static const ScalarVariant enum_value = ScalarVariant_Int32;
+};
+
+template<> struct ScalarVariantTraits<dml::ir::operatorFieldTypes::UInt32> {
+  static const ScalarVariant enum_value = ScalarVariant_UInt32;
+};
+
+template<> struct ScalarVariantTraits<dml::ir::operatorFieldTypes::Int64> {
+  static const ScalarVariant enum_value = ScalarVariant_Int64;
+};
+
+template<> struct ScalarVariantTraits<dml::ir::operatorFieldTypes::UInt64> {
+  static const ScalarVariant enum_value = ScalarVariant_UInt64;
+};
+
+template<> struct ScalarVariantTraits<dml::ir::operatorFieldTypes::Float32> {
+  static const ScalarVariant enum_value = ScalarVariant_Float32;
+};
+
+template<> struct ScalarVariantTraits<dml::ir::operatorFieldTypes::Float64> {
+  static const ScalarVariant enum_value = ScalarVariant_Float64;
+};
+
+bool VerifyScalarVariant(::flatbuffers::Verifier &verifier, const void *obj, ScalarVariant type);
+bool VerifyScalarVariantVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, const ::flatbuffers::Vector<uint8_t> *types);
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) UInt8 FLATBUFFERS_FINAL_CLASS {
+ private:
+  uint8_t data_;
+
+ public:
+  UInt8()
+      : data_(0) {
+  }
+  UInt8(uint8_t _data)
+      : data_(::flatbuffers::EndianScalar(_data)) {
+  }
+  uint8_t data() const {
+    return ::flatbuffers::EndianScalar(data_);
+  }
+  void mutate_data(uint8_t _data) {
+    ::flatbuffers::WriteScalar(&data_, _data);
+  }
+};
+FLATBUFFERS_STRUCT_END(UInt8, 1);
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(2) UInt16 FLATBUFFERS_FINAL_CLASS {
+ private:
+  uint16_t data_;
+
+ public:
+  UInt16()
+      : data_(0) {
+  }
+  UInt16(uint16_t _data)
+      : data_(::flatbuffers::EndianScalar(_data)) {
+  }
+  uint16_t data() const {
+    return ::flatbuffers::EndianScalar(data_);
+  }
+  void mutate_data(uint16_t _data) {
+    ::flatbuffers::WriteScalar(&data_, _data);
+  }
+};
+FLATBUFFERS_STRUCT_END(UInt16, 2);
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) UInt32 FLATBUFFERS_FINAL_CLASS {
+ private:
+  uint32_t data_;
+
+ public:
+  UInt32()
+      : data_(0) {
+  }
+  UInt32(uint32_t _data)
+      : data_(::flatbuffers::EndianScalar(_data)) {
+  }
+  uint32_t data() const {
+    return ::flatbuffers::EndianScalar(data_);
+  }
+  void mutate_data(uint32_t _data) {
+    ::flatbuffers::WriteScalar(&data_, _data);
+  }
+};
+FLATBUFFERS_STRUCT_END(UInt32, 4);
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) UInt64 FLATBUFFERS_FINAL_CLASS {
+ private:
+  uint64_t data_;
+
+ public:
+  UInt64()
+      : data_(0) {
+  }
+  UInt64(uint64_t _data)
+      : data_(::flatbuffers::EndianScalar(_data)) {
+  }
+  uint64_t data() const {
+    return ::flatbuffers::EndianScalar(data_);
+  }
+  void mutate_data(uint64_t _data) {
+    ::flatbuffers::WriteScalar(&data_, _data);
+  }
+};
+FLATBUFFERS_STRUCT_END(UInt64, 8);
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) Int8 FLATBUFFERS_FINAL_CLASS {
+ private:
+  int8_t data_;
+
+ public:
+  Int8()
+      : data_(0) {
+  }
+  Int8(int8_t _data)
+      : data_(::flatbuffers::EndianScalar(_data)) {
+  }
+  int8_t data() const {
+    return ::flatbuffers::EndianScalar(data_);
+  }
+  void mutate_data(int8_t _data) {
+    ::flatbuffers::WriteScalar(&data_, _data);
+  }
+};
+FLATBUFFERS_STRUCT_END(Int8, 1);
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(2) Int16 FLATBUFFERS_FINAL_CLASS {
+ private:
+  int16_t data_;
+
+ public:
+  Int16()
+      : data_(0) {
+  }
+  Int16(int16_t _data)
+      : data_(::flatbuffers::EndianScalar(_data)) {
+  }
+  int16_t data() const {
+    return ::flatbuffers::EndianScalar(data_);
+  }
+  void mutate_data(int16_t _data) {
+    ::flatbuffers::WriteScalar(&data_, _data);
+  }
+};
+FLATBUFFERS_STRUCT_END(Int16, 2);
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Int32 FLATBUFFERS_FINAL_CLASS {
+ private:
+  int32_t data_;
+
+ public:
+  Int32()
+      : data_(0) {
+  }
+  Int32(int32_t _data)
+      : data_(::flatbuffers::EndianScalar(_data)) {
+  }
+  int32_t data() const {
+    return ::flatbuffers::EndianScalar(data_);
+  }
+  void mutate_data(int32_t _data) {
+    ::flatbuffers::WriteScalar(&data_, _data);
+  }
+};
+FLATBUFFERS_STRUCT_END(Int32, 4);
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) Int64 FLATBUFFERS_FINAL_CLASS {
+ private:
+  int64_t data_;
+
+ public:
+  Int64()
+      : data_(0) {
+  }
+  Int64(int64_t _data)
+      : data_(::flatbuffers::EndianScalar(_data)) {
+  }
+  int64_t data() const {
+    return ::flatbuffers::EndianScalar(data_);
+  }
+  void mutate_data(int64_t _data) {
+    ::flatbuffers::WriteScalar(&data_, _data);
+  }
+};
+FLATBUFFERS_STRUCT_END(Int64, 8);
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Float32 FLATBUFFERS_FINAL_CLASS {
+ private:
+  float data_;
+
+ public:
+  Float32()
+      : data_(0) {
+  }
+  Float32(float _data)
+      : data_(::flatbuffers::EndianScalar(_data)) {
+  }
+  float data() const {
+    return ::flatbuffers::EndianScalar(data_);
+  }
+  void mutate_data(float _data) {
+    ::flatbuffers::WriteScalar(&data_, _data);
+  }
+};
+FLATBUFFERS_STRUCT_END(Float32, 4);
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) Float64 FLATBUFFERS_FINAL_CLASS {
+ private:
+  double data_;
+
+ public:
+  Float64()
+      : data_(0) {
+  }
+  Float64(double _data)
+      : data_(::flatbuffers::EndianScalar(_data)) {
+  }
+  double data() const {
+    return ::flatbuffers::EndianScalar(data_);
+  }
+  void mutate_data(double _data) {
+    ::flatbuffers::WriteScalar(&data_, _data);
+  }
+};
+FLATBUFFERS_STRUCT_END(Float64, 8);
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) ScaleBias FLATBUFFERS_FINAL_CLASS {
+ private:
+  float scale_;
+  float bias_;
+
+ public:
+  ScaleBias()
+      : scale_(0),
+        bias_(0) {
+  }
+  ScaleBias(float _scale, float _bias)
+      : scale_(::flatbuffers::EndianScalar(_scale)),
+        bias_(::flatbuffers::EndianScalar(_bias)) {
+  }
+  float scale() const {
+    return ::flatbuffers::EndianScalar(scale_);
+  }
+  void mutate_scale(float _scale) {
+    ::flatbuffers::WriteScalar(&scale_, _scale);
+  }
+  float bias() const {
+    return ::flatbuffers::EndianScalar(bias_);
+  }
+  void mutate_bias(float _bias) {
+    ::flatbuffers::WriteScalar(&bias_, _bias);
+  }
+};
+FLATBUFFERS_STRUCT_END(ScaleBias, 8);
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Size2D FLATBUFFERS_FINAL_CLASS {
+ private:
+  uint32_t width_;
+  uint32_t height_;
+
+ public:
+  Size2D()
+      : width_(0),
+        height_(0) {
+  }
+  Size2D(uint32_t _width, uint32_t _height)
+      : width_(::flatbuffers::EndianScalar(_width)),
+        height_(::flatbuffers::EndianScalar(_height)) {
+  }
+  uint32_t width() const {
+    return ::flatbuffers::EndianScalar(width_);
+  }
+  void mutate_width(uint32_t _width) {
+    ::flatbuffers::WriteScalar(&width_, _width);
+  }
+  uint32_t height() const {
+    return ::flatbuffers::EndianScalar(height_);
+  }
+  void mutate_height(uint32_t _height) {
+    ::flatbuffers::WriteScalar(&height_, _height);
+  }
+};
+FLATBUFFERS_STRUCT_END(Size2D, 8);
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) ByteArray FLATBUFFERS_FINAL_CLASS {
+ private:
+  uint8_t data_[8];
+
+ public:
+  ByteArray()
+      : data_() {
+  }
+  ByteArray(::flatbuffers::span<const uint8_t, 8> _data) {
+    ::flatbuffers::CastToArray(data_).CopyFromSpan(_data);
+  }
+  const ::flatbuffers::Array<uint8_t, 8> *data() const {
+    return &::flatbuffers::CastToArray(data_);
+  }
+  ::flatbuffers::Array<uint8_t, 8> *mutable_data() {
+    return &::flatbuffers::CastToArray(data_);
+  }
+};
+FLATBUFFERS_STRUCT_END(ByteArray, 8);
+
+FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) Bool FLATBUFFERS_FINAL_CLASS {
+ private:
+  uint8_t data_;
+
+ public:
+  Bool()
+      : data_(0) {
+  }
+  Bool(bool _data)
+      : data_(::flatbuffers::EndianScalar(static_cast<uint8_t>(_data))) {
+  }
+  bool data() const {
+    return ::flatbuffers::EndianScalar(data_) != 0;
+  }
+  void mutate_data(bool _data) {
+    ::flatbuffers::WriteScalar(&data_, static_cast<uint8_t>(_data));
+  }
+};
+FLATBUFFERS_STRUCT_END(Bool, 1);
+
+struct AttributeDesc FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+  typedef AttributeDescBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_NAME = 4,
+    VT_VAL_TYPE = 6,
+    VT_VAL = 8
+  };
+  const ::flatbuffers::String *name() const {
+    return GetPointer<const ::flatbuffers::String *>(VT_NAME);
+  }
+  ::flatbuffers::String *mutable_name() {
+    return GetPointer<::flatbuffers::String *>(VT_NAME);
+  }
+  dml::ir::operatorFieldTypes::AttributeFieldVariant val_type() const {
+    return static_cast<dml::ir::operatorFieldTypes::AttributeFieldVariant>(GetField<uint8_t>(VT_VAL_TYPE, 0));
+  }
+  const void *val() const {
+    return GetPointer<const void *>(VT_VAL);
+  }
+  template<typename T> const T *val_as() const;
+  const dml::ir::operatorFieldTypes::Activation *val_as_Activation() const {
+    return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation ? static_cast<const dml::ir::operatorFieldTypes::Activation *>(val()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::ActivationArray *val_as_ActivationArray() const {
+    return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray ? static_cast<const dml::ir::operatorFieldTypes::ActivationArray *>(val()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::UInt32 *val_as_UInt32() const {
+    return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt32 ? static_cast<const dml::ir::operatorFieldTypes::UInt32 *>(val()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::UInt64 *val_as_UInt64() const {
+    return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt64 ? static_cast<const dml::ir::operatorFieldTypes::UInt64 *>(val()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::Int32 *val_as_Int32() const {
+    return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Int32 ? static_cast<const dml::ir::operatorFieldTypes::Int32 *>(val()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::Float32 *val_as_Float32() const {
+    return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Float32 ? static_cast<const dml::ir::operatorFieldTypes::Float32 *>(val()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::UIntArray *val_as_UIntArray() const {
+    return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UIntArray ? static_cast<const dml::ir::operatorFieldTypes::UIntArray *>(val()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::IntArray *val_as_IntArray() const {
+    return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_IntArray ? static_cast<const dml::ir::operatorFieldTypes::IntArray *>(val()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::FloatArray *val_as_FloatArray() const {
+    return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_FloatArray ? static_cast<const dml::ir::operatorFieldTypes::FloatArray *>(val()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::ScaleBias *val_as_ScaleBias() const {
+    return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias ? static_cast<const dml::ir::operatorFieldTypes::ScaleBias *>(val()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::Size2D *val_as_Size2D() const {
+    return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Size2D ? static_cast<const dml::ir::operatorFieldTypes::Size2D *>(val()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::ScalarUnionData *val_as_ScalarUnionData() const {
+    return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ScalarUnionData ? static_cast<const dml::ir::operatorFieldTypes::ScalarUnionData *>(val()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::Bool *val_as_Bool() const {
+    return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Bool ? static_cast<const dml::ir::operatorFieldTypes::Bool *>(val()) : nullptr;
+  }
+  void *mutable_val() {
+    return GetPointer<void *>(VT_VAL);
+  }
+  bool Verify(::flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, VT_NAME) &&
+           verifier.VerifyString(name()) &&
+           VerifyField<uint8_t>(verifier, VT_VAL_TYPE, 1) &&
+           VerifyOffset(verifier, VT_VAL) &&
+           VerifyAttributeFieldVariant(verifier, val(), val_type()) &&
+           verifier.EndTable();
+  }
+};
+
+template<> inline const dml::ir::operatorFieldTypes::Activation *AttributeDesc::val_as<dml::ir::operatorFieldTypes::Activation>() const {
+  return val_as_Activation();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::ActivationArray *AttributeDesc::val_as<dml::ir::operatorFieldTypes::ActivationArray>() const {
+  return val_as_ActivationArray();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::UInt32 *AttributeDesc::val_as<dml::ir::operatorFieldTypes::UInt32>() const {
+  return val_as_UInt32();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::UInt64 *AttributeDesc::val_as<dml::ir::operatorFieldTypes::UInt64>() const {
+  return val_as_UInt64();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::Int32 *AttributeDesc::val_as<dml::ir::operatorFieldTypes::Int32>() const {
+  return val_as_Int32();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::Float32 *AttributeDesc::val_as<dml::ir::operatorFieldTypes::Float32>() const {
+  return val_as_Float32();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::UIntArray *AttributeDesc::val_as<dml::ir::operatorFieldTypes::UIntArray>() const {
+  return val_as_UIntArray();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::IntArray *AttributeDesc::val_as<dml::ir::operatorFieldTypes::IntArray>() const {
+  return val_as_IntArray();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::FloatArray *AttributeDesc::val_as<dml::ir::operatorFieldTypes::FloatArray>() const {
+  return val_as_FloatArray();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::ScaleBias *AttributeDesc::val_as<dml::ir::operatorFieldTypes::ScaleBias>() const {
+  return val_as_ScaleBias();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::Size2D *AttributeDesc::val_as<dml::ir::operatorFieldTypes::Size2D>() const {
+  return val_as_Size2D();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::ScalarUnionData *AttributeDesc::val_as<dml::ir::operatorFieldTypes::ScalarUnionData>() const {
+  return val_as_ScalarUnionData();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::Bool *AttributeDesc::val_as<dml::ir::operatorFieldTypes::Bool>() const {
+  return val_as_Bool();
+}
+
+struct AttributeDescBuilder {
+  typedef AttributeDesc Table;
+  ::flatbuffers::FlatBufferBuilder &fbb_;
+  ::flatbuffers::uoffset_t start_;
+  void add_name(::flatbuffers::Offset<::flatbuffers::String> name) {
+    fbb_.AddOffset(AttributeDesc::VT_NAME, name);
+  }
+  void add_val_type(dml::ir::operatorFieldTypes::AttributeFieldVariant val_type) {
+    fbb_.AddElement<uint8_t>(AttributeDesc::VT_VAL_TYPE, static_cast<uint8_t>(val_type), 0);
+  }
+  void add_val(::flatbuffers::Offset<void> val) {
+    fbb_.AddOffset(AttributeDesc::VT_VAL, val);
+  }
+  explicit AttributeDescBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ::flatbuffers::Offset<AttributeDesc> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = ::flatbuffers::Offset<AttributeDesc>(end);
+    return o;
+  }
+};
+
+inline ::flatbuffers::Offset<AttributeDesc> CreateAttributeDesc(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    ::flatbuffers::Offset<::flatbuffers::String> name = 0,
+    dml::ir::operatorFieldTypes::AttributeFieldVariant val_type = dml::ir::operatorFieldTypes::AttributeFieldVariant_NONE,
+    ::flatbuffers::Offset<void> val = 0) {
+  AttributeDescBuilder builder_(_fbb);
+  builder_.add_val(val);
+  builder_.add_name(name);
+  builder_.add_val_type(val_type);
+  return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<AttributeDesc> CreateAttributeDescDirect(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    const char *name = nullptr,
+    dml::ir::operatorFieldTypes::AttributeFieldVariant val_type = dml::ir::operatorFieldTypes::AttributeFieldVariant_NONE,
+    ::flatbuffers::Offset<void> val = 0) {
+  auto name__ = name ? _fbb.CreateString(name) : 0;
+  return dml::ir::operatorFieldTypes::CreateAttributeDesc(
+      _fbb,
+      name__,
+      val_type,
+      val);
+}
+
+struct Activation FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+  typedef ActivationBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_TYPE = 4,
+    VT_ATTRIBUTES = 6
+  };
+  const ::flatbuffers::String *type() const {
+    return GetPointer<const ::flatbuffers::String *>(VT_TYPE);
+  }
+  ::flatbuffers::String *mutable_type() {
+    return GetPointer<::flatbuffers::String *>(VT_TYPE);
+  }
+  const ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>> *attributes() const {
+    return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>> *>(VT_ATTRIBUTES);
+  }
+  ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>> *mutable_attributes() {
+    return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>> *>(VT_ATTRIBUTES);
+  }
+  bool Verify(::flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, VT_TYPE) &&
+           verifier.VerifyString(type()) &&
+           VerifyOffset(verifier, VT_ATTRIBUTES) &&
+           verifier.VerifyVector(attributes()) &&
+           verifier.VerifyVectorOfTables(attributes()) &&
+           verifier.EndTable();
+  }
+};
+
+struct ActivationBuilder {
+  typedef Activation Table;
+  ::flatbuffers::FlatBufferBuilder &fbb_;
+  ::flatbuffers::uoffset_t start_;
+  void add_type(::flatbuffers::Offset<::flatbuffers::String> type) {
+    fbb_.AddOffset(Activation::VT_TYPE, type);
+  }
+  void add_attributes(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>>> attributes) {
+    fbb_.AddOffset(Activation::VT_ATTRIBUTES, attributes);
+  }
+  explicit ActivationBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ::flatbuffers::Offset<Activation> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = ::flatbuffers::Offset<Activation>(end);
+    return o;
+  }
+};
+
+inline ::flatbuffers::Offset<Activation> CreateActivation(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    ::flatbuffers::Offset<::flatbuffers::String> type = 0,
+    ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>>> attributes = 0) {
+  ActivationBuilder builder_(_fbb);
+  builder_.add_attributes(attributes);
+  builder_.add_type(type);
+  return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<Activation> CreateActivationDirect(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    const char *type = nullptr,
+    const std::vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>> *attributes = nullptr) {
+  auto type__ = type ? _fbb.CreateString(type) : 0;
+  auto attributes__ = attributes ? _fbb.CreateVector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::AttributeDesc>>(*attributes) : 0;
+  return dml::ir::operatorFieldTypes::CreateActivation(
+      _fbb,
+      type__,
+      attributes__);
+}
+
+struct ActivationArray FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+  typedef ActivationArrayBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_DATA = 4
+  };
+  const ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::Activation>> *data() const {
+    return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::Activation>> *>(VT_DATA);
+  }
+  ::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::Activation>> *mutable_data() {
+    return GetPointer<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::Activation>> *>(VT_DATA);
+  }
+  bool Verify(::flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, VT_DATA) &&
+           verifier.VerifyVector(data()) &&
+           verifier.VerifyVectorOfTables(data()) &&
+           verifier.EndTable();
+  }
+};
+
+struct ActivationArrayBuilder {
+  typedef ActivationArray Table;
+  ::flatbuffers::FlatBufferBuilder &fbb_;
+  ::flatbuffers::uoffset_t start_;
+  void add_data(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::Activation>>> data) {
+    fbb_.AddOffset(ActivationArray::VT_DATA, data);
+  }
+  explicit ActivationArrayBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ::flatbuffers::Offset<ActivationArray> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = ::flatbuffers::Offset<ActivationArray>(end);
+    return o;
+  }
+};
+
+inline ::flatbuffers::Offset<ActivationArray> CreateActivationArray(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::Activation>>> data = 0) {
+  ActivationArrayBuilder builder_(_fbb);
+  builder_.add_data(data);
+  return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<ActivationArray> CreateActivationArrayDirect(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    const std::vector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::Activation>> *data = nullptr) {
+  auto data__ = data ? _fbb.CreateVector<::flatbuffers::Offset<dml::ir::operatorFieldTypes::Activation>>(*data) : 0;
+  return dml::ir::operatorFieldTypes::CreateActivationArray(
+      _fbb,
+      data__);
+}
+
+struct UIntArray FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+  typedef UIntArrayBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_DATA = 4
+  };
+  const ::flatbuffers::Vector<uint32_t> *data() const {
+    return GetPointer<const ::flatbuffers::Vector<uint32_t> *>(VT_DATA);
+  }
+  ::flatbuffers::Vector<uint32_t> *mutable_data() {
+    return GetPointer<::flatbuffers::Vector<uint32_t> *>(VT_DATA);
+  }
+  bool Verify(::flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, VT_DATA) &&
+           verifier.VerifyVector(data()) &&
+           verifier.EndTable();
+  }
+};
+
+struct UIntArrayBuilder {
+  typedef UIntArray Table;
+  ::flatbuffers::FlatBufferBuilder &fbb_;
+  ::flatbuffers::uoffset_t start_;
+  void add_data(::flatbuffers::Offset<::flatbuffers::Vector<uint32_t>> data) {
+    fbb_.AddOffset(UIntArray::VT_DATA, data);
+  }
+  explicit UIntArrayBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ::flatbuffers::Offset<UIntArray> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = ::flatbuffers::Offset<UIntArray>(end);
+    return o;
+  }
+};
+
+inline ::flatbuffers::Offset<UIntArray> CreateUIntArray(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    ::flatbuffers::Offset<::flatbuffers::Vector<uint32_t>> data = 0) {
+  UIntArrayBuilder builder_(_fbb);
+  builder_.add_data(data);
+  return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<UIntArray> CreateUIntArrayDirect(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    const std::vector<uint32_t> *data = nullptr) {
+  auto data__ = data ? _fbb.CreateVector<uint32_t>(*data) : 0;
+  return dml::ir::operatorFieldTypes::CreateUIntArray(
+      _fbb,
+      data__);
+}
+
+struct IntArray FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+  typedef IntArrayBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_DATA = 4
+  };
+  const ::flatbuffers::Vector<int32_t> *data() const {
+    return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_DATA);
+  }
+  ::flatbuffers::Vector<int32_t> *mutable_data() {
+    return GetPointer<::flatbuffers::Vector<int32_t> *>(VT_DATA);
+  }
+  bool Verify(::flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, VT_DATA) &&
+           verifier.VerifyVector(data()) &&
+           verifier.EndTable();
+  }
+};
+
+struct IntArrayBuilder {
+  typedef IntArray Table;
+  ::flatbuffers::FlatBufferBuilder &fbb_;
+  ::flatbuffers::uoffset_t start_;
+  void add_data(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> data) {
+    fbb_.AddOffset(IntArray::VT_DATA, data);
+  }
+  explicit IntArrayBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ::flatbuffers::Offset<IntArray> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = ::flatbuffers::Offset<IntArray>(end);
+    return o;
+  }
+};
+
+inline ::flatbuffers::Offset<IntArray> CreateIntArray(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> data = 0) {
+  IntArrayBuilder builder_(_fbb);
+  builder_.add_data(data);
+  return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<IntArray> CreateIntArrayDirect(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    const std::vector<int32_t> *data = nullptr) {
+  auto data__ = data ? _fbb.CreateVector<int32_t>(*data) : 0;
+  return dml::ir::operatorFieldTypes::CreateIntArray(
+      _fbb,
+      data__);
+}
+
+struct FloatArray FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+  typedef FloatArrayBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_DATA = 4
+  };
+  const ::flatbuffers::Vector<float> *data() const {
+    return GetPointer<const ::flatbuffers::Vector<float> *>(VT_DATA);
+  }
+  ::flatbuffers::Vector<float> *mutable_data() {
+    return GetPointer<::flatbuffers::Vector<float> *>(VT_DATA);
+  }
+  bool Verify(::flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, VT_DATA) &&
+           verifier.VerifyVector(data()) &&
+           verifier.EndTable();
+  }
+};
+
+struct FloatArrayBuilder {
+  typedef FloatArray Table;
+  ::flatbuffers::FlatBufferBuilder &fbb_;
+  ::flatbuffers::uoffset_t start_;
+  void add_data(::flatbuffers::Offset<::flatbuffers::Vector<float>> data) {
+    fbb_.AddOffset(FloatArray::VT_DATA, data);
+  }
+  explicit FloatArrayBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ::flatbuffers::Offset<FloatArray> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = ::flatbuffers::Offset<FloatArray>(end);
+    return o;
+  }
+};
+
+inline ::flatbuffers::Offset<FloatArray> CreateFloatArray(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    ::flatbuffers::Offset<::flatbuffers::Vector<float>> data = 0) {
+  FloatArrayBuilder builder_(_fbb);
+  builder_.add_data(data);
+  return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<FloatArray> CreateFloatArrayDirect(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    const std::vector<float> *data = nullptr) {
+  auto data__ = data ? _fbb.CreateVector<float>(*data) : 0;
+  return dml::ir::operatorFieldTypes::CreateFloatArray(
+      _fbb,
+      data__);
+}
+
+struct ScalarUnionData FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
+  typedef ScalarUnionDataBuilder Builder;
+  enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+    VT_DATA_TYPE = 4,
+    VT_DATA = 6
+  };
+  dml::ir::operatorFieldTypes::ScalarVariant data_type() const {
+    return static_cast<dml::ir::operatorFieldTypes::ScalarVariant>(GetField<uint8_t>(VT_DATA_TYPE, 0));
+  }
+  const void *data() const {
+    return GetPointer<const void *>(VT_DATA);
+  }
+  template<typename T> const T *data_as() const;
+  const dml::ir::operatorFieldTypes::ByteArray *data_as_ByteArray() const {
+    return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_ByteArray ? static_cast<const dml::ir::operatorFieldTypes::ByteArray *>(data()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::Int8 *data_as_Int8() const {
+    return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int8 ? static_cast<const dml::ir::operatorFieldTypes::Int8 *>(data()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::UInt8 *data_as_UInt8() const {
+    return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt8 ? static_cast<const dml::ir::operatorFieldTypes::UInt8 *>(data()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::Int16 *data_as_Int16() const {
+    return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int16 ? static_cast<const dml::ir::operatorFieldTypes::Int16 *>(data()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::UInt16 *data_as_UInt16() const {
+    return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt16 ? static_cast<const dml::ir::operatorFieldTypes::UInt16 *>(data()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::Int32 *data_as_Int32() const {
+    return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int32 ? static_cast<const dml::ir::operatorFieldTypes::Int32 *>(data()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::UInt32 *data_as_UInt32() const {
+    return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt32 ? static_cast<const dml::ir::operatorFieldTypes::UInt32 *>(data()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::Int64 *data_as_Int64() const {
+    return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int64 ? static_cast<const dml::ir::operatorFieldTypes::Int64 *>(data()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::UInt64 *data_as_UInt64() const {
+    return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt64 ? static_cast<const dml::ir::operatorFieldTypes::UInt64 *>(data()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::Float32 *data_as_Float32() const {
+    return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Float32 ? static_cast<const dml::ir::operatorFieldTypes::Float32 *>(data()) : nullptr;
+  }
+  const dml::ir::operatorFieldTypes::Float64 *data_as_Float64() const {
+    return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Float64 ? static_cast<const dml::ir::operatorFieldTypes::Float64 *>(data()) : nullptr;
+  }
+  void *mutable_data() {
+    return GetPointer<void *>(VT_DATA);
+  }
+  bool Verify(::flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyField<uint8_t>(verifier, VT_DATA_TYPE, 1) &&
+           VerifyOffset(verifier, VT_DATA) &&
+           VerifyScalarVariant(verifier, data(), data_type()) &&
+           verifier.EndTable();
+  }
+};
+
+template<> inline const dml::ir::operatorFieldTypes::ByteArray *ScalarUnionData::data_as<dml::ir::operatorFieldTypes::ByteArray>() const {
+  return data_as_ByteArray();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::Int8 *ScalarUnionData::data_as<dml::ir::operatorFieldTypes::Int8>() const {
+  return data_as_Int8();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::UInt8 *ScalarUnionData::data_as<dml::ir::operatorFieldTypes::UInt8>() const {
+  return data_as_UInt8();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::Int16 *ScalarUnionData::data_as<dml::ir::operatorFieldTypes::Int16>() const {
+  return data_as_Int16();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::UInt16 *ScalarUnionData::data_as<dml::ir::operatorFieldTypes::UInt16>() const {
+  return data_as_UInt16();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::Int32 *ScalarUnionData::data_as<dml::ir::operatorFieldTypes::Int32>() const {
+  return data_as_Int32();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::UInt32 *ScalarUnionData::data_as<dml::ir::operatorFieldTypes::UInt32>() const {
+  return data_as_UInt32();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::Int64 *ScalarUnionData::data_as<dml::ir::operatorFieldTypes::Int64>() const {
+  return data_as_Int64();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::UInt64 *ScalarUnionData::data_as<dml::ir::operatorFieldTypes::UInt64>() const {
+  return data_as_UInt64();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::Float32 *ScalarUnionData::data_as<dml::ir::operatorFieldTypes::Float32>() const {
+  return data_as_Float32();
+}
+
+template<> inline const dml::ir::operatorFieldTypes::Float64 *ScalarUnionData::data_as<dml::ir::operatorFieldTypes::Float64>() const {
+  return data_as_Float64();
+}
+
+struct ScalarUnionDataBuilder {
+  typedef ScalarUnionData Table;
+  ::flatbuffers::FlatBufferBuilder &fbb_;
+  ::flatbuffers::uoffset_t start_;
+  void add_data_type(dml::ir::operatorFieldTypes::ScalarVariant data_type) {
+    fbb_.AddElement<uint8_t>(ScalarUnionData::VT_DATA_TYPE, static_cast<uint8_t>(data_type), 0);
+  }
+  void add_data(::flatbuffers::Offset<void> data) {
+    fbb_.AddOffset(ScalarUnionData::VT_DATA, data);
+  }
+  explicit ScalarUnionDataBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ::flatbuffers::Offset<ScalarUnionData> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = ::flatbuffers::Offset<ScalarUnionData>(end);
+    return o;
+  }
+};
+
+inline ::flatbuffers::Offset<ScalarUnionData> CreateScalarUnionData(
+    ::flatbuffers::FlatBufferBuilder &_fbb,
+    dml::ir::operatorFieldTypes::ScalarVariant data_type = dml::ir::operatorFieldTypes::ScalarVariant_NONE,
+    ::flatbuffers::Offset<void> data = 0) {
+  ScalarUnionDataBuilder builder_(_fbb);
+  builder_.add_data(data);
+  builder_.add_data_type(data_type);
+  return builder_.Finish();
+}
+
+inline bool VerifyAttributeFieldVariant(::flatbuffers::Verifier &verifier, const void *obj, AttributeFieldVariant type) {
+  switch (type) {
+    case AttributeFieldVariant_NONE: {
+      return true;
+    }
+    case AttributeFieldVariant_Activation: {
+      auto ptr = reinterpret_cast<const dml::ir::operatorFieldTypes::Activation *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
+    case AttributeFieldVariant_ActivationArray: {
+      auto ptr = reinterpret_cast<const dml::ir::operatorFieldTypes::ActivationArray *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
+    case AttributeFieldVariant_UInt32: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::UInt32>(static_cast<const uint8_t *>(obj), 0, 4);
+    }
+    case AttributeFieldVariant_UInt64: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::UInt64>(static_cast<const uint8_t *>(obj), 0, 8);
+    }
+    case AttributeFieldVariant_Int32: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::Int32>(static_cast<const uint8_t *>(obj), 0, 4);
+    }
+    case AttributeFieldVariant_Float32: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::Float32>(static_cast<const uint8_t *>(obj), 0, 4);
+    }
+    case AttributeFieldVariant_UIntArray: {
+      auto ptr = reinterpret_cast<const dml::ir::operatorFieldTypes::UIntArray *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
+    case AttributeFieldVariant_IntArray: {
+      auto ptr = reinterpret_cast<const dml::ir::operatorFieldTypes::IntArray *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
+    case AttributeFieldVariant_FloatArray: {
+      auto ptr = reinterpret_cast<const dml::ir::operatorFieldTypes::FloatArray *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
+    case AttributeFieldVariant_ScaleBias: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::ScaleBias>(static_cast<const uint8_t *>(obj), 0, 4);
+    }
+    case AttributeFieldVariant_Size2D: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::Size2D>(static_cast<const uint8_t *>(obj), 0, 4);
+    }
+    case AttributeFieldVariant_ScalarUnionData: {
+      auto ptr = reinterpret_cast<const dml::ir::operatorFieldTypes::ScalarUnionData *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
+    case AttributeFieldVariant_Bool: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::Bool>(static_cast<const uint8_t *>(obj), 0, 1);
+    }
+    default: return true;
+  }
+}
+
+inline bool VerifyAttributeFieldVariantVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, const ::flatbuffers::Vector<uint8_t> *types) {
+  if (!values || !types) return !values && !types;
+  if (values->size() != types->size()) return false;
+  for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+    if (!VerifyAttributeFieldVariant(
+        verifier,  values->Get(i), types->GetEnum<AttributeFieldVariant>(i))) {
+      return false;
+    }
+  }
+  return true;
+}
+
+inline bool VerifyScalarVariant(::flatbuffers::Verifier &verifier, const void *obj, ScalarVariant type) {
+  switch (type) {
+    case ScalarVariant_NONE: {
+      return true;
+    }
+    case ScalarVariant_ByteArray: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::ByteArray>(static_cast<const uint8_t *>(obj), 0, 1);
+    }
+    case ScalarVariant_Int8: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::Int8>(static_cast<const uint8_t *>(obj), 0, 1);
+    }
+    case ScalarVariant_UInt8: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::UInt8>(static_cast<const uint8_t *>(obj), 0, 1);
+    }
+    case ScalarVariant_Int16: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::Int16>(static_cast<const uint8_t *>(obj), 0, 2);
+    }
+    case ScalarVariant_UInt16: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::UInt16>(static_cast<const uint8_t *>(obj), 0, 2);
+    }
+    case ScalarVariant_Int32: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::Int32>(static_cast<const uint8_t *>(obj), 0, 4);
+    }
+    case ScalarVariant_UInt32: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::UInt32>(static_cast<const uint8_t *>(obj), 0, 4);
+    }
+    case ScalarVariant_Int64: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::Int64>(static_cast<const uint8_t *>(obj), 0, 8);
+    }
+    case ScalarVariant_UInt64: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::UInt64>(static_cast<const uint8_t *>(obj), 0, 8);
+    }
+    case ScalarVariant_Float32: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::Float32>(static_cast<const uint8_t *>(obj), 0, 4);
+    }
+    case ScalarVariant_Float64: {
+      return verifier.VerifyField<dml::ir::operatorFieldTypes::Float64>(static_cast<const uint8_t *>(obj), 0, 8);
+    }
+    default: return true;
+  }
+}
+
+inline bool VerifyScalarVariantVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values, const ::flatbuffers::Vector<uint8_t> *types) {
+  if (!values || !types) return !values && !types;
+  if (values->size() != types->size()) return false;
+  for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+    if (!VerifyScalarVariant(
+        verifier,  values->Get(i), types->GetEnum<ScalarVariant>(i))) {
+      return false;
+    }
+  }
+  return true;
+}
+
+}  // namespace operatorFieldTypes
+}  // namespace ir
+}  // namespace dml
+
+#endif  // FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h
index 5285481485184..1bc694dfe90c2 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h
@@ -26,14 +26,14 @@ namespace SchemaHelpers
         return field;
     }
 
-    inline OperatorFieldTypes::OperatorDesc ToOperatorFieldType(const DML_OPERATOR_DESC* value)
+    inline OperatorFieldTypes::FusedActivationOperatorDesc ToOperatorFieldType(const DML_OPERATOR_DESC* value)
     {
-        return value ? OperatorFieldTypes::OperatorDesc(ConvertOperatorDesc(*value)) : std::nullopt;
+        return value ? OperatorFieldTypes::FusedActivationOperatorDesc(ConvertOperatorDesc(*value)) : std::nullopt;
     }
 
-    inline OperatorFieldTypes::OperatorDescArray ToOperatorFieldType(const DML_OPERATOR_DESC* values, uint32_t count)
+    inline OperatorFieldTypes::FusedActivationOperatorDescArray ToOperatorFieldType(const DML_OPERATOR_DESC* values, uint32_t count)
     {
-        OperatorFieldTypes::OperatorDescArray field;
+        OperatorFieldTypes::FusedActivationOperatorDescArray field;
         if (values && count != 0)
         {
             field.emplace(count);
@@ -65,13 +65,17 @@ namespace SchemaHelpers
         return value;
     }
 
+    inline OperatorFieldTypes::Bool ToOperatorFieldType(bool value)
+    {
+        return value;
+    }
+
     inline OperatorFieldTypes::UIntArray ToOperatorFieldType(const uint32_t* values, uint32_t count)
     {
         OperatorFieldTypes::UIntArray field;
         if (values && count != 0)
         {
-            field.emplace(count);
-            std::copy_n(values, count, field->begin());
+            field.assign(values, values + count);
         }
         return field;
     }
@@ -81,8 +85,7 @@ namespace SchemaHelpers
         OperatorFieldTypes::IntArray field;
         if (values && count != 0)
         {
-            field.emplace(count);
-            std::copy_n(values, count, field->begin());
+            field.assign(values, values + count);
         }
         return field;
     }
@@ -92,8 +95,7 @@ namespace SchemaHelpers
         OperatorFieldTypes::FloatArray field;
         if (values && count != 0)
         {
-            field.emplace(count);
-            std::copy_n(values, count, field->begin());
+            field.assign(values, values + count);
         }
         return field;
     }
@@ -237,7 +239,7 @@ namespace SchemaHelpers
         {
             DML_OPERATOR_DESC* desc = nullptr;
 
-            const auto& value = field.AsOperatorDesc();
+            const auto& value = field.AsFusedActivationOperatorDesc();
             if (value)
             {
                 desc = allocator->template Allocate<DML_OPERATOR_DESC>();
@@ -251,7 +253,7 @@ namespace SchemaHelpers
         {
             DML_OPERATOR_DESC* descs = nullptr;
 
-            const auto& values = field.AsOperatorDescArray();
+            const auto& values = field.AsFusedActivationOperatorDescArray();
             if (values)
             {
                 descs = allocator->template Allocate<DML_OPERATOR_DESC>(values->size());
@@ -288,16 +290,20 @@ namespace SchemaHelpers
             dst->Write(value);
         } break;
 
+        case DML_SCHEMA_FIELD_TYPE_BOOL:
+        {
+            // OperatorFieldTypes::Bool is a 'bool' (1 byte) but written as 'BOOL' in op descs (4 bytes).
+            BOOL value = static_cast<BOOL>(field.AsBool());
+            dst->Write(value);
+        } break;
+
         case DML_SCHEMA_FIELD_TYPE_UINT_ARRAY:
         {
             uint32_t* arrayPtr = nullptr;
 
             const auto& values = field.AsUIntArray();
-            if (values)
-            {
-                arrayPtr = allocator->template Allocate<uint32_t>(values->size());
-                std::copy(values->begin(), values->end(), arrayPtr);
-            }
+            arrayPtr = allocator->template Allocate<uint32_t>(values.size());
+            std::copy(values.begin(), values.end(), arrayPtr);
 
             dst->Write(arrayPtr);
         } break;
@@ -307,11 +313,8 @@ namespace SchemaHelpers
             int32_t* arrayPtr = nullptr;
 
             const auto& values = field.AsIntArray();
-            if (values)
-            {
-                arrayPtr = allocator->template Allocate<int32_t>(values->size());
-                std::copy(values->begin(), values->end(), arrayPtr);
-            }
+            arrayPtr = allocator->template Allocate<int32_t>(values.size());
+            std::copy(values.begin(), values.end(), arrayPtr);
 
             dst->Write(arrayPtr);
         } break;
@@ -321,11 +324,8 @@ namespace SchemaHelpers
             float* arrayPtr = nullptr;
 
             const auto& values = field.AsFloatArray();
-            if (values)
-            {
-                arrayPtr = allocator->template Allocate<float>(values->size());
-                std::copy(values->begin(), values->end(), arrayPtr);
-            }
+            arrayPtr = allocator->template Allocate<float>(values.size());
+            std::copy(values.begin(), values.end(), arrayPtr);
 
             dst->Write(arrayPtr);
         } break;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
index 2456b396de3f6..e6f008af5c23f 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
@@ -33,10 +33,10 @@ namespace Dml::GraphDescBuilder
     #pragma warning(pop)
 
     static void RemoveUnconnectedNodes(
-        std::vector<NodeInfo>& graphNodes,
-        std::vector<DML_INPUT_GRAPH_EDGE_DESC>& graphInputEdges,
-        std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC>& graphIntermediateEdges,
-        std::vector<DML_OUTPUT_GRAPH_EDGE_DESC>& graphOutputEdges)
+        std::vector<DmlSerializedGraphNode>& graphNodes,
+        std::vector<DmlInputSerializedGraphEdge>& graphInputEdges,
+        std::vector<DmlIntermediateSerializedGraphEdge>& graphIntermediateEdges,
+        std::vector<DmlOutputSerializedGraphEdge>& graphOutputEdges)
     {
         enum class NodeState
         {
@@ -52,7 +52,7 @@ namespace Dml::GraphDescBuilder
         };
 
         std::vector<NodeData> nodesData(graphNodes.size());
-        for (const DML_INTERMEDIATE_GRAPH_EDGE_DESC& intermediateEdge : graphIntermediateEdges)
+        for (const DmlIntermediateSerializedGraphEdge& intermediateEdge : graphIntermediateEdges)
         {
             nodesData[intermediateEdge.ToNodeIndex].predecessorIndices.push_back(intermediateEdge.FromNodeIndex);
         }
@@ -60,7 +60,7 @@ namespace Dml::GraphDescBuilder
         std::stack<uint32_t> nodeIndicesToVisit;
 
         // Start from the outputs of the graph and traverse upwards
-        for (const DML_OUTPUT_GRAPH_EDGE_DESC& outputEdge : graphOutputEdges)
+        for (const DmlOutputSerializedGraphEdge& outputEdge : graphOutputEdges)
         {
             nodeIndicesToVisit.push(outputEdge.FromNodeIndex);
         }
@@ -143,17 +143,44 @@ namespace Dml::GraphDescBuilder
         }
     }
 
+
+    uint32_t SetAndGetDmlGraphNodeIndex(
+        const uint32_t operatorDmlGraphNodeIndex,
+        const std::string& nodeNamePrefix,
+        AbstractOperatorDesc& operatorDesc,
+        /*in_out*/std::unordered_map<uint32_t, uint32_t>& operatorDmlGraphToDmlGraphNodeIndexMap,
+        /*in_out*/std::vector<DmlSerializedGraphNode>& dmlGraphNodes)
+    {
+        auto iter = operatorDmlGraphToDmlGraphNodeIndexMap.find(operatorDmlGraphNodeIndex);
+        if (iter != operatorDmlGraphToDmlGraphNodeIndexMap.end())
+        {
+            return iter->second;
+        }
+        operatorDmlGraphToDmlGraphNodeIndexMap[operatorDmlGraphNodeIndex] = static_cast<uint32_t>(dmlGraphNodes.size());
+        dmlGraphNodes.push_back({operatorDesc, nodeNamePrefix + std::to_string(operatorDmlGraphNodeIndex)});
+        return operatorDmlGraphToDmlGraphNodeIndexMap[operatorDmlGraphNodeIndex];
+    }
+
+    // Terminology:
+    //   Subgraph: partitioned ONNX graph from the original (main) ONNX graph
+    //   DmlGraph: a graph in DML currency converted from subgraph.
+    //   operatorDmlGraph: a graph in DML currency for a given node or operator
+    // Main Points to note:
+    //   - GraphDesc will always has sequential indices for input and intermediate edges.
+    //   - 1 onnx node can be converted to one or more dml nodes.
     GraphDesc BuildGraphDesc(
         const uint8_t* isConstGpuGraphInput,
         const size_t isConstGpuGraphInputCount,
         const std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>& isInitializerTransferable,
         const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
-        IDMLDevice* device,
         const ExecutionProviderImpl* executionHandle,
         const onnxruntime::Path& modelPath,
         gsl::span<const onnxruntime::Node* const> subgraphNodes,
         gsl::span<const onnxruntime::NodeArg* const> subgraphInputs,
-        gsl::span<const onnxruntime::NodeArg* const> subgraphOutputs)
+        gsl::span<const onnxruntime::NodeArg* const> subgraphOutputs,
+        /*out*/ std::unordered_map<uint32_t, uint32_t>& serializedGraphInputIndexToSubgraphInputIndex,
+        /*out*/ std::unordered_map<std::string_view, uint32_t>& serializedGraphLargeConstantNameToSubgraphInputIndex,
+        /*out*/ std::vector<std::unique_ptr<std::byte[]>>& smallConstantData)
     {
         struct NodeAndIndex
         {
@@ -161,19 +188,34 @@ namespace Dml::GraphDescBuilder
             uint32_t targetIndex; // The index of the input/output on the node (e.g. 1 for the second input on a node)
         };
 
-        // Map from Lotus node argument names to the new node and index where it will be produced
-        std::unordered_map<std::string, NodeAndIndex> nameToNodeAndIndexMap;
-
         std::unordered_map<std::string, EdgeShapes> nodeOutputShapes;
 
-        // Map from Lotus node argument names to input indices of the fused kernel node.
-        std::unordered_map<std::string, uint32_t> nameToDmlFusedNodeInputIndex;
+        // Map from ORT subgraph input names to indices
+        std::unordered_map<std::string_view, uint32_t> subgraphInputNameToIndexMap;
+        
+        // - Map from ORT node's output names to DmlGraph <NodeAndIndex>.
+        // - Once a given ORT node (or operator) will be transformed into a operatorDmlGraph,
+        //   then ORT node's output names will become output edges for the operatorDmlGraph.
+        // - This map will be populated for those output edges.
+        std::unordered_map<std::string, NodeAndIndex> dmlGraphNodeOutputNameToNodeAndIndexMap;
+        
+        // This map will be used to re-index an subGraphInputIndex to sequential input index
+        // for DmlGraph
+        std::unordered_map<uint32_t, uint32_t> subGraphInputIndexToDmlGraphInputIndex;
+        
+        // Iterate through each node and create a corresponding node in the new graph
+        // We can iterate the nodes in any order because the edge connectivity will take care of the topological order
+        std::unordered_map<std::string, std::vector<uint32_t>> inferredOutputShapes;
+        
+        std::vector<DmlSerializedGraphNode> dmlGraphNodes;
+        std::vector<DmlInputSerializedGraphEdge> dmlGraphInputEdges;
+        std::vector<DmlIntermediateSerializedGraphEdge> dmlGraphIntermediateEdges;
+        std::vector<DmlOutputSerializedGraphEdge> dmlGraphOutputEdges;
 
         for (size_t inputIndex = 0; inputIndex < subgraphInputs.size(); ++inputIndex)
         {
-            const onnxruntime::NodeArg* graphInput = subgraphInputs[inputIndex];
-
-            if (!graphInput)
+            const onnxruntime::NodeArg* subgraphInput = subgraphInputs[inputIndex];
+            if (!subgraphInput)
             {
                 // This is a workaround for when node inputs get manipulated by transformers outside of our control,
                 // which then causes them to have a different name. If that happens we can't figure out how to
@@ -181,45 +223,21 @@ namespace Dml::GraphDescBuilder
                 // just bail early.
                 ORT_THROW_HR(E_UNEXPECTED);
             }
-
-            nameToDmlFusedNodeInputIndex.emplace(graphInput->Name(), gsl::narrow_cast<uint32_t>(inputIndex));
-        }
-
-        StackAllocator<1024> allocator; // Used for converting abstract operator descs into DML_OPERATOR_DESC
-
-        std::vector<NodeInfo> graphNodes;
-        std::vector<DML_INPUT_GRAPH_EDGE_DESC> graphInputEdges;
-        std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> graphIntermediateEdges;
-        std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> graphOutputEdges;
-
-        // Avoid using separate command lists for small graphs. This value can be reduced by tuning the
-        // flushing behavior of DmlCommandRecorder.  Its current behavior is to assume that graphs contain
-        // enough GPU work to be worth flushing immediately.
-        const uint32_t minNodeCountToReuseCommandList = 5;
-        bool reuseCommandList = false;
-
-        if (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice())
-        {
-            reuseCommandList = true;
+            subgraphInputNameToIndexMap.emplace(subgraphInput->Name(), gsl::narrow_cast<uint32_t>(inputIndex));
         }
 
         auto constantCpuGraphInputGetter = [&isInitializerTransferable, &modelPath](const std::string& argName)
         {
             ComPtr<OnnxTensorWrapper> tensorWrapper;
-
             auto iter = isInitializerTransferable.find(argName);
             if (iter != isInitializerTransferable.end())
             {
                 // Using const_cast here is simpler than making surrounding code const correct.
                 tensorWrapper = wil::MakeOrThrow<OnnxTensorWrapper>(const_cast<ONNX_NAMESPACE::TensorProto*>(iter->second.first), modelPath);
             }
-
             return tensorWrapper;
         };
 
-        // Iterate through each node and create a corresponding node in the new graph
-        // We can iterate the nodes in any order because the edge connectivity will take care of the topological order
-        std::unordered_map<std::string, std::vector<uint32_t>> inferredOutputShapes;
 
         for (const onnxruntime::Node* subgraphNode : subgraphNodes)
         {
@@ -277,195 +295,206 @@ namespace Dml::GraphDescBuilder
             }
 
             EdgeShapes outputShapes;
-            DmlGraphNodeCreateInfo graphNodeCreateInfo;
+            DmlGraphNodeCreateInfo operatorDmlGraphCreateInfo;
             graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory(
                 node,
                 constantCpuNodeInputGetter,
                 executionHandle,
                 &inputShapesOverrides,
                 /*out*/ &outputShapes,
-                /*out*/ &graphNodeCreateInfo
+                /*out*/ &operatorDmlGraphCreateInfo
             );
 
             ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size());
             for (int i = 0; i < node.OutputDefs().size(); ++i)
             {
                 inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i);
-            }
-
-            // Create a map between operatorGraphNodeIndex to mainGraphNodeIndex.
-            std::unordered_map<uint32_t, uint32_t> operatorGraphNodeIndexToMainGraphNodeIndexMap;
-            uint32_t graphNodeCount = gsl::narrow_cast<uint32_t>(graphNodes.size());
-            const bool isNodeAsOpDesc = graphNodeCreateInfo.nodesAsOperatorDesc.size() > 0;
-            size_t firstOpDescGraphNodeIndex = graphNodes.size();
-
-            if (isNodeAsOpDesc)
+            }            
+            
+            // Algorithm:
+            //  1. Create constant nodes by iterating through operatorDmlGraph's input edges and keep a map of it,
+            //     because there would be an intermediate edge from the constantNode and source of the intermediate edge
+            //     should come before the destination.
+            //  2. Again iterate through operatorDmlGraph's input edges to create mainGraph's input and intermediate edges.
+            //  3. Iterate through operatorDmlGraph's intermediate edges to create mainGraph's intermediate edges.
+            //  4. Iterate through operatorDmlGraph's output edges to populate outputEdgeNameToDmlGraphNodeAndIndex
+            //  5. While performing step 2, 3, and 4, insert operatorDmlGraphNode to the mainDmlGraphNode list.
+            
+            for (auto& operatorDmlGraphInputEdge : operatorDmlGraphCreateInfo.inputEdges)
             {
-                // Can't populate graphNodes vector at this point, because operatorDesc may get modified later.
-                for (uint32_t nodeIndex = 0; nodeIndex < graphNodeCreateInfo.nodeCount; nodeIndex++)
+                const onnxruntime::NodeArg* arg = node.InputDefs()[operatorDmlGraphInputEdge.GraphInputIndex];
+                if (arg->Exists())
                 {
-                    ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsOperatorDesc[nodeIndex]);
-                    operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++);
-                }
+                    auto iter = subgraphInputNameToIndexMap.find(arg->Name());
+                    if (iter != subgraphInputNameToIndexMap.end() &&
+                        iter->second < isConstGpuGraphInputCount &&
+                        isConstGpuGraphInput[iter->second])
+                    {
+                        DmlSerializedGraphNode constantNode = {};
+                        constantNode.Name = arg->Name();
+
+                        // This is a highly inefficient approach to generating constant nodes.  It duplicates constant data 
+                        // across the graph input as well as every consumer's unique constant node.  However it is currently 
+                        // only used for small inputs.
+                        auto& operatorDmlGraphInputNode = operatorDmlGraphCreateInfo.nodes[operatorDmlGraphInputEdge.ToNodeIndex];
+                        std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = operatorDmlGraphInputNode->GetInputTensors();
+                        DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorDmlGraphInputEdge.ToNodeInputIndex];
+                        ComPtr<OnnxTensorWrapper> constantInput;
+
+                        if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize)
+                        {
+                            constantInput = constantCpuGraphInputGetter(arg->Name());
+                        }
 
-                graphNodes.resize(graphNodes.size() + graphNodeCreateInfo.nodeCount);
-            }
-            else
-            {
-                for (uint32_t nodeIndex = 0; nodeIndex < graphNodeCreateInfo.nodeCount; nodeIndex++)
-                {
-                    ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex].Get());
-                    operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++);
-                    NodeInfo nodeInfo = {};
-                    nodeInfo.nodeDef = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]);
-                    graphNodes.push_back(std::move(nodeInfo));
+                        if (constantInput)
+                        {
+                            // The tensor description's size should be no larger than the constant input unless it was rounded to
+                            // the required alignment.
+                            assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes);
+                            size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast<size_t>(tensorDesc->totalTensorSizeInBytes));
+                            auto data = static_cast<const uint8_t*>(constantInput->GetData());
+                            std::vector<uint8_t> tensorData(data, data + minimumConstantSize);
+
+                            smallConstantData.push_back(std::make_unique<std::byte[]>(tensorData.size()));
+                            std::transform(tensorData.begin(), tensorData.end(), smallConstantData.back().get(), [](uint8_t b) {return static_cast<std::byte>(b);});
+
+                            ConstantData constantData = {smallConstantData.back().get(), tensorData.size()};
+                            constantNode.Desc = constantData;
+                        }
+                        else
+                        {
+                            ConstantName constantFileName = {GetSanitizedFileName(arg->Name())};
+                            constantNode.Desc = constantFileName;
+                        }
+                        dmlGraphNodeOutputNameToNodeAndIndexMap[arg->Name()] = {static_cast<uint32_t>(dmlGraphNodes.size()), 0};
+                        dmlGraphNodes.push_back(constantNode);
+                    }
                 }
             }
 
-            // map operatorGraphInputEdge as either mainGraphInputEdge or mainGraphIntermediateEdge
-            for (auto& operatorGraphInputEdge : graphNodeCreateInfo.inputEdges)
-            {
-                // operatorGraphInputEdge.GraphInputIndex will be the ONNX input index.
-                const onnxruntime::NodeArg* arg = node.InputDefs()[operatorGraphInputEdge.GraphInputIndex];
+            // Create a map between operatorGraphNodeIndex to dmlGraphNodeIndex.
+            std::unordered_map<uint32_t, uint32_t> operatorDmlGraphToDmlGraphNodeIndexMap;
 
+            // map operatorDmlGraphInputEdge as either mainDmlGraphInputEdge or mainDmlGraphIntermediateEdge
+            for (auto& operatorDmlGraphInputEdge : operatorDmlGraphCreateInfo.inputEdges)
+            {
+                // operatorDmlGraphInputEdge.GraphInputIndex will be the ONNX input index.
+                const onnxruntime::NodeArg* arg = node.InputDefs()[operatorDmlGraphInputEdge.GraphInputIndex];
                 if (arg->Exists())
                 {
-                    auto iter = nameToDmlFusedNodeInputIndex.find(arg->Name());
-                    uint32_t mainGraphNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphInputEdge.ToNodeIndex];
-
-                    if (iter != nameToDmlFusedNodeInputIndex.end())
+                    uint32_t dmlGraphNodeIndex = SetAndGetDmlGraphNodeIndex(
+                        operatorDmlGraphInputEdge.ToNodeIndex,
+                        node.Name(),
+                        *operatorDmlGraphCreateInfo.nodes[operatorDmlGraphInputEdge.ToNodeIndex],
+                        operatorDmlGraphToDmlGraphNodeIndexMap,
+                        dmlGraphNodes);
+
+                    auto iter = subgraphInputNameToIndexMap.find(arg->Name());
+                    if (iter != subgraphInputNameToIndexMap.end())
                     {
-                        // This is a graph input
-
-                        const uint32_t dmlFusedNodeInputIndex = iter->second;
-
-                        // If this is a constant input, set the appropriate flags on the desc
-                        if (isNodeAsOpDesc &&
-                            dmlFusedNodeInputIndex < isConstGpuGraphInputCount &&
-                            isConstGpuGraphInput[dmlFusedNodeInputIndex])
+                        const uint32_t subgraphInputIndex = iter->second;
+                        
+                        // Either this edge will be
+                        //  a constant input, then it will be an intermediate edge and 
+                        //  set the OWNED_BY_DML flag if it is large constant 
+                        //  or,
+                        //  a non-constant input, then it will be a mainDmlGraphInputEdge.
+                        if (subgraphInputIndex < isConstGpuGraphInputCount &&
+                            isConstGpuGraphInput[subgraphInputIndex])
                         {
-                            // This is a highly inefficient approach to generating constant nodes.  It duplicates constant data
-                            // across the graph input as well as every consumer's unique constant node.  However it is currently
-                            // only used for small inputs.
-                            uint32_t c_maxConstNodeDataSize = 8;
-
-
-                            auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex];
-                            std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors();
-                            DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex];
-                            ComPtr<OnnxTensorWrapper> constantInput;
-
-                            if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize)
-                            {
-                                constantInput = constantCpuGraphInputGetter(arg->Name());
-                            }
-
-                            if (constantInput)
-                            {
-                                // The tensor description's size should be no larger than the constant input unless it was rounded to
-                                // the required alignment.
-                                assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes);
-                                size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast<size_t>(tensorDesc->totalTensorSizeInBytes));
-                                auto data = static_cast<const uint8_t*>(constantInput->GetData());
-                                std::vector<uint8_t> tensorData(data, data + minimumConstantSize);
-
-                                NodeInfo nodeInfo = {};
-                                nodeInfo.nodeDef = std::move(tensorData);
-                                graphNodes.push_back(std::move(nodeInfo));
-
-                                DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {};
-                                edge.FromNodeIndex = static_cast<UINT>(graphNodes.size() - 1);
-                                edge.FromNodeOutputIndex = 0;
-                                edge.ToNodeIndex = mainGraphNodeIndex;
-                                edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
-                                graphIntermediateEdges.push_back(edge);
-                            }
-                            else
+                            const auto& constantNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(arg->Name());
+                            auto& constantNodeVariant = std::get<DmlSerializedGraphNodeConstantVariant>(dmlGraphNodes[constantNodeAndIndex.nodeIndex].Desc);
+                            if (std::holds_alternative<ConstantName>(constantNodeVariant))
                             {
-                                DML_INPUT_GRAPH_EDGE_DESC edge = {};
-                                edge.GraphInputIndex = dmlFusedNodeInputIndex;
-                                edge.ToNodeIndex = mainGraphNodeIndex;
-                                edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
-                                graphInputEdges.push_back(edge);
-
+                                auto& mainDmlGraphNode = dmlGraphNodes[dmlGraphNodeIndex];
+                                AbstractOperatorDesc& abstractOperatorDesc = std::get<AbstractOperatorDesc>(mainDmlGraphNode.Desc);
+                                std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = abstractOperatorDesc.GetInputTensors();
+                                DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorDmlGraphInputEdge.ToNodeInputIndex];
                                 tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML;
+                                serializedGraphLargeConstantNameToSubgraphInputIndex[arg->Name()] = subgraphInputIndex;
                             }
+
+                            DmlIntermediateSerializedGraphEdge edge = {};
+                            edge.FromNodeIndex = constantNodeAndIndex.nodeIndex;
+                            edge.FromNodeOutputIndex = constantNodeAndIndex.targetIndex;
+                            edge.ToNodeIndex = dmlGraphNodeIndex;
+                            edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex;
+                            edge.Name = arg->Name() + "-nodeIdx:" + std::to_string(edge.FromNodeIndex) + "-outputIdx:" + std::to_string(edge.FromNodeOutputIndex);
+                            dmlGraphIntermediateEdges.push_back(edge);
                         }
                         else
                         {
-                            DML_INPUT_GRAPH_EDGE_DESC edge = {};
-                            edge.GraphInputIndex = dmlFusedNodeInputIndex;
-                            edge.ToNodeIndex = mainGraphNodeIndex;
-                            edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
-                            graphInputEdges.push_back(edge);
+                            DmlInputSerializedGraphEdge edge = {};
+                            if (subGraphInputIndexToDmlGraphInputIndex.find(subgraphInputIndex) == subGraphInputIndexToDmlGraphInputIndex.end())
+                            {
+                                subGraphInputIndexToDmlGraphInputIndex[subgraphInputIndex] = static_cast<uint32_t>(subGraphInputIndexToDmlGraphInputIndex.size());
+                            }
+
+                            edge.GraphInputIndex = subGraphInputIndexToDmlGraphInputIndex[subgraphInputIndex];
+                            edge.ToNodeIndex = dmlGraphNodeIndex;
+                            edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex;  // ?? might need to point inputIndex
+                            edge.Name = arg->Name();
+
+                            serializedGraphInputIndexToSubgraphInputIndex[edge.GraphInputIndex] = subgraphInputIndex;
+                            dmlGraphInputEdges.push_back(edge);
                         }
                     }
                     else
                     {
-                        const auto& inputNodeAndIndex = nameToNodeAndIndexMap.at(arg->Name());
+                        const auto& inputNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(arg->Name());
 
-                        DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {};
+                        DmlIntermediateSerializedGraphEdge edge = {};
                         edge.FromNodeIndex = inputNodeAndIndex.nodeIndex;
                         edge.FromNodeOutputIndex = inputNodeAndIndex.targetIndex;
-                        edge.ToNodeIndex = mainGraphNodeIndex;
-                        edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
-                        graphIntermediateEdges.push_back(edge);
+                        edge.ToNodeIndex = dmlGraphNodeIndex;
+                        edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex;
+                        edge.Name = arg->Name();
+                        dmlGraphIntermediateEdges.push_back(edge);
                     }
                 }
             }
 
             // map operatorGraphIntermediateEdges as mainGraphIntermediateEdge
-            for (auto& operatorGraphIntermediateEdge : graphNodeCreateInfo.intermediateEdges)
+            for (auto& operatorGraphIntermediateEdge : operatorDmlGraphCreateInfo.intermediateEdges)
             {
-                DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {};
-                edge.FromNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphIntermediateEdge.FromNodeIndex];
+                DmlIntermediateSerializedGraphEdge edge = {};
+                uint32_t shiftedFromNodeIndex = SetAndGetDmlGraphNodeIndex(
+                        operatorGraphIntermediateEdge.FromNodeIndex,
+                        node.Name(),
+                        *operatorDmlGraphCreateInfo.nodes[operatorGraphIntermediateEdge.FromNodeIndex],
+                        operatorDmlGraphToDmlGraphNodeIndexMap,
+                        dmlGraphNodes);
+                uint32_t shiftedToNodeIndex = SetAndGetDmlGraphNodeIndex(
+                        operatorGraphIntermediateEdge.ToNodeIndex,
+                        node.Name(),
+                        *operatorDmlGraphCreateInfo.nodes[operatorGraphIntermediateEdge.ToNodeIndex],
+                        operatorDmlGraphToDmlGraphNodeIndexMap,
+                        dmlGraphNodes);
+
+                edge.FromNodeIndex = shiftedFromNodeIndex;
                 edge.FromNodeOutputIndex = operatorGraphIntermediateEdge.FromNodeOutputIndex;
-                edge.ToNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphIntermediateEdge.ToNodeIndex];
+                edge.ToNodeIndex = shiftedToNodeIndex;
                 edge.ToNodeInputIndex = operatorGraphIntermediateEdge.ToNodeInputIndex;
-                graphIntermediateEdges.push_back(edge);
+                edge.Name = "nodeIdx:" + std::to_string(shiftedFromNodeIndex) + "-outputIdx:" + std::to_string(operatorGraphIntermediateEdge.FromNodeOutputIndex);
+                dmlGraphIntermediateEdges.push_back(edge);
             }
-
+            
             // populate nameToNodeAndIndexMap (which will be used by above loop) for operatorGraphOutputEdges
-            for (auto& operatorGraphOutputEdge : graphNodeCreateInfo.outputEdges)
+            for (auto& operatorGraphOutputEdge : operatorDmlGraphCreateInfo.outputEdges)
             {
                 const onnxruntime::NodeArg* arg = node.OutputDefs()[operatorGraphOutputEdge.GraphOutputIndex];
                 if (arg->Exists())
                 {
-                    nameToNodeAndIndexMap[arg->Name()] = NodeAndIndex {
-                        operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphOutputEdge.FromNodeIndex],
-                        operatorGraphOutputEdge.FromNodeOutputIndex
-                    };
-
+                    uint32_t shiftedNodeIndex = SetAndGetDmlGraphNodeIndex(
+                            operatorGraphOutputEdge.FromNodeIndex,
+                            node.Name(),
+                            *operatorDmlGraphCreateInfo.nodes[operatorGraphOutputEdge.FromNodeIndex],
+                            operatorDmlGraphToDmlGraphNodeIndexMap,
+                            dmlGraphNodes);
+                    dmlGraphNodeOutputNameToNodeAndIndexMap[arg->Name()] = {shiftedNodeIndex, operatorGraphOutputEdge.FromNodeOutputIndex};
                     nodeOutputShapes[arg->Name()] = outputShapes;
                 }
             }
-
-            if (isNodeAsOpDesc)
-            {
-                for (size_t i = 0; i < graphNodeCreateInfo.nodesAsOperatorDesc.size(); ++i)
-                {
-                    auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i];
-
-                    DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator);
-
-                    // TODO: Change as new header is ingested
-                    if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING)
-                        dmlDesc.Type = (DML_OPERATOR_TYPE) 169;
-
-                    // TODO: Change as new header is ingested
-                    if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT)
-                        dmlDesc.Type = (DML_OPERATOR_TYPE) 170;
-
-                    ComPtr<IDMLOperator> op;
-                    ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op)));
-                    allocator.Reset();
-
-                    NodeInfo nodeInfo = {};
-                    nodeInfo.nodeDef = std::move(op);
-                    nodeInfo.name = node.Name();
-                    graphNodes[firstOpDescGraphNodeIndex + i] = std::move(nodeInfo);
-                }
-            }
         }
 
         EdgeShapes graphOutputShapes(subgraphOutputs.size());
@@ -476,24 +505,27 @@ namespace Dml::GraphDescBuilder
             const onnxruntime::NodeArg* graphOutput = subgraphOutputs[outputIndex];
 
             ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg");
-            const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name());
+            const auto& outputNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(graphOutput->Name());
 
-            DML_OUTPUT_GRAPH_EDGE_DESC edge = {};
+            DmlOutputSerializedGraphEdge edge = {};
             edge.FromNodeIndex = outputNodeAndIndex.nodeIndex;
             edge.FromNodeOutputIndex = outputNodeAndIndex.targetIndex;
             edge.GraphOutputIndex = gsl::narrow_cast<uint32_t>(outputIndex);
-            graphOutputEdges.push_back(edge);
+            edge.Name = graphOutput->Name();
+            dmlGraphOutputEdges.push_back(edge);
             graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex);
         }
 
-        RemoveUnconnectedNodes(graphNodes, graphInputEdges, graphIntermediateEdges, graphOutputEdges);
+        RemoveUnconnectedNodes(dmlGraphNodes, dmlGraphInputEdges, dmlGraphIntermediateEdges, dmlGraphOutputEdges);
 
         GraphDesc graphDesc{};
-        graphDesc.nodes = std::move(graphNodes);
-        graphDesc.inputEdges = std::move(graphInputEdges);
-        graphDesc.outputEdges = std::move(graphOutputEdges);
-        graphDesc.intermediateEdges = std::move(graphIntermediateEdges);
-        graphDesc.reuseCommandList = reuseCommandList;
+        graphDesc.InputCount = static_cast<uint32_t>(dmlGraphInputEdges.size());
+        graphDesc.OutputCount = static_cast<uint32_t>(subgraphOutputs.size());
+        graphDesc.Nodes = std::move(dmlGraphNodes);
+        graphDesc.InputEdges = std::move(dmlGraphInputEdges);
+        graphDesc.OutputEdges = std::move(dmlGraphOutputEdges);
+        graphDesc.IntermediateEdges = std::move(dmlGraphIntermediateEdges);
+        graphDesc.reuseCommandList = (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice());
         graphDesc.outputShapes = std::move(graphOutputShapes);
         return graphDesc;
     }
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h
index c95e89b45541b..4055984b40405 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h
@@ -22,22 +22,15 @@ namespace Dml
 
     namespace GraphDescBuilder
     {
+        constexpr uint32_t minNodeCountToReuseCommandList = 5;
+        constexpr uint32_t c_maxConstNodeDataSize = 8;
+
         // Gets a unique name for the node which survives recreation and graph manipulations between the point
         // that graph partitioning occurs and kernel creation happens
         const std::string& GetUniqueNodeName(const onnxruntime::Node& node);
 
-        struct NodeInfo
-        {
-            std::variant<Microsoft::WRL::ComPtr<IDMLOperator>, std::vector<uint8_t>> nodeDef;
-            std::string name;
-        };
-
-        struct GraphDesc
+        struct GraphDesc : DmlSerializedGraphDesc
         {
-            std::vector<NodeInfo> nodes;
-            std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
-            std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
-            std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
             bool reuseCommandList;
             Windows::AI::MachineLearning::Adapter::EdgeShapes outputShapes;
         };
@@ -47,11 +40,13 @@ namespace Dml
             const size_t isConstGpuGraphInputCount,
             const std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>& isInitializerTransferable,
             const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
-            IDMLDevice* device,
             const ExecutionProviderImpl* executionHandle,
             const onnxruntime::Path& modelPath,
             gsl::span<const onnxruntime::Node* const> subgraphNodes,
             gsl::span<const onnxruntime::NodeArg* const> subgraphInputs,
-            gsl::span<const onnxruntime::NodeArg* const> subgraphOutputs);
+            gsl::span<const onnxruntime::NodeArg* const> subgraphOutputs,
+            /*out*/ std::unordered_map<uint32_t, uint32_t>& serializedGraphInputIndexToSubgraphInputIndex,
+            /*out*/ std::unordered_map<std::string_view, uint32_t>& serializedGraphLargeConstantNameToSubgraphInputIndex,
+            /*out*/ std::vector<std::unique_ptr<std::byte[]>>& smallConstantData);
     }
 }
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp
index d524780de71b8..f29fbc7a1a65b 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp
@@ -1508,31 +1508,17 @@ namespace Windows::AI::MachineLearning::Adapter
         ORT_TRY
         {
             assert(operatorGraphDesc != nullptr);
-            // Either nodesAsOpDesc or nodesIDMLOperator can be present.
-            assert(operatorGraphDesc->nodeCount == 0 || (!operatorGraphDesc->nodesAsOpDesc ^ !operatorGraphDesc->nodesAsIDMLOperator));
+            assert(operatorGraphDesc->nodeCount == 0 || operatorGraphDesc->nodes);
 
-            if (operatorGraphDesc->nodesAsOpDesc)
+            m_graphNodeCreateInfo->nodes = std::vector<std::unique_ptr<AbstractOperatorDesc>>();
+            for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++)
             {
-                m_graphNodeCreateInfo->nodesAsOperatorDesc = std::vector<std::unique_ptr<AbstractOperatorDesc>>();
-                for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++)
-                {
-                    auto* node = operatorGraphDesc->nodesAsOpDesc[nodeIndex];
-                    assert(node != nullptr);
-                    AbstractOperatorDesc abstractDesc = SchemaHelpers::ConvertOperatorDesc(*node);
-                    m_graphNodeCreateInfo->nodesAsOperatorDesc.push_back(std::make_unique<AbstractOperatorDesc>(std::move(abstractDesc)));
-                }
-            }
-            else
-            {
-                m_graphNodeCreateInfo->nodesAsIDMLOperator = std::vector<Microsoft::WRL::ComPtr<IDMLOperator>>();
-                for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++)
-                {
-                    auto* node = operatorGraphDesc->nodesAsIDMLOperator[nodeIndex];
-                    assert(node != nullptr);
-                    m_graphNodeCreateInfo->nodesAsIDMLOperator.push_back(node);
-                }
+                auto* node = operatorGraphDesc->nodes[nodeIndex];
+                assert(node != nullptr);
+                AbstractOperatorDesc abstractDesc = SchemaHelpers::ConvertOperatorDesc(*node);
+                m_graphNodeCreateInfo->nodes.push_back(std::make_unique<AbstractOperatorDesc>(std::move(abstractDesc)));
             }
-
+            
             // There can be operators (or kernels) which don't require any input.
             assert(operatorGraphDesc->inputEdgeCount == 0 || operatorGraphDesc->inputEdges != nullptr);
             m_graphNodeCreateInfo->inputEdges.insert(
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp
index c3bb1a52210f5..287f1e5b6dfe7 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp
@@ -53,7 +53,7 @@ namespace Dml
             MLOperatorGraphDesc operatorGraphDesc = {};
             operatorGraphDesc.nodeCount = 1;
             const DML_OPERATOR_DESC* opDescs{&operatorDesc};
-            operatorGraphDesc.nodesAsOpDesc = &opDescs;
+            operatorGraphDesc.nodes = &opDescs;
 
             std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
             for (uint32_t inputIndex = 0; inputIndex < m_kernelInputIndices.size(); inputIndex++)
@@ -796,7 +796,7 @@ namespace Dml
         for (size_t i = 0; i < graphDesc.NodeCount; ++i)
         {
             // Create the operator.
-            ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(operatorGraphDesc.nodesAsOpDesc[i], IID_PPV_ARGS(&dmlOperators[i])));
+            ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(operatorGraphDesc.nodes[i], IID_PPV_ARGS(&dmlOperators[i])));
             dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{dmlOperators[i].Get()};
             dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
         }
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp
index c8ca6806e75f7..73c2d57e984af 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp
@@ -531,7 +531,7 @@ class DmlOperatorAttention : public DmlOperator
         operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
         operatorGraphDesc.outputEdges = outputEdges.data();
         operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
-        operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+        operatorGraphDesc.nodes = opDescs.data();
 
         SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
     }
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp
index 1c851c94c4ddc..5aceebbdabfe3 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp
@@ -103,7 +103,7 @@ class DmlOperatorBiasAdd : public DmlOperator
         operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
         operatorGraphDesc.outputEdges = outputEdges.data();
         operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
-        operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+        operatorGraphDesc.nodes = opDescs.data();
         SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
     }
 };
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp
index 501ce14f1fc08..1e10214ffd463 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp
@@ -137,7 +137,7 @@ class DmlOperatorBiasSplitGelu : public DmlOperator
         operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
         operatorGraphDesc.outputEdges = outputEdges.data();
         operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
-        operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+        operatorGraphDesc.nodes = opDescs.data();
         SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
     }
 };
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp
new file mode 100644
index 0000000000000..c6a87da705a99
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp
@@ -0,0 +1,173 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "precomp.h"
+
+namespace Dml
+{
+// DynamicQuantizeMatMul = MatrixMultiplyIntegerToFloat(DynamicQuantizeLinear(A), B)
+class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
+{
+    // This order matches the ONNX schema.
+    enum OnnxInputIndex
+    {
+        A, // Input
+        B,
+        B_scale,
+        B_zero_point,
+        Bias,
+        Count,
+    };
+
+public:
+    DmlOperatorDynamicQuantizeMatMul(const MLOperatorKernelCreationContext& kernelCreationContext)
+    :   DmlOperator(kernelCreationContext)
+    {
+        DmlOperator::Initialize(kernelCreationContext);
+
+        const bool hasBias = kernelCreationContext.IsInputValid(OnnxInputIndex::Bias);
+        const bool hasBZP = kernelCreationContext.IsInputValid(OnnxInputIndex::B_zero_point);
+
+        // Broadcast Bias tensor to the shape of the output tensor.
+        if (hasBias)
+        {
+            m_inputTensorDescs[OnnxInputIndex::Bias] = CreateTensorDescFromInput(
+                kernelCreationContext,
+                OnnxInputIndex::Bias,
+                TensorAxis::DoNotCoerce,
+                TensorAxis::W,
+                TensorAxis::RightAligned,
+                kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0)
+            );
+        }
+        MLOperatorTensorDataType BDatatype = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::B).tensorDataType;
+
+        std::vector<uint32_t> ATensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::A);
+        std::vector<uint32_t> ExpectedAScaleTensorShape = {1, 1, 1, 1};
+        std::vector<uint32_t> ExpectedAZeroPointTensorShape = {1, 1, 1, 1};
+
+        //  output edges between DynQL and MMItoFloat node
+        TensorDesc intermediateQuantizedATensorDesc = TensorDesc(
+                BDatatype,
+                gsl::make_span(ATensorShape),
+                gsl::make_span(ATensorShape),
+                TensorAxis::DoNotCoerce,
+                TensorAxis::W,
+                TensorAxis::RightAligned,
+                NchwDimensionCount,  // minDimensionCount
+                0  // guaranteedBaseOffsetAlignment
+            );
+
+        TensorDesc intermediateQuantizedAScaleTensorDesc = TensorDesc(
+                MLOperatorTensorDataType::Float,
+                gsl::make_span(ExpectedAScaleTensorShape),
+                gsl::make_span(ExpectedAScaleTensorShape),
+                TensorAxis::DoNotCoerce,
+                TensorAxis::W,
+                TensorAxis::RightAligned,
+                NchwDimensionCount,  // minDimensionCount
+                0  // guaranteedBaseOffsetAlignment
+            );
+
+        TensorDesc intermediateQuantizedAZeroPointTensorDesc = TensorDesc(
+                BDatatype,
+                gsl::make_span(ExpectedAZeroPointTensorShape),
+                gsl::make_span(ExpectedAZeroPointTensorShape),
+                TensorAxis::DoNotCoerce,
+                TensorAxis::W,
+                TensorAxis::RightAligned,
+                NchwDimensionCount,  // minDimensionCount
+                0  // guaranteedBaseOffsetAlignment
+            );
+
+        DML_TENSOR_DESC namedIntermediateQuantizedATensorDesc = intermediateQuantizedATensorDesc.GetDmlDesc();
+        DML_TENSOR_DESC namedIntermediateQuantizedAScaleTensorDesc = intermediateQuantizedAScaleTensorDesc.GetDmlDesc();
+        DML_TENSOR_DESC namedIntermediateQuantizedAZeroPointTensorDesc = intermediateQuantizedAZeroPointTensorDesc.GetDmlDesc();
+
+        std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
+        std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
+
+        DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC dynamicQuantizeLinearOperatorDesc = {};
+        dynamicQuantizeLinearOperatorDesc.InputTensor = &inputDescs[OnnxInputIndex::A];
+        dynamicQuantizeLinearOperatorDesc.OutputTensor = &namedIntermediateQuantizedATensorDesc;
+        dynamicQuantizeLinearOperatorDesc.OutputScaleTensor = &namedIntermediateQuantizedAScaleTensorDesc;
+        dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor = &namedIntermediateQuantizedAZeroPointTensorDesc;
+
+        const DML_OPERATOR_DESC opDesc1{DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR, &dynamicQuantizeLinearOperatorDesc};
+
+        DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matrixMultiplyIntergerToFloatOperatorDesc = {};
+        matrixMultiplyIntergerToFloatOperatorDesc.ATensor = dynamicQuantizeLinearOperatorDesc.OutputTensor;
+        matrixMultiplyIntergerToFloatOperatorDesc.AScaleTensor = dynamicQuantizeLinearOperatorDesc.OutputScaleTensor;
+        matrixMultiplyIntergerToFloatOperatorDesc.AZeroPointTensor = dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor;
+        matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &inputDescs[OnnxInputIndex::B];
+        matrixMultiplyIntergerToFloatOperatorDesc.BScaleTensor = &inputDescs[OnnxInputIndex::B_scale];
+        matrixMultiplyIntergerToFloatOperatorDesc.BZeroPointTensor = hasBZP? &inputDescs[OnnxInputIndex::B_zero_point] : nullptr;
+        matrixMultiplyIntergerToFloatOperatorDesc.BiasTensor = hasBias? &inputDescs[OnnxInputIndex::Bias] : nullptr;
+        matrixMultiplyIntergerToFloatOperatorDesc.OutputTensor = &outputDescs[0];
+
+        const DML_OPERATOR_DESC opDesc2{ DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matrixMultiplyIntergerToFloatOperatorDesc};
+
+        MLOperatorGraphDesc operatorGraphDesc = {};
+        std::vector<const DML_OPERATOR_DESC*> opDescs{&opDesc1, &opDesc2};
+        operatorGraphDesc.nodeCount = static_cast<uint32_t>(opDescs.size());
+        operatorGraphDesc.nodes = opDescs.data();
+
+        // set input edges
+        std::pair<uint32_t, uint32_t> nodeToNodeInputIndex[OnnxInputIndex::Count] {{0, 0}, {1, 3}, {1, 4}, {1, 5}, {1, 6}};
+        std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
+        for (uint32_t inputIndex = 0; inputIndex < OnnxInputIndex::Count; inputIndex++)
+        {
+            if (inputIndex == OnnxInputIndex::B_zero_point && !hasBZP) continue;
+            if (inputIndex == OnnxInputIndex::Bias && !hasBias) continue;
+            DML_INPUT_GRAPH_EDGE_DESC inputEdge = {};
+            inputEdge.GraphInputIndex = inputIndex;  // OnnxInputIndex and DmlInputIndex are identity for QLinearSigmoid
+            inputEdge.ToNodeIndex = nodeToNodeInputIndex[inputIndex].first;
+            inputEdge.ToNodeInputIndex = nodeToNodeInputIndex[inputIndex].second;
+            inputEdges.push_back(inputEdge);
+        }
+        operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
+        operatorGraphDesc.inputEdges = inputEdges.data();
+
+        // set intermediate edges
+        std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
+
+        DML_INTERMEDIATE_GRAPH_EDGE_DESC dynQLToMMItofloatEdge1 = {};
+        dynQLToMMItofloatEdge1.FromNodeIndex = 0;
+        dynQLToMMItofloatEdge1.FromNodeOutputIndex = 0;
+        dynQLToMMItofloatEdge1.ToNodeIndex = 1;
+        dynQLToMMItofloatEdge1.ToNodeInputIndex = 0;
+        intermediateEdges.push_back(dynQLToMMItofloatEdge1);
+
+        DML_INTERMEDIATE_GRAPH_EDGE_DESC dynQLToMMItofloatEdge2 = {};
+        dynQLToMMItofloatEdge2.FromNodeIndex = 0;
+        dynQLToMMItofloatEdge2.FromNodeOutputIndex = 1;
+        dynQLToMMItofloatEdge2.ToNodeIndex = 1;
+        dynQLToMMItofloatEdge2.ToNodeInputIndex = 1;
+        intermediateEdges.push_back(dynQLToMMItofloatEdge2);
+
+        DML_INTERMEDIATE_GRAPH_EDGE_DESC dynQLToMMItofloatEdge3 = {};
+        dynQLToMMItofloatEdge3.FromNodeIndex = 0;
+        dynQLToMMItofloatEdge3.FromNodeOutputIndex = 2;
+        dynQLToMMItofloatEdge3.ToNodeIndex = 1;
+        dynQLToMMItofloatEdge3.ToNodeInputIndex = 2;
+        intermediateEdges.push_back(dynQLToMMItofloatEdge3);
+
+        operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
+        operatorGraphDesc.intermediateEdges = intermediateEdges.data();
+
+        // set the output edges
+        std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
+        DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
+        outputEdge.FromNodeIndex = 1;
+        outputEdge.FromNodeOutputIndex = 0;
+        outputEdge.GraphOutputIndex = 0;
+        outputEdges.push_back(outputEdge);
+        operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
+        operatorGraphDesc.outputEdges = outputEdges.data();
+
+        SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
+    }
+};
+
+DML_OP_DEFINE_CREATION_FUNCTION(DynamicQuantizeMatMul, DmlOperatorDynamicQuantizeMatMul);
+}  // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp
index 6a8333cd72561..3c9458658c4d0 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp
@@ -484,7 +484,7 @@ class DmlOperatorEmbedLayerNormalization : public DmlOperator
         operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
         operatorGraphDesc.outputEdges = outputEdges.data();
         operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
-        operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+        operatorGraphDesc.nodes = opDescs.data();
 
         SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
     }
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp
index fed0e4645ffd8..8b275fc550f3e 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp
@@ -287,7 +287,7 @@ class DmlOperatorGroupNorm : public DmlOperator
         operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
         operatorGraphDesc.outputEdges = outputEdges.data();
         operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
-        operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+        operatorGraphDesc.nodes = opDescs.data();
         SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
     }
 };
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp
index 5c64059f7caa9..80e6fefc2fb80 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp
@@ -247,7 +247,7 @@ class DmlOperatorLayerNormalization : public DmlOperator
         operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
         operatorGraphDesc.outputEdges = outputEdges.data();
         operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
-        operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+        operatorGraphDesc.nodes = opDescs.data();
 
         SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
     }
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp
new file mode 100644
index 0000000000000..b5a3dd0960b86
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp
@@ -0,0 +1,111 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "precomp.h"
+
+namespace Dml
+{
+
+class DmlOperatorMatMulIntegerToFloat : public DmlOperator
+{
+    enum OrtInputTensors : uint32_t
+    {
+        ortA,
+        ortB,
+        ortAScale,
+        ortBScale,
+        ortAZeroPoint,
+        ortBZeroPoint,
+        ortBias,
+        ortInputCount
+    };
+
+    enum DmlInputIndex : uint32_t
+    {
+        dmlA,
+        dmlAScale,
+        dmlAZeroPoint,
+        dmlB,
+        dmlBScale,
+        dmlBZeroPoint,
+        dmlBias,
+        dmlInputCount,
+    };
+
+public:
+    DmlOperatorMatMulIntegerToFloat(const MLOperatorKernelCreationContext& kernelInfo)
+        :   DmlOperator(kernelInfo)
+    {
+        std::vector<std::optional<uint32_t>> inputIndices = { OrtInputTensors::ortA, OrtInputTensors::ortAScale, OrtInputTensors::ortAZeroPoint, OrtInputTensors::ortB, OrtInputTensors::ortBScale, OrtInputTensors::ortBZeroPoint, OrtInputTensors::ortBias };
+        DmlOperator::Initialize(kernelInfo, inputIndices);
+
+        std::vector<DimensionType> inputShape0 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortA);
+        std::vector<DimensionType> inputShape1 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortB);
+        std::vector<DimensionType> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
+
+        OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape);
+
+        // Initialize the input descriptions with broadcasting
+        m_inputTensorDescs[DmlInputIndex::dmlA] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortA, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape0);
+        m_inputTensorDescs[DmlInputIndex::dmlB] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortB, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape1);
+
+        // Broadcast Bias tensor to the shape of the output tensor.
+        if(kernelInfo.IsInputValid(OrtInputTensors::ortBias)) {
+            m_inputTensorDescs[DmlInputIndex::dmlBias] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortBias, TensorAxis::DoNotCoerce,
+                TensorAxis::W, TensorAxis::RightAligned, outputShape);
+        }
+
+        uint32_t dmlDimSize = m_inputTensorDescs[DmlInputIndex::dmlA].GetDimensionCount();
+        // Resize the A Scale to be the same dimension as the input tensor.
+        // The 1D tensor needs to be moved to the H channel.
+        m_inputTensorDescs[DmlInputIndex::dmlAScale] = CreateTensorDescFromInput(
+            kernelInfo,
+            OrtInputTensors::ortAScale,
+            TensorAxis::DoNotCoerce,
+            TensorAxis::H,
+            TensorAxis::LeftAligned,
+            std::nullopt,
+            dmlDimSize
+            );
+
+        // Resize the A ZeroPoint to be the same dimension as the input tensor.
+        // The 1D tensor needs to be moved to the H channel.
+        if (kernelInfo.IsInputValid(OrtInputTensors::ortAZeroPoint))
+        {
+            m_inputTensorDescs[DmlInputIndex::dmlAZeroPoint] = CreateTensorDescFromInput(
+                kernelInfo,
+                OrtInputTensors::ortAZeroPoint,
+                TensorAxis::DoNotCoerce,
+                TensorAxis::H,
+                TensorAxis::LeftAligned,
+                std::nullopt,
+                dmlDimSize
+                );
+        }
+
+        // B Zeropoint and BScale are already aligned in the W dimension so no need to align them
+
+        // Initialize the output description while overriding the shape
+        m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
+
+        std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
+        std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
+
+        DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulDesc = {};
+        matMulDesc.ATensor = &inputDescs[DmlInputIndex::dmlA];
+        matMulDesc.AScaleTensor = &inputDescs[DmlInputIndex::dmlAScale];
+        matMulDesc.AZeroPointTensor = inputDescs[DmlInputIndex::dmlAZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlAZeroPoint] : nullptr;
+        matMulDesc.BTensor = &inputDescs[DmlInputIndex::dmlB];
+        matMulDesc.BScaleTensor = &inputDescs[DmlInputIndex::dmlBScale];
+        matMulDesc.BZeroPointTensor = inputDescs[DmlInputIndex::dmlBZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBZeroPoint] : nullptr;
+        matMulDesc.BiasTensor = inputDescs[DmlInputIndex::dmlBias].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBias] : nullptr;
+        matMulDesc.OutputTensor = &outputDescs[0];
+
+        DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulDesc };
+        SetDmlOperatorDesc(opDesc, kernelInfo);
+    }
+};
+
+DML_OP_DEFINE_CREATION_FUNCTION(MatMulIntegerToFloat, DmlOperatorMatMulIntegerToFloat);
+
+}  // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp
new file mode 100644
index 0000000000000..f9519b26bb4e3
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp
@@ -0,0 +1,704 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "precomp.h"
+
+/*
+Abbreviations: B is batch_size, S is sequence_length, W is hidden_size
+               N is number of attention heads, H is head size, and W=N*H
+
+Input, Weight, Bias, Mask Index and Past are Inputs
+
+Mask Index/Causal  Input   Weight   Bias
+         |             \    |       /
+         |              \   |      /
+         |               \  |     /
+         |             MatMulIntToFloat
+         |                / |   \
+         |               /  |    \
+         |              /   |     \
+         |          Slice  Slice  Slice
+         |            |     |       |
+         |            |     |       |
+         |      Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while
+         |            |     |       |      // keeping the GEMM strides as NCHW to better target metacommands
+         |            |     |       |
+         |            |     |       |       Past
+         |            |     |       |       / \
+         |            |     |       |      /   \
+         |            |     |       |  Slice   Slice
+         |            |     |       |     |      |
+         |            |     |       |     |      |
+         |            |     |       |     |      |
+         --------------------------MHA -----------
+                                  / | \
+                                 /  |   \
+                                /   |     \
+                               /    |       \
+                              /     |         \
+                             /      |           \
+                            /  presentKey   presentValue
+                           /         \       /
+                          /           \     /
+                         /             \   /
+                        /             Concat
+                       /                 |
+                   Output1            Output2 (present)
+
+ This kernel creates a DML_GRAPH, as mentioned above.
+ For reference, refer to this Doc:
+ https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftqattention
+ */
+
+namespace Dml
+{
+class DmlOperatorQAttention : public DmlOperator
+{
+public:
+    DmlOperatorQAttention(const MLOperatorKernelCreationContext& kernelCreationContext)
+    :   DmlOperator(kernelCreationContext)
+    {
+        enum InputIndex : uint32_t
+        {
+            inputIndex,
+            weightsIndex,
+            biasIndex,
+            inputScaleIndex,
+            weightScaleIndex,
+            maskIndex,
+            inputZeroPointIndex,
+            weightZeroPointIndex,
+            pastIndex,
+            inputCount,
+        };
+
+        enum OutputIndex : uint32_t
+        {
+            outputIndex,
+            presentIndex,
+            outputCount,
+        };
+
+        enum MhaInputIndex : uint32_t
+        {
+            mhaQueryIndex,
+            mhaKeyIndex,
+            mhaValueIndex,
+            mhaStackedQueryKeyIndex,
+            mhaStackedKeyValueIndex,
+            mhaStackedQueryKeyValueIndex,
+            mhaBiasIndex,
+            mhaMaskIndex,
+            mhaRelativePositionBiasIndex,
+            mhaPastKeyIndex,
+            mhaPastValueIndex,
+            mhaInputCount,
+        };
+
+        enum MhaOutputIndex : uint32_t
+        {
+            mhaOutputIndex,
+            mhaPresentKeyIndex,
+            mhaPresentValueIndex,
+            mhaOutputCount,
+        };
+
+        ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 5);
+        ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1);
+
+        const bool hasBias = kernelCreationContext.IsInputValid(biasIndex);
+        const bool hasMask = kernelCreationContext.IsInputValid(maskIndex);
+        const bool hasUnpaddedBounds = hasMask && kernelCreationContext.GetInputTensorDimensionCount(maskIndex) == 1;
+        const bool hasPast = kernelCreationContext.IsInputValid(pastIndex);
+
+        DmlOperator::Initialize(kernelCreationContext, std::nullopt, std::nullopt, std::nullopt, std::nullopt, 1);
+
+        const bool unidirectional = gsl::narrow_cast<uint32_t>(kernelCreationContext.GetAttribute<int64_t>(AttrName::Unidirectional));
+        const uint32_t numHeads = gsl::narrow_cast<uint32_t>(kernelCreationContext.GetAttribute<int64_t>(AttrName::NumHeads));
+        ML_CHECK_VALID_ARGUMENT(numHeads > 0); //  to avoid process crash because of division by zero.
+
+        auto inputTensorShape = m_inputTensorDescs[inputIndex].GetSizes();
+        ML_CHECK_VALID_ARGUMENT(inputTensorShape.size() == 3);
+
+        auto weightTensorShape = m_inputTensorDescs[weightsIndex].GetSizes();
+        ML_CHECK_VALID_ARGUMENT(weightTensorShape.size() == 2);
+        ML_CHECK_VALID_ARGUMENT(weightTensorShape[0] == inputTensorShape[2]);
+        ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] % 3 == 0);
+
+        if (hasBias)
+        {
+            auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes();
+            ML_CHECK_VALID_ARGUMENT(biasTensorShape.size() == 1);
+            ML_CHECK_VALID_ARGUMENT(biasTensorShape[0] % 3 == 0);
+            ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] == biasTensorShape[0]);
+        }
+
+        if (hasPast)
+        {
+            ML_CHECK_VALID_ARGUMENT(kernelCreationContext.IsOutputValid(presentIndex));
+        }
+
+        const uint32_t hiddenSize = weightTensorShape[1] / 3;
+        const uint32_t headSize = hiddenSize / numHeads;
+        const uint32_t batchSize = inputTensorShape[0];
+        const uint32_t sequenceLength = inputTensorShape[1];
+        const uint32_t pastSequenceLength = hasPast ? m_inputTensorDescs[pastIndex].GetSizes()[3] : 0;
+
+        uint32_t desiredWeightTensorShape[3] = {batchSize, weightTensorShape[0], 3 * hiddenSize};
+        MLOperatorTensorDataType dataType = kernelCreationContext.GetOutputEdgeDescription(outputIndex).tensorDataType;
+
+        m_inputTensorDescs[weightsIndex] = TensorDesc::ConstructBroadcastedTensorDesc(
+            kernelCreationContext.GetInputEdgeDescription(weightsIndex).tensorDataType,
+            desiredWeightTensorShape,
+            weightTensorShape);
+
+        uint32_t desiredBiasTensorShape[3] = {batchSize, sequenceLength, 3 * hiddenSize};
+
+        if (hasBias)
+        {
+            auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes();
+            m_inputTensorDescs[biasIndex] = TensorDesc::ConstructBroadcastedTensorDesc(kernelCreationContext.GetInputEdgeDescription(biasIndex).tensorDataType, desiredBiasTensorShape, biasTensorShape);
+        }
+
+        MLOperatorTensorDataType maskTensorDataType = MLOperatorTensorDataType::Undefined;
+        bool hasMaxSequenceMask = false;
+        DML_MULTIHEAD_ATTENTION_MASK_TYPE maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE;
+        if (hasMask)
+        {
+            if (hasUnpaddedBounds)
+            {
+                auto unpaddedKeyBoundsShape = m_inputTensorDescs[maskIndex].GetSizes();
+                ML_CHECK_VALID_ARGUMENT(unpaddedKeyBoundsShape.size() == 1);
+
+                const uint32_t batchGroupCount = unpaddedKeyBoundsShape[0] / batchSize;
+                ML_CHECK_VALID_ARGUMENT(batchGroupCount == 1 || batchGroupCount == 2);
+
+                uint32_t desiredShape[2] = {batchGroupCount, batchSize};
+                m_inputTensorDescs[maskIndex] = TensorDesc(
+                    m_inputTensorDescs[maskIndex].GetDmlDataType(),
+                    desiredShape);
+
+                maskType = batchGroupCount == 1
+                    ? DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH
+                    : DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START;
+            }
+            else
+            {
+                auto maskIndexTensorShape = m_inputTensorDescs[maskIndex].GetSizes();
+                ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape.size() > 1 && maskIndexTensorShape.size() <= 4);
+
+                maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN;
+                std::vector<uint32_t> reshapedMaskIndexTensorShape(maskIndexTensorShape.begin(), maskIndexTensorShape.end());
+                if (maskIndexTensorShape.size() == 4 && maskIndexTensorShape[2] != sequenceLength)
+                {
+                    hasMaxSequenceMask = true;
+                    ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape[2] == maskIndexTensorShape[3]);
+                    const uint32_t maxSequenceLength = maskIndexTensorShape[2];
+                    uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, maxSequenceLength, maxSequenceLength};
+                    maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType;
+                    m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape);
+                }
+                else
+                {
+                    uint32_t maskIndexDimensionCount = gsl::narrow_cast<uint32_t>(maskIndexTensorShape.size());
+                    reshapedMaskIndexTensorShape.insert(reshapedMaskIndexTensorShape.begin() + 1, 4 - maskIndexDimensionCount, 1);
+                    uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, sequenceLength, sequenceLength};
+                    maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType;
+                    m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape);
+                }
+            }
+        }
+
+        MLOperatorTensorDataType pastTensorDataType = MLOperatorTensorDataType::Undefined;
+        MLOperatorTensorDataType presentTensorDataType = MLOperatorTensorDataType::Undefined;
+        if (hasPast)
+        {
+            pastTensorDataType = kernelCreationContext.GetInputEdgeDescription(pastIndex).tensorDataType;
+            presentTensorDataType = kernelCreationContext.GetOutputEdgeDescription(presentIndex).tensorDataType;
+        }
+
+        TensorDesc matMulIntToFloatOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape);
+        DML_TENSOR_DESC namedMatMulIntToFloatOutputTensorDesc = matMulIntToFloatOutputTensorDesc.GetDmlDesc();
+
+        std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
+        std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
+
+        DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulIntToFloatOperatorDesc = {};
+        matMulIntToFloatOperatorDesc.ATensor = &inputDescs[InputIndex::inputIndex];
+        matMulIntToFloatOperatorDesc.AScaleTensor = &inputDescs[InputIndex::inputScaleIndex];
+        matMulIntToFloatOperatorDesc.AZeroPointTensor = &inputDescs[InputIndex::inputZeroPointIndex];
+        matMulIntToFloatOperatorDesc.BTensor = &inputDescs[InputIndex::weightsIndex];
+        matMulIntToFloatOperatorDesc.BScaleTensor = &inputDescs[InputIndex::weightScaleIndex];
+        matMulIntToFloatOperatorDesc.BZeroPointTensor = &inputDescs[InputIndex::weightZeroPointIndex];
+        matMulIntToFloatOperatorDesc.BiasTensor = hasBias ? &inputDescs[InputIndex::biasIndex] : nullptr;
+        matMulIntToFloatOperatorDesc.OutputTensor = &namedMatMulIntToFloatOutputTensorDesc;
+
+        const DML_OPERATOR_DESC matMulIntToFloatDesc = { DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulIntToFloatOperatorDesc};
+
+        std::array<uint32_t, 3> queryKeySlicedTensorShape = {batchSize, sequenceLength, hiddenSize + hiddenSize};
+        TensorDesc queryKeySlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeySlicedTensorShape);
+        DML_TENSOR_DESC namedQueryKeySlicedInputTensorDesc = queryKeySlicedInputTensorDesc.GetDmlDesc();
+
+        std::array<uint32_t, 3> valueSlicedTensorShape = {batchSize, sequenceLength, hiddenSize};
+        TensorDesc valueSlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, valueSlicedTensorShape);
+        DML_TENSOR_DESC namedValueSlicedInputTensorDesc = valueSlicedInputTensorDesc.GetDmlDesc();
+
+        // Transpose slice QK from [batchSize, sequenceLength, 2, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 2, headSize]
+        std::array<uint32_t, 5> queryKeyTransposedTensorShape = {batchSize, sequenceLength, numHeads, 2, headSize};
+        std::array<uint32_t, 5> queryKeyTransposedStrides = {
+            sequenceLength * numHeads * 2 * headSize,
+            numHeads * 2 * headSize,
+            headSize,
+            numHeads * headSize,
+            1,
+        };
+
+        TensorDesc queryKeyTransposedInputTensorDesc = TensorDesc(
+            GetDmlDataTypeFromMlDataType(dataType),
+            queryKeyTransposedTensorShape,
+            queryKeyTransposedStrides);
+        DML_TENSOR_DESC namedQueryKeyTransposedInputTensorDesc = queryKeyTransposedInputTensorDesc.GetDmlDesc();
+
+        TensorDesc queryKeyTransposedOutputTensorDesc = TensorDesc(
+            GetDmlDataTypeFromMlDataType(dataType),
+            queryKeyTransposedTensorShape);
+        DML_TENSOR_DESC namedQueryKeyTransposedOutputTensorDesc = queryKeyTransposedOutputTensorDesc.GetDmlDesc();
+
+        // Transpose QKV from [batchSize, sequenceLength, 3, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 3, headSize]
+        std::array<uint32_t, 5> queryKeyValueTransposedTensorShape = {batchSize, sequenceLength, numHeads, 3, headSize};
+        std::array<uint32_t, 5> queryKeyValueTransposedStrides = {
+            sequenceLength * numHeads * 3 * headSize,
+            numHeads * 3 * headSize,
+            headSize,
+            numHeads * headSize,
+            1,
+        };
+
+        TensorDesc queryKeyValueTransposedInputTensorDesc = TensorDesc(
+            GetDmlDataTypeFromMlDataType(dataType),
+            queryKeyValueTransposedTensorShape,
+            queryKeyValueTransposedStrides);
+        DML_TENSOR_DESC namedQueryKeyValueTransposedInputTensorDesc = queryKeyValueTransposedInputTensorDesc.GetDmlDesc();
+
+        TensorDesc queryKeyValueTransposedOutputTensorDesc = TensorDesc(
+            GetDmlDataTypeFromMlDataType(dataType),
+            queryKeyValueTransposedTensorShape);
+        DML_TENSOR_DESC namedQueryKeyValueTransposedOutputTensorDesc = queryKeyValueTransposedOutputTensorDesc.GetDmlDesc();
+
+        std::array<uint32_t, 3> queryKeySliceOffset = {0, 0, 0};
+        std::array<uint32_t, 3> queryKeySliceSize = {batchSize, sequenceLength, hiddenSize + hiddenSize};
+        std::array<int32_t, 3> queryKeySliceStrides = {1, 1, 1};
+
+        std::array<uint32_t, 3> valueSliceOffset = {0, 0, 2 * hiddenSize};
+        std::array<uint32_t, 3> valueSliceSize = {batchSize, sequenceLength, hiddenSize};
+        std::array<int32_t, 3> valueSliceStrides = {1, 1, 1};
+
+        // When Q/K/V all have the same hidden size, we just have to transpose it before sending it to MHA
+        DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposeOperatorDesc = {};
+
+        transposeOperatorDesc.InputTensor = &namedQueryKeyValueTransposedInputTensorDesc;
+        transposeOperatorDesc.OutputTensor = &namedQueryKeyValueTransposedOutputTensorDesc;
+
+        const DML_OPERATOR_DESC transposedDesc = { DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposeOperatorDesc};
+
+        std::array<uint32_t, 4> maskSliceOutputShape = {batchSize, numHeads, sequenceLength, sequenceLength};
+        std::array<int32_t, 4> maskSliceStrides = {1, 1, 1, 1};
+        std::array<uint32_t, 4> maskSliceOffsets = {0, 0, 0, 0};
+        TensorDesc maskSliceOutputTensorDesc;
+        DML_TENSOR_DESC namedMaskSliceOutputTensorDesc;
+
+        DML_SLICE1_OPERATOR_DESC maskSlicedOperatorDesc = {};
+        if (hasMaxSequenceMask)
+        {
+            maskSliceOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(maskTensorDataType, maskSliceOutputShape);
+            namedMaskSliceOutputTensorDesc = maskSliceOutputTensorDesc.GetDmlDesc();
+            maskSlicedOperatorDesc.InputTensor = &inputDescs[maskIndex];
+            maskSlicedOperatorDesc.OutputTensor = &namedMaskSliceOutputTensorDesc;
+            maskSlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(maskSliceOutputShape.size());
+            maskSlicedOperatorDesc.InputWindowOffsets = maskSliceOffsets.data();
+            maskSlicedOperatorDesc.InputWindowSizes = maskSliceOutputShape.data();
+            maskSlicedOperatorDesc.InputWindowStrides = maskSliceStrides.data();
+        }
+        const DML_OPERATOR_DESC maskSlicedDesc = { DML_OPERATOR_SLICE1, &maskSlicedOperatorDesc};
+
+        // We need to slice Past to get PastValue and PastKey tensors for MHA
+        std::array<uint32_t, 5> pastKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize};
+        std::array<int32_t, 5> pastKeyStrides = {1, 1, 1, 1, 1};
+        std::array<uint32_t, 5> pastKeyOffsets = {0, 0, 0, 0, 0};
+        TensorDesc pastKeyOutputTensorDesc;
+        DML_TENSOR_DESC namedPastKeyOutputTensorDesc;
+
+        std::array<uint32_t, 5> pastValueOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize};
+        std::array<int32_t, 5> pastValueStrides = {1, 1, 1, 1, 1};
+        std::array<uint32_t, 5> pastValueOffsets = {1, 0, 0, 0, 0};
+        TensorDesc pastValueOutputTensorDesc;
+        DML_TENSOR_DESC namedPastValueOutputTensorDesc;
+
+        DML_SLICE1_OPERATOR_DESC pastKeySlicedOperatorDesc = {};
+        DML_SLICE1_OPERATOR_DESC pastValueSlicedOperatorDesc = {};
+
+        if (hasPast)
+        {
+            pastKeyOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastKeyOutputShape);
+            namedPastKeyOutputTensorDesc = pastKeyOutputTensorDesc.GetDmlDesc();
+            pastKeySlicedOperatorDesc.InputTensor = &inputDescs[pastIndex];
+            pastKeySlicedOperatorDesc.OutputTensor = &namedPastKeyOutputTensorDesc;
+            pastKeySlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(pastKeyOutputShape.size());
+            pastKeySlicedOperatorDesc.InputWindowOffsets = pastKeyOffsets.data();
+            pastKeySlicedOperatorDesc.InputWindowSizes = pastKeyOutputShape.data();
+            pastKeySlicedOperatorDesc.InputWindowStrides = pastKeyStrides.data();
+
+            pastValueOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastValueOutputShape);
+            namedPastValueOutputTensorDesc = pastValueOutputTensorDesc.GetDmlDesc();
+            pastValueSlicedOperatorDesc.InputTensor = &inputDescs[pastIndex];
+            pastValueSlicedOperatorDesc.OutputTensor = &namedPastValueOutputTensorDesc;
+            pastValueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(pastValueOutputShape.size());
+            pastValueSlicedOperatorDesc.InputWindowOffsets = pastValueOffsets.data();
+            pastValueSlicedOperatorDesc.InputWindowSizes = pastValueOutputShape.data();
+            pastValueSlicedOperatorDesc.InputWindowStrides = pastValueStrides.data();
+        }
+
+        const DML_OPERATOR_DESC pastKeySlicedDesc = { DML_OPERATOR_SLICE1, &pastKeySlicedOperatorDesc};
+        const DML_OPERATOR_DESC pastValueSlicedDesc = { DML_OPERATOR_SLICE1, &pastValueSlicedOperatorDesc};
+
+        // Causal Mask: Upper Triangular Boolean Matrix
+        // Example: [[1, 0, 0, 0, 0],
+        //           [1, 1, 0, 0, 0],
+        //           [1, 1, 1, 0, 0],
+        //           [1, 1, 1, 1, 0]]
+        // DML adds maskFilterValue to the "off" bits in the mask and sets the "on" bits to 0
+        // passed to MHA as maskIndex Tensor when unidirectional == 1
+        std::array<uint32_t, 4> causalMaskOutputShape = {1, 1,  sequenceLength, pastSequenceLength + sequenceLength};
+        TensorDesc causalMaskTensorDesc;
+        DML_DIAGONAL_MATRIX1_OPERATOR_DESC causalMaskOperatorDesc = {};
+        DML_TENSOR_DESC namedcausalMaskTensorDesc;
+
+        if (unidirectional && !hasMask)
+        {
+            causalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, causalMaskOutputShape);
+            namedcausalMaskTensorDesc = causalMaskTensorDesc.GetDmlDesc();
+            causalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32;
+            causalMaskOperatorDesc.DiagonalFillBegin = INT32_MIN;
+            causalMaskOperatorDesc.DiagonalFillEnd = pastSequenceLength + 1;
+            causalMaskOperatorDesc.Value.Int32 = 1;
+            causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc;
+            maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN;
+        }
+        DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_DIAGONAL_MATRIX1, &causalMaskOperatorDesc };
+
+        DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {};
+        std::array<uint32_t, 5> presentKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize};
+        std::array<uint32_t, 5> presentValueOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize};
+        TensorDesc presentKeyTensorDesc;
+        TensorDesc presentValueTensorDesc;
+        DML_TENSOR_DESC namedPresentKeyOutputTensorDesc;
+        DML_TENSOR_DESC namedPresentValueOutputTensorDesc;
+
+        mhaOperatorDesc.StackedQueryKeyValueTensor = &namedQueryKeyValueTransposedOutputTensorDesc;
+
+        // Broadcast to MHA MaskTensor Shape
+        std::array<uint32_t, 4> mhaMaskTensorShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength};
+        TensorDesc broadcastedcausalMaskTensorDesc;
+        DML_TENSOR_DESC namedbroadcastedcausalMaskTensorDesc;
+        if (unidirectional && !hasMask)
+        {
+            broadcastedcausalMaskTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(MLOperatorTensorDataType::Int32, mhaMaskTensorShape, causalMaskOutputShape);
+            namedbroadcastedcausalMaskTensorDesc = broadcastedcausalMaskTensorDesc.GetDmlDesc();
+            mhaOperatorDesc.MaskTensor = &namedbroadcastedcausalMaskTensorDesc;
+        }
+        else if (hasMaxSequenceMask)
+        {
+            mhaOperatorDesc.MaskTensor = &namedMaskSliceOutputTensorDesc;
+        }
+        else
+        {
+            mhaOperatorDesc.MaskTensor = hasMask ? &inputDescs[maskIndex] : nullptr;
+        }
+
+        mhaOperatorDesc.RelativePositionBiasTensor = nullptr;
+        mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex];
+        mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Scale, gsl::narrow_cast<float>(1.0f / std::sqrt(headSize)));
+        // Set MaskFilterValue to lowest float for Causal Mask 
+        mhaOperatorDesc.MaskFilterValue = unidirectional ? std::numeric_limits<float>::lowest() :
+            kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, -10'000.0f);
+        mhaOperatorDesc.HeadCount = numHeads;
+        mhaOperatorDesc.MaskType = maskType;
+        if (hasPast)
+        {
+            presentKeyTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentKeyOutputShape);
+            namedPresentKeyOutputTensorDesc = presentKeyTensorDesc.GetDmlDesc();
+            presentValueTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentValueOutputShape);
+            namedPresentValueOutputTensorDesc = presentValueTensorDesc.GetDmlDesc();
+            mhaOperatorDesc.PastKeyTensor = &namedPastKeyOutputTensorDesc;
+            mhaOperatorDesc.PastValueTensor = &namedPastValueOutputTensorDesc;
+            mhaOperatorDesc.OutputPresentKeyTensor = &namedPresentKeyOutputTensorDesc;
+            mhaOperatorDesc.OutputPresentValueTensor = &namedPresentValueOutputTensorDesc;
+        }
+
+        const DML_OPERATOR_DESC mhaDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION, &mhaOperatorDesc };
+
+        DML_JOIN_OPERATOR_DESC presentKeyValueJoinOperatorDesc = {};
+        std::vector<DML_TENSOR_DESC> joinInputDesc;
+
+        if (hasPast)
+        {
+            joinInputDesc.push_back(namedPresentKeyOutputTensorDesc);
+            joinInputDesc.push_back(namedPresentValueOutputTensorDesc);
+            presentKeyValueJoinOperatorDesc.InputCount = gsl::narrow_cast<uint32_t>(joinInputDesc.size());
+            presentKeyValueJoinOperatorDesc.InputTensors = joinInputDesc.data();
+            presentKeyValueJoinOperatorDesc.OutputTensor = &outputDescs[presentIndex];
+            presentKeyValueJoinOperatorDesc.Axis = gsl::narrow_cast<uint32_t>(0);
+        }
+
+        DML_OPERATOR_DESC presentKeyValueJoinDesc = { DML_OPERATOR_JOIN, &presentKeyValueJoinOperatorDesc };
+
+        // Construct the graph
+        std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
+        std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
+        std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
+
+        std::vector<const DML_OPERATOR_DESC*> opDescs = {
+            &matMulIntToFloatDesc,
+            &mhaDesc,
+        };
+
+        uint32_t currentNodeIndex = 0;
+        const uint32_t matMulIntToFloatNodeIndex = currentNodeIndex++;
+        const uint32_t mhaNodeIndex = currentNodeIndex++;
+
+        uint32_t queryKeyValueTransposedNodeIndex = 0;
+
+        opDescs.push_back(&transposedDesc);
+        queryKeyValueTransposedNodeIndex = currentNodeIndex++;
+
+        uint32_t maskSliceNodeIndex = 0;
+        if (hasMaxSequenceMask)
+        {
+            opDescs.push_back(&maskSlicedDesc);
+            maskSliceNodeIndex = currentNodeIndex++;
+        }
+
+        uint32_t pastKeySliceNodeIndex = 0;
+        uint32_t pastValueSliceNodeIndex = 0;
+        uint32_t concatNodeIndex = 0;
+        if (hasPast)
+        {
+            opDescs.push_back(&pastKeySlicedDesc);
+            pastKeySliceNodeIndex = currentNodeIndex++;
+            opDescs.push_back(&pastValueSlicedDesc);
+            pastValueSliceNodeIndex = currentNodeIndex++;
+            opDescs.push_back(&presentKeyValueJoinDesc);
+            concatNodeIndex = currentNodeIndex++;
+        }
+
+        uint32_t causalMaskNodeIndex = 0;
+        if (unidirectional && !hasMask)
+        {
+            opDescs.push_back(&causalMaskDesc);
+            causalMaskNodeIndex = currentNodeIndex++;
+        }
+
+        DML_INPUT_GRAPH_EDGE_DESC inputToMatMulIntToFloatEdge = {};
+        inputToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputIndex;
+        inputToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
+        inputToMatMulIntToFloatEdge.ToNodeInputIndex = 0;
+        inputEdges.push_back(inputToMatMulIntToFloatEdge);
+
+        DML_INPUT_GRAPH_EDGE_DESC inputScaleToMatMulIntToFloatEdge = {};
+        inputScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputScaleIndex;
+        inputScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
+        inputScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 1;
+        inputEdges.push_back(inputScaleToMatMulIntToFloatEdge);
+
+        DML_INPUT_GRAPH_EDGE_DESC inputZeroPointToMatMulIntToFloatEdge = {};
+        inputZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputZeroPointIndex;
+        inputZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
+        inputZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 2;
+        inputEdges.push_back(inputZeroPointToMatMulIntToFloatEdge);
+
+        DML_INPUT_GRAPH_EDGE_DESC weightToMatMulIntToFloatEdge = {};
+        weightToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightsIndex;
+        weightToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
+        weightToMatMulIntToFloatEdge.ToNodeInputIndex = 3;
+        inputEdges.push_back(weightToMatMulIntToFloatEdge);
+
+        DML_INPUT_GRAPH_EDGE_DESC weightScaleToMatMulIntToFloatEdge = {};
+        weightScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightScaleIndex;
+        weightScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
+        weightScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 4;
+        inputEdges.push_back(weightScaleToMatMulIntToFloatEdge);
+
+        DML_INPUT_GRAPH_EDGE_DESC weightZeroPointToMatMulIntToFloatEdge = {};
+        weightZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightZeroPointIndex;
+        weightZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
+        weightZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 5;
+        inputEdges.push_back(weightZeroPointToMatMulIntToFloatEdge);
+
+        if (hasBias)
+        {
+            DML_INPUT_GRAPH_EDGE_DESC biasToMatMulIntToFloatEdge = {};
+            biasToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::biasIndex;
+            biasToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
+            biasToMatMulIntToFloatEdge.ToNodeInputIndex = 6;
+            inputEdges.push_back(biasToMatMulIntToFloatEdge);
+        }
+
+        if (hasMask)
+        {
+            if (hasUnpaddedBounds)
+            {
+                DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {};
+                maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex;
+                maskToMhaEdge.ToNodeIndex = mhaNodeIndex;
+                maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
+                inputEdges.push_back(maskToMhaEdge);
+            }
+            else if (hasMaxSequenceMask)
+            {
+                DML_INPUT_GRAPH_EDGE_DESC maskToMaskSliceEdge = {};
+                maskToMaskSliceEdge.GraphInputIndex = InputIndex::maskIndex;
+                maskToMaskSliceEdge.ToNodeIndex = maskSliceNodeIndex;
+                maskToMaskSliceEdge.ToNodeInputIndex = 0;
+                inputEdges.push_back(maskToMaskSliceEdge);
+
+                DML_INTERMEDIATE_GRAPH_EDGE_DESC maskSliceToMhaEdge = {};
+                maskSliceToMhaEdge.FromNodeIndex = maskSliceNodeIndex;
+                maskSliceToMhaEdge.FromNodeOutputIndex = 0;
+                maskSliceToMhaEdge.ToNodeIndex = mhaNodeIndex;
+                maskSliceToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
+                intermediateEdges.push_back(maskSliceToMhaEdge);
+            }
+            else
+            {
+                DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {};
+                maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex;
+                maskToMhaEdge.ToNodeIndex = mhaNodeIndex;
+                maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
+                inputEdges.push_back(maskToMhaEdge);
+            }
+        }
+        else if (unidirectional)
+        {
+            DML_INTERMEDIATE_GRAPH_EDGE_DESC causalMaskToMhaEdge = {};
+            causalMaskToMhaEdge.FromNodeIndex = causalMaskNodeIndex;
+            causalMaskToMhaEdge.FromNodeOutputIndex = 0;
+            causalMaskToMhaEdge.ToNodeIndex = mhaNodeIndex;
+            causalMaskToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
+            intermediateEdges.push_back(causalMaskToMhaEdge);
+        }
+
+        if (hasPast)
+        {
+            DML_INPUT_GRAPH_EDGE_DESC pastToPastKeySliceEdge = {};
+            pastToPastKeySliceEdge.GraphInputIndex = InputIndex::pastIndex;
+            pastToPastKeySliceEdge.ToNodeIndex = pastKeySliceNodeIndex;
+            pastToPastKeySliceEdge.ToNodeInputIndex = 0;
+            inputEdges.push_back(pastToPastKeySliceEdge);
+
+            DML_INPUT_GRAPH_EDGE_DESC pastToPastValueSliceEdge = {};
+            pastToPastValueSliceEdge.GraphInputIndex = InputIndex::pastIndex;
+            pastToPastValueSliceEdge.ToNodeIndex = pastValueSliceNodeIndex;
+            pastToPastValueSliceEdge.ToNodeInputIndex = 0;
+            inputEdges.push_back(pastToPastValueSliceEdge);
+
+            DML_INTERMEDIATE_GRAPH_EDGE_DESC pastKeyToMhaEdge = {};
+            pastKeyToMhaEdge.FromNodeIndex = pastKeySliceNodeIndex;
+            pastKeyToMhaEdge.FromNodeOutputIndex = 0;
+            pastKeyToMhaEdge.ToNodeIndex = mhaNodeIndex;
+            pastKeyToMhaEdge.ToNodeInputIndex = mhaPastKeyIndex;
+            intermediateEdges.push_back(pastKeyToMhaEdge);
+
+            DML_INTERMEDIATE_GRAPH_EDGE_DESC pastValueToMhaEdge = {};
+            pastValueToMhaEdge.FromNodeIndex = pastValueSliceNodeIndex;
+            pastValueToMhaEdge.FromNodeOutputIndex = 0;
+            pastValueToMhaEdge.ToNodeIndex = mhaNodeIndex;
+            pastValueToMhaEdge.ToNodeInputIndex = mhaPastValueIndex;
+            intermediateEdges.push_back(pastValueToMhaEdge);
+
+            DML_INTERMEDIATE_GRAPH_EDGE_DESC presentKeyToConcatEdge = {};
+            presentKeyToConcatEdge.FromNodeIndex = mhaNodeIndex;
+            presentKeyToConcatEdge.FromNodeOutputIndex = mhaPresentKeyIndex;
+            presentKeyToConcatEdge.ToNodeIndex = concatNodeIndex;
+            presentKeyToConcatEdge.ToNodeInputIndex = 0;
+            intermediateEdges.push_back(presentKeyToConcatEdge);
+
+            DML_INTERMEDIATE_GRAPH_EDGE_DESC presentValueToConcatEdge = {};
+            presentValueToConcatEdge.FromNodeIndex = mhaNodeIndex;
+            presentValueToConcatEdge.FromNodeOutputIndex = mhaPresentValueIndex;
+            presentValueToConcatEdge.ToNodeIndex = concatNodeIndex;
+            presentValueToConcatEdge.ToNodeInputIndex = 1;
+            intermediateEdges.push_back(presentValueToConcatEdge);
+        }
+
+        DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeyValueTransposeEdge = {};
+        matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeIndex = matMulIntToFloatNodeIndex;
+        matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0;
+        matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex;
+        matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0;
+        intermediateEdges.push_back(matMulIntToFloatToQueryKeyValueTransposeEdge);
+
+        // All we need to do here is transpose the stacked QKV tensor into something DML supports
+        DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyValueTransposedToMhaEdge = {};
+        queryKeyValueTransposedToMhaEdge.FromNodeIndex = queryKeyValueTransposedNodeIndex;
+        queryKeyValueTransposedToMhaEdge.FromNodeOutputIndex = 0;
+        queryKeyValueTransposedToMhaEdge.ToNodeIndex = mhaNodeIndex;
+        queryKeyValueTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyValueIndex;
+        intermediateEdges.push_back(queryKeyValueTransposedToMhaEdge);
+
+        DML_OUTPUT_GRAPH_EDGE_DESC mhaToOutputEdge = {};
+        mhaToOutputEdge.FromNodeIndex = mhaNodeIndex;
+        mhaToOutputEdge.FromNodeOutputIndex = mhaOutputIndex;
+        mhaToOutputEdge.GraphOutputIndex = OutputIndex::outputIndex;
+        outputEdges.push_back(mhaToOutputEdge);
+
+        if (hasPast)
+        {
+            DML_OUTPUT_GRAPH_EDGE_DESC concatToOutputEdge = {};
+            concatToOutputEdge.FromNodeIndex = concatNodeIndex;
+            concatToOutputEdge.FromNodeOutputIndex = 0;
+            concatToOutputEdge.GraphOutputIndex = OutputIndex::presentIndex;
+            outputEdges.push_back(concatToOutputEdge);
+        }
+
+        MLOperatorGraphDesc operatorGraphDesc = {};
+        operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
+        operatorGraphDesc.inputEdges = inputEdges.data();
+        operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
+        operatorGraphDesc.intermediateEdges = intermediateEdges.data();
+        operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
+        operatorGraphDesc.outputEdges = outputEdges.data();
+        operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
+        operatorGraphDesc.nodes = opDescs.data();
+
+        SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
+    }
+};
+
+void CALLBACK QueryQAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported)
+{
+    *isSupported = false;
+
+    // `unidirectional == 1` with Mask Tensor is not supported yet
+    MLOperatorAttributes attributes(context);
+    if (attributes.GetOptionalAttribute<int32_t>(AttrName::Unidirectional, 0) != 0 && context->IsInputValid(5))
+    {
+        return;
+    }
+
+    // `do_rotary == 1` is not supported yet
+    if (attributes.GetOptionalAttribute<int32_t>(AttrName::DoRotary, 0) != 0)
+    {
+        return;
+    }
+
+    // `past_present_share_buffer == 1` is not supported yet
+    if (attributes.GetOptionalAttribute<int32_t>(AttrName::PastPresentShareBuffer, 0) != 0)
+    {
+        return;
+    }
+
+    *isSupported = true;
+}
+
+DML_OP_DEFINE_CREATION_FUNCTION(QAttention, DmlOperatorQAttention);
+}  // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp
index c97b03dc36b62..8727610ff3112 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp
@@ -166,7 +166,7 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper
 
         MLOperatorGraphDesc operatorGraphDesc = {};
         operatorGraphDesc.nodeCount = static_cast<uint32_t>(opDescs.size());
-        operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+        operatorGraphDesc.nodes = opDescs.data();
 
         uint32_t joinNodeIndex = operatorGraphDesc.nodeCount - 2;
         uint32_t quantizeNodeIndex = operatorGraphDesc.nodeCount - 1;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp
index 35f926d62c92a..bc0082fef3496 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp
@@ -113,7 +113,7 @@ class DmlOperatorQLinearSigmoid : public DmlOperator
         MLOperatorGraphDesc operatorGraphDesc = {};
         operatorGraphDesc.nodeCount = 3;
         std::vector<const DML_OPERATOR_DESC*> opDescs{&opDesc1, &opDesc2, &opDesc3};
-        operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+        operatorGraphDesc.nodes = opDescs.data();
 
         // set input edges
         std::pair<uint32_t, uint32_t> nodeToNodeInputIndex[5] {{0, 0}, {0, 1}, {0, 2}, {2, 1}, {2, 2}};
@@ -178,4 +178,4 @@ void CALLBACK QueryQLinearSigmoid(IMLOperatorSupportQueryContextPrivate* context
 }
 
 DML_OP_DEFINE_CREATION_FUNCTION(QLinearSigmoid, DmlOperatorQLinearSigmoid);
-} // namespace Dml
+} // namespace Dml
\ No newline at end of file
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp
index 3683ab7b0b0b3..e62b7d707ba78 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp
@@ -123,7 +123,7 @@ class DmlOperatorQuickGelu : public DmlOperator
         operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
         operatorGraphDesc.outputEdges = outputEdges.data();
         operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
-        operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+        operatorGraphDesc.nodes = opDescs.data();
         SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
     }
 };
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp
index 30c339b845b36..0f15ebf342b3a 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp
@@ -43,6 +43,10 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
         ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4);
         ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
 
+        // When the input is 4D, it has the shape [batchSize, numHeads, sequenceLength, headSize]. Otherwise,
+        // it has the shape [batchSize, sequenceLength, hiddenSize]
+        const bool inputIs4D = kernelInfo.GetInputTensorDimensionCount(inputDataIndex) == 4;
+
         // When positionIds is a scalar, it represents the start offset for each sequence
         const bool positionIdsIsOffset = kernelInfo.GetInputTensorDimensionCount(positionIdsIndex) == 1;
 
@@ -63,9 +67,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
 
         // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize]
         const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes();
-        const uint32_t batchSize = inputDataSizes[1];
+        const uint32_t batchSize = inputIs4D ? inputDataSizes[0] : inputDataSizes[1];
         const uint32_t sequenceLength = inputDataSizes[2];
-        const uint32_t numHeads = inputDataSizes[3] / headSize;
+        const uint32_t numHeads = inputIs4D ? inputDataSizes[1] : inputDataSizes[3] / headSize;
 
         const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes();
         const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2];
@@ -80,16 +84,24 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
         std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
         const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType;
 
-        // Splitting the hiddenSize into numHeads and headSize dimensions makes it easier for DML to handle
         const std::array<uint32_t, 4> inputOutputShape = {batchSize, sequenceLength, numHeads, headSize};
         TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape);
+        TensorDesc stridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape);
+
+        if (inputIs4D)
+        {
+            const std::array<uint32_t, 4> inputOutputStrides = {headSize * numHeads * sequenceLength, headSize, sequenceLength * headSize, 1};
+            stridedInputOutputTensorDesc.SetStrides(inputOutputStrides);
+        }
+
         const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc();
+        const DML_TENSOR_DESC stridedInputOutputDmlTensorDesc = stridedInputOutputTensorDesc.GetDmlDesc();
 
         // Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase.
         DML_SCALE_BIAS scaleBias = {1.0f, 0.0f};
 
         DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{};
-        copyInputDesc.InputTensor = &inputOutputDmlTensorDesc;
+        copyInputDesc.InputTensor = &stridedInputOutputDmlTensorDesc;
         copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc;
         copyInputDesc.ScaleBias = &scaleBias;
         const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &copyInputDesc};
@@ -104,8 +116,12 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
             : std::vector<uint32_t>({batchSize, sequenceLength, numHeads, 1, headSize / 2});
 
         TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape);
+
         const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc();
 
+        TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape);
+        const DML_TENSOR_DESC joinedDataDmlTensorDesc = joinedDataTensorDesc.GetDmlDesc();
+
         TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape);
         const std::array<DML_TENSOR_DESC, 2> splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()};
 
@@ -122,7 +138,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
         // Swap the 2 halves and join them together
         DML_JOIN_OPERATOR_DESC joinInputDesc{};
         joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data();
-        joinInputDesc.OutputTensor = &inputDataDmlTensorDesc;
+        joinInputDesc.OutputTensor = &joinedDataDmlTensorDesc;
         joinInputDesc.Axis = splitInputDesc.Axis;
         joinInputDesc.InputCount = gsl::narrow_cast<uint32_t>(splitInputDataDmlTensorDescs.size());
         const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc};
@@ -212,23 +228,23 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
         const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc();
 
         DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{};
-        mulSignDesc.ATensor = &inputDataDmlTensorDesc;
+        mulSignDesc.ATensor = &joinedDataDmlTensorDesc;
         mulSignDesc.BTensor = &broadcastedSignDmlTensorDesc;
-        mulSignDesc.OutputTensor = &inputDataDmlTensorDesc;
+        mulSignDesc.OutputTensor = &joinedDataDmlTensorDesc;
         const DML_OPERATOR_DESC mulSignDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulSignDesc};
 
         // Multiply the non-rotated data with the cos and the rotated data with the sin
         DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulCosSinDesc{};
-        mulCosSinDesc.ATensor = &inputDataDmlTensorDesc;
+        mulCosSinDesc.ATensor = &joinedDataDmlTensorDesc;
         mulCosSinDesc.BTensor = &broadcastedCosSinDmlTensorDesc;
-        mulCosSinDesc.OutputTensor = &inputDataDmlTensorDesc;
+        mulCosSinDesc.OutputTensor = &joinedDataDmlTensorDesc;
         const DML_OPERATOR_DESC mulCosSinDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulCosSinDesc};
 
         // Add the multiplied cos and sin values together
         DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{};
         addDesc.ATensor = &inputOutputDmlTensorDesc;
         addDesc.BTensor = &inputOutputDmlTensorDesc;
-        addDesc.OutputTensor = &inputOutputDmlTensorDesc;
+        addDesc.OutputTensor = &stridedInputOutputDmlTensorDesc;
         const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc};
 
         // Construct the graph
@@ -425,7 +441,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
         operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
         operatorGraphDesc.outputEdges = outputEdges.data();
         operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
-        operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+        operatorGraphDesc.nodes = opDescs.data();
 
         SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo);
     }
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp
index 4dafd78f21ea8..094c45a0e38e5 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp
@@ -198,7 +198,7 @@ class DmlOperatorSkipLayerNormalization : public DmlOperator
         operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
         operatorGraphDesc.outputEdges = outputEdges.data();
         operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
-        operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+        operatorGraphDesc.nodes = opDescs.data();
 
         SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
     }
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham.h
index 9c03b7f6de639..1bfd6e6c6068d 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham.h
@@ -21,7 +21,7 @@ dcl_uav_structured u0, 4
 dcl_uav_structured u1, 4
 dcl_uav_structured u2, 4
 dcl_input vThreadID.x
-dcl_temps 6
+dcl_temps 5
 dcl_thread_group 64, 1, 1
 iadd r0.x, vThreadID.x, cb0[0].x
 ult r0.y, r0.x, cb0[0].y
@@ -40,66 +40,57 @@ if_nz r0.y
   ieq r1.y, cb0[7].x, l(1)
   ult r1.z, r0.w, cb0[5].z
   and r1.z, r1.z, r1.y
-  if_nz r1.z
-    imul null, r1.z, r0.w, cb0[6].z
-    ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r1.z, l(0), u2.xxxx
-    imad r1.z, r0.w, cb0[6].z, cb0[6].w
-    ieq r1.w, cb0[5].w, l(2)
-    ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r1.z, r1.z, l(0), u2.xxxx
-    and r4.y, r1.z, r1.w
+  imul null, r1.w, r0.w, cb0[6].z
+  ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r1.w, l(0), u2.xxxx
+  ieq r1.w, cb0[5].w, l(2)
+  if_nz r1.w
+    imad r2.y, r0.w, cb0[6].z, cb0[6].w
+    ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.y, r2.y, l(0), u2.xxxx
   else 
-    mov r4.xy, l(1.000000,0,0,0)
+    mov r4.y, l(0)
   endif 
+  movc r2.yz, r1.zzzz, r4.yyxy, l(0,0,1.000000,0)
   ult r1.z, r0.w, cb0[1].y
-  if_nz r1.z
-    imul null, r0.w, r0.w, cb0[2].y
-    imad r0.w, r1.x, cb0[2].x, r0.w
-    imad r0.w, r3.x, cb0[2].z, r0.w
-    ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r5.x, r0.w, l(0), u0.xxxx
-    ieq r1.z, cb0[1].w, l(2)
-    if_nz r1.z
-      iadd r0.w, r0.w, cb0[2].w
-      ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r5.y, r0.w, l(0), u0.xxxx
-    else 
-      mov r5.y, l(0)
-    endif 
+  imul null, r1.x, r1.x, cb0[2].x
+  imad r0.w, r0.w, cb0[2].y, r1.x
+  imad r0.w, r3.x, cb0[2].z, r0.w
+  ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r0.w, l(0), u0.xxxx
+  ieq r2.w, cb0[1].w, l(2)
+  if_nz r2.w
+    iadd r0.w, r0.w, cb0[2].w
+    ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.y, r0.w, l(0), u0.xxxx
   else 
-    mov r5.xy, l(0,0,0,0)
+    mov r4.y, l(0)
   endif 
-  mul r0.w, r4.y, r5.y
-  mad r0.w, r5.x, r4.x, -r0.w
-  dp2 r1.z, r5.yxyy, r4.xyxx
-  ult r1.w, r0.y, cb0[5].z
-  and r1.y, r1.w, r1.y
-  if_nz r1.y
-    imul null, r1.y, r0.y, cb0[6].z
-    ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r1.y, l(0), u2.xxxx
-    imad r1.y, r0.y, cb0[6].z, cb0[6].w
-    ieq r1.w, cb0[5].w, l(2)
-    ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r1.y, r1.y, l(0), u2.xxxx
-    and r4.y, r1.y, r1.w
+  and r3.yz, r1.zzzz, r4.xxyx
+  mul r0.w, r2.y, r3.z
+  mad r0.w, r3.y, r2.z, -r0.w
+  dp2 r1.z, r3.yzyy, r2.yzyy
+  ult r2.y, r0.y, cb0[5].z
+  and r1.y, r1.y, r2.y
+  imul null, r2.y, r0.y, cb0[6].z
+  ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r2.y, l(0), u2.xxxx
+  if_nz r1.w
+    imad r1.w, r0.y, cb0[6].z, cb0[6].w
+    ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.y, r1.w, l(0), u2.xxxx
   else 
-    mov r4.xy, l(1.000000,0,0,0)
+    mov r4.y, l(0)
   endif 
-  ult r1.y, r0.y, cb0[1].y
-  if_nz r1.y
-    imul null, r0.y, r0.y, cb0[2].y
-    imad r0.y, r1.x, cb0[2].x, r0.y
-    imad r0.y, r3.x, cb0[2].z, r0.y
-    ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r1.x, r0.y, l(0), u0.xxxx
-    ieq r1.w, cb0[1].w, l(2)
-    if_nz r1.w
-      iadd r0.y, r0.y, cb0[2].w
-      ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r1.y, r0.y, l(0), u0.xxxx
-    else 
-      mov r1.y, l(0)
-    endif 
+  movc r1.yw, r1.yyyy, r4.yyyx, l(0,0,0,1.000000)
+  ult r2.y, r0.y, cb0[1].y
+  imad r0.y, r0.y, cb0[2].y, r1.x
+  imad r0.y, r3.x, cb0[2].z, r0.y
+  ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r3.x, r0.y, l(0), u0.xxxx
+  if_nz r2.w
+    iadd r0.y, r0.y, cb0[2].w
+    ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r3.y, r0.y, l(0), u0.xxxx
   else 
-    mov r1.xy, l(0,0,0,0)
+    mov r3.y, l(0)
   endif 
-  mul r0.y, r4.y, r1.y
-  mad r0.y, r1.x, r4.x, -r0.y
-  dp2 r1.x, r1.yxyy, r4.xyxx
+  and r2.yz, r2.yyyy, r3.xxyx
+  mul r0.y, r1.y, r2.z
+  mad r0.y, r2.y, r1.w, -r0.y
+  dp2 r1.x, r2.yzyy, r1.ywyy
   udiv null, r1.y, r2.x, r0.z
   ieq r1.w, cb0[0].w, l(1)
   movc r1.w, r1.w, l(6.283185), l(-6.283185)
@@ -117,17 +108,22 @@ if_nz r0.y
   mad r0.y, r3.x, r1.x, r0.y
   add r0.y, r0.y, r1.z
   mul r0.yw, r0.yyyw, cb0[7].zzzz
-  ne r1.x, cb0[7].y, l(0.000000)
-  mul r1.y, r1.y, r1.y
-  mul r1.y, r1.y, l(3.141593)
-  div r1.y, r1.y, cb0[7].y
-  sincos r2.x, r3.x, r1.y
-  mov r2.y, r3.x
-  movc r1.xy, r1.xxxx, r2.xyxx, l(0,1.000000,0,0)
-  mul r1.zw, r0.yyyy, r1.xxxy
-  mad r0.y, r0.w, r1.y, -r1.z
-  store_structured u1.x, r0.z, l(0), r0.y
-  mad r0.y, r0.w, r1.x, r1.w
+  eq r1.x, cb0[7].y, l(0.000000)
+  if_nz r1.x
+    mov r1.x, r0.w
+  else 
+    ne r1.z, cb0[7].y, l(0.000000)
+    mul r1.y, r1.y, r1.y
+    mul r1.y, r1.y, l(3.141593)
+    div r1.y, r1.y, cb0[7].y
+    sincos r2.x, r3.x, r1.y
+    mov r2.y, r3.x
+    movc r1.yz, r1.zzzz, r2.xxyx, l(0,0,1.000000,0)
+    mul r2.xy, r0.yyyy, r1.yzyy
+    mad r1.x, r0.w, r1.z, -r2.x
+    mad r0.y, r0.w, r1.y, r2.y
+  endif 
+  store_structured u1.x, r0.z, l(0), r1.x
   store_structured u1.x, r0.x, l(0), r0.y
 endif 
 ret 
@@ -136,11 +132,11 @@ ret
 
 const BYTE g_DFT[] =
 {
-     68,  88,  66,  67, 222, 156, 
-    188, 133, 179,  57, 118,  25, 
-    122, 216, 102,  13,  91, 242, 
-     99,  27,   1,   0,   0,   0, 
-    172,  12,   0,   0,   3,   0, 
+     68,  88,  66,  67,  63, 188, 
+    200, 227, 206,  73,  64,  21, 
+    140, 126,  47, 226, 169,  81, 
+    175, 134,   1,   0,   0,   0, 
+    112,  12,   0,   0,   3,   0, 
       0,   0,  44,   0,   0,   0, 
      60,   0,   0,   0,  76,   0, 
       0,   0,  73,  83,  71,  78, 
@@ -149,8 +145,8 @@ const BYTE g_DFT[] =
      79,  83,  71,  78,   8,   0, 
       0,   0,   0,   0,   0,   0, 
       8,   0,   0,   0,  83,  72, 
-     69,  88,  88,  12,   0,   0, 
-     80,   0,   5,   0,  22,   3, 
+     69,  88,  28,  12,   0,   0, 
+     80,   0,   5,   0,   7,   3, 
       0,   0, 106,   8,   0,   1, 
      89,   0,   0,   4,  70, 142, 
      32,   0,   0,   0,   0,   0, 
@@ -164,7 +160,7 @@ const BYTE g_DFT[] =
      17,   0,   2,   0,   0,   0, 
       4,   0,   0,   0,  95,   0, 
       0,   2,  18,   0,   2,   0, 
-    104,   0,   0,   2,   6,   0, 
+    104,   0,   0,   2,   5,   0, 
       0,   0, 155,   0,   0,   4, 
      64,   0,   0,   0,   1,   0, 
       0,   0,   1,   0,   0,   0, 
@@ -256,11 +252,9 @@ const BYTE g_DFT[] =
      16,   0,   1,   0,   0,   0, 
      42,   0,  16,   0,   1,   0, 
       0,   0,  26,   0,  16,   0, 
-      1,   0,   0,   0,  31,   0, 
-      4,   3,  42,   0,  16,   0, 
       1,   0,   0,   0,  38,   0, 
       0,   9,   0, 208,   0,   0, 
-     66,   0,  16,   0,   1,   0, 
+    130,   0,  16,   0,   1,   0, 
       0,   0,  58,   0,  16,   0, 
       0,   0,   0,   0,  42, 128, 
      32,   0,   0,   0,   0,   0, 
@@ -268,221 +262,203 @@ const BYTE g_DFT[] =
       0, 139,   2,  35,   0, 128, 
     131, 153,  25,   0,  18,   0, 
      16,   0,   4,   0,   0,   0, 
-     42,   0,  16,   0,   1,   0, 
+     58,   0,  16,   0,   1,   0, 
       0,   0,   1,  64,   0,   0, 
       0,   0,   0,   0,   6, 224, 
      17,   0,   2,   0,   0,   0, 
-     35,   0,   0,  11,  66,   0, 
+     32,   0,   0,   8, 130,   0, 
      16,   0,   1,   0,   0,   0, 
-     58,   0,  16,   0,   0,   0, 
-      0,   0,  42, 128,  32,   0, 
-      0,   0,   0,   0,   6,   0, 
-      0,   0,  58, 128,  32,   0, 
-      0,   0,   0,   0,   6,   0, 
-      0,   0,  32,   0,   0,   8, 
-    130,   0,  16,   0,   1,   0, 
-      0,   0,  58, 128,  32,   0, 
-      0,   0,   0,   0,   5,   0, 
-      0,   0,   1,  64,   0,   0, 
-      2,   0,   0,   0, 167,   0, 
+     58, 128,  32,   0,   0,   0, 
+      0,   0,   5,   0,   0,   0, 
+      1,  64,   0,   0,   2,   0, 
+      0,   0,  31,   0,   4,   3, 
+     58,   0,  16,   0,   1,   0, 
+      0,   0,  35,   0,   0,  11, 
+     34,   0,  16,   0,   2,   0, 
+      0,   0,  58,   0,  16,   0, 
+      0,   0,   0,   0,  42, 128, 
+     32,   0,   0,   0,   0,   0, 
+      6,   0,   0,   0,  58, 128, 
+     32,   0,   0,   0,   0,   0, 
+      6,   0,   0,   0, 167,   0, 
       0, 139,   2,  35,   0, 128, 
-    131, 153,  25,   0,  66,   0, 
-     16,   0,   1,   0,   0,   0, 
-     42,   0,  16,   0,   1,   0, 
+    131, 153,  25,   0,  34,   0, 
+     16,   0,   4,   0,   0,   0, 
+     26,   0,  16,   0,   2,   0, 
       0,   0,   1,  64,   0,   0, 
       0,   0,   0,   0,   6, 224, 
      17,   0,   2,   0,   0,   0, 
-      1,   0,   0,   7,  34,   0, 
-     16,   0,   4,   0,   0,   0, 
-     42,   0,  16,   0,   1,   0, 
-      0,   0,  58,   0,  16,   0, 
-      1,   0,   0,   0,  18,   0, 
-      0,   1,  54,   0,   0,   8, 
-     50,   0,  16,   0,   4,   0, 
+     18,   0,   0,   1,  54,   0, 
+      0,   5,  34,   0,  16,   0, 
+      4,   0,   0,   0,   1,  64, 
+      0,   0,   0,   0,   0,   0, 
+     21,   0,   0,   1,  55,   0, 
+      0,  12,  98,   0,  16,   0, 
+      2,   0,   0,   0, 166,  10, 
+     16,   0,   1,   0,   0,   0, 
+     86,   4,  16,   0,   4,   0, 
       0,   0,   2,  64,   0,   0, 
-      0,   0, 128,  63,   0,   0, 
       0,   0,   0,   0,   0,   0, 
-      0,   0,   0,   0,  21,   0, 
-      0,   1,  79,   0,   0,   8, 
-     66,   0,  16,   0,   1,   0, 
-      0,   0,  58,   0,  16,   0, 
-      0,   0,   0,   0,  26, 128, 
-     32,   0,   0,   0,   0,   0, 
-      1,   0,   0,   0,  31,   0, 
-      4,   3,  42,   0,  16,   0, 
-      1,   0,   0,   0,  38,   0, 
-      0,   9,   0, 208,   0,   0, 
-    130,   0,  16,   0,   0,   0, 
-      0,   0,  58,   0,  16,   0, 
-      0,   0,   0,   0,  26, 128, 
-     32,   0,   0,   0,   0,   0, 
-      2,   0,   0,   0,  35,   0, 
-      0,  10, 130,   0,  16,   0, 
-      0,   0,   0,   0,  10,   0, 
+      0,   0,   0,   0, 128,  63, 
+      0,   0,   0,   0,  79,   0, 
+      0,   8,  66,   0,  16,   0, 
+      1,   0,   0,   0,  58,   0, 
+     16,   0,   0,   0,   0,   0, 
+     26, 128,  32,   0,   0,   0, 
+      0,   0,   1,   0,   0,   0, 
+     38,   0,   0,   9,   0, 208, 
+      0,   0,  18,   0,  16,   0, 
+      1,   0,   0,   0,  10,   0, 
      16,   0,   1,   0,   0,   0, 
      10, 128,  32,   0,   0,   0, 
       0,   0,   2,   0,   0,   0, 
+     35,   0,   0,  10, 130,   0, 
+     16,   0,   0,   0,   0,   0, 
      58,   0,  16,   0,   0,   0, 
-      0,   0,  35,   0,   0,  10, 
-    130,   0,  16,   0,   0,   0, 
+      0,   0,  26, 128,  32,   0, 
+      0,   0,   0,   0,   2,   0, 
       0,   0,  10,   0,  16,   0, 
-      3,   0,   0,   0,  42, 128, 
+      1,   0,   0,   0,  35,   0, 
+      0,  10, 130,   0,  16,   0, 
+      0,   0,   0,   0,  10,   0, 
+     16,   0,   3,   0,   0,   0, 
+     42, 128,  32,   0,   0,   0, 
+      0,   0,   2,   0,   0,   0, 
+     58,   0,  16,   0,   0,   0, 
+      0,   0, 167,   0,   0, 139, 
+      2,  35,   0, 128, 131, 153, 
+     25,   0,  18,   0,  16,   0, 
+      4,   0,   0,   0,  58,   0, 
+     16,   0,   0,   0,   0,   0, 
+      1,  64,   0,   0,   0,   0, 
+      0,   0,   6, 224,  17,   0, 
+      0,   0,   0,   0,  32,   0, 
+      0,   8, 130,   0,  16,   0, 
+      2,   0,   0,   0,  58, 128, 
      32,   0,   0,   0,   0,   0, 
-      2,   0,   0,   0,  58,   0, 
+      1,   0,   0,   0,   1,  64, 
+      0,   0,   2,   0,   0,   0, 
+     31,   0,   4,   3,  58,   0, 
+     16,   0,   2,   0,   0,   0, 
+     30,   0,   0,   8, 130,   0, 
      16,   0,   0,   0,   0,   0, 
-    167,   0,   0, 139,   2,  35, 
-      0, 128, 131, 153,  25,   0, 
-     18,   0,  16,   0,   5,   0, 
-      0,   0,  58,   0,  16,   0, 
-      0,   0,   0,   0,   1,  64, 
-      0,   0,   0,   0,   0,   0, 
-      6, 224,  17,   0,   0,   0, 
-      0,   0,  32,   0,   0,   8, 
-     66,   0,  16,   0,   1,   0, 
+     58,   0,  16,   0,   0,   0, 
       0,   0,  58, 128,  32,   0, 
-      0,   0,   0,   0,   1,   0, 
-      0,   0,   1,  64,   0,   0, 
-      2,   0,   0,   0,  31,   0, 
-      4,   3,  42,   0,  16,   0, 
-      1,   0,   0,   0,  30,   0, 
-      0,   8, 130,   0,  16,   0, 
-      0,   0,   0,   0,  58,   0, 
+      0,   0,   0,   0,   2,   0, 
+      0,   0, 167,   0,   0, 139, 
+      2,  35,   0, 128, 131, 153, 
+     25,   0,  34,   0,  16,   0, 
+      4,   0,   0,   0,  58,   0, 
      16,   0,   0,   0,   0,   0, 
-     58, 128,  32,   0,   0,   0, 
-      0,   0,   2,   0,   0,   0, 
-    167,   0,   0, 139,   2,  35, 
-      0, 128, 131, 153,  25,   0, 
-     34,   0,  16,   0,   5,   0, 
-      0,   0,  58,   0,  16,   0, 
-      0,   0,   0,   0,   1,  64, 
-      0,   0,   0,   0,   0,   0, 
-      6, 224,  17,   0,   0,   0, 
-      0,   0,  18,   0,   0,   1, 
-     54,   0,   0,   5,  34,   0, 
-     16,   0,   5,   0,   0,   0, 
       1,  64,   0,   0,   0,   0, 
-      0,   0,  21,   0,   0,   1, 
-     18,   0,   0,   1,  54,   0, 
-      0,   8,  50,   0,  16,   0, 
-      5,   0,   0,   0,   2,  64, 
-      0,   0,   0,   0,   0,   0, 
-      0,   0,   0,   0,   0,   0, 
-      0,   0,   0,   0,   0,   0, 
-     21,   0,   0,   1,  56,   0, 
-      0,   7, 130,   0,  16,   0, 
-      0,   0,   0,   0,  26,   0, 
-     16,   0,   4,   0,   0,   0, 
-     26,   0,  16,   0,   5,   0, 
-      0,   0,  50,   0,   0,  10, 
-    130,   0,  16,   0,   0,   0, 
-      0,   0,  10,   0,  16,   0, 
-      5,   0,   0,   0,  10,   0, 
+      0,   0,   6, 224,  17,   0, 
+      0,   0,   0,   0,  18,   0, 
+      0,   1,  54,   0,   0,   5, 
+     34,   0,  16,   0,   4,   0, 
+      0,   0,   1,  64,   0,   0, 
+      0,   0,   0,   0,  21,   0, 
+      0,   1,   1,   0,   0,   7, 
+     98,   0,  16,   0,   3,   0, 
+      0,   0, 166,  10,  16,   0, 
+      1,   0,   0,   0,   6,   1, 
      16,   0,   4,   0,   0,   0, 
-     58,   0,  16, 128,  65,   0, 
-      0,   0,   0,   0,   0,   0, 
-     15,   0,   0,   7,  66,   0, 
-     16,   0,   1,   0,   0,   0, 
-     22,   5,  16,   0,   5,   0, 
-      0,   0,  70,   0,  16,   0, 
-      4,   0,   0,   0,  79,   0, 
-      0,   8, 130,   0,  16,   0, 
+     56,   0,   0,   7, 130,   0, 
+     16,   0,   0,   0,   0,   0, 
+     26,   0,  16,   0,   2,   0, 
+      0,   0,  42,   0,  16,   0, 
+      3,   0,   0,   0,  50,   0, 
+      0,  10, 130,   0,  16,   0, 
+      0,   0,   0,   0,  26,   0, 
+     16,   0,   3,   0,   0,   0, 
+     42,   0,  16,   0,   2,   0, 
+      0,   0,  58,   0,  16, 128, 
+     65,   0,   0,   0,   0,   0, 
+      0,   0,  15,   0,   0,   7, 
+     66,   0,  16,   0,   1,   0, 
+      0,   0, 150,   5,  16,   0, 
+      3,   0,   0,   0, 150,   5, 
+     16,   0,   2,   0,   0,   0, 
+     79,   0,   0,   8,  34,   0, 
+     16,   0,   2,   0,   0,   0, 
+     26,   0,  16,   0,   0,   0, 
+      0,   0,  42, 128,  32,   0, 
+      0,   0,   0,   0,   5,   0, 
+      0,   0,   1,   0,   0,   7, 
+     34,   0,  16,   0,   1,   0, 
+      0,   0,  26,   0,  16,   0, 
       1,   0,   0,   0,  26,   0, 
+     16,   0,   2,   0,   0,   0, 
+     38,   0,   0,   9,   0, 208, 
+      0,   0,  34,   0,  16,   0, 
+      2,   0,   0,   0,  26,   0, 
      16,   0,   0,   0,   0,   0, 
      42, 128,  32,   0,   0,   0, 
-      0,   0,   5,   0,   0,   0, 
-      1,   0,   0,   7,  34,   0, 
-     16,   0,   1,   0,   0,   0, 
-     58,   0,  16,   0,   1,   0, 
+      0,   0,   6,   0,   0,   0, 
+    167,   0,   0, 139,   2,  35, 
+      0, 128, 131, 153,  25,   0, 
+     18,   0,  16,   0,   4,   0, 
       0,   0,  26,   0,  16,   0, 
-      1,   0,   0,   0,  31,   0, 
-      4,   3,  26,   0,  16,   0, 
-      1,   0,   0,   0,  38,   0, 
-      0,   9,   0, 208,   0,   0, 
-     34,   0,  16,   0,   1,   0, 
+      2,   0,   0,   0,   1,  64, 
+      0,   0,   0,   0,   0,   0, 
+      6, 224,  17,   0,   2,   0, 
+      0,   0,  31,   0,   4,   3, 
+     58,   0,  16,   0,   1,   0, 
+      0,   0,  35,   0,   0,  11, 
+    130,   0,  16,   0,   1,   0, 
       0,   0,  26,   0,  16,   0, 
       0,   0,   0,   0,  42, 128, 
+     32,   0,   0,   0,   0,   0, 
+      6,   0,   0,   0,  58, 128, 
      32,   0,   0,   0,   0,   0, 
       6,   0,   0,   0, 167,   0, 
       0, 139,   2,  35,   0, 128, 
-    131, 153,  25,   0,  18,   0, 
+    131, 153,  25,   0,  34,   0, 
      16,   0,   4,   0,   0,   0, 
-     26,   0,  16,   0,   1,   0, 
+     58,   0,  16,   0,   1,   0, 
       0,   0,   1,  64,   0,   0, 
       0,   0,   0,   0,   6, 224, 
      17,   0,   2,   0,   0,   0, 
-     35,   0,   0,  11,  34,   0, 
-     16,   0,   1,   0,   0,   0, 
-     26,   0,  16,   0,   0,   0, 
-      0,   0,  42, 128,  32,   0, 
-      0,   0,   0,   0,   6,   0, 
-      0,   0,  58, 128,  32,   0, 
-      0,   0,   0,   0,   6,   0, 
-      0,   0,  32,   0,   0,   8, 
-    130,   0,  16,   0,   1,   0, 
-      0,   0,  58, 128,  32,   0, 
-      0,   0,   0,   0,   5,   0, 
-      0,   0,   1,  64,   0,   0, 
-      2,   0,   0,   0, 167,   0, 
-      0, 139,   2,  35,   0, 128, 
-    131, 153,  25,   0,  34,   0, 
+     18,   0,   0,   1,  54,   0, 
+      0,   5,  34,   0,  16,   0, 
+      4,   0,   0,   0,   1,  64, 
+      0,   0,   0,   0,   0,   0, 
+     21,   0,   0,   1,  55,   0, 
+      0,  12, 162,   0,  16,   0, 
+      1,   0,   0,   0,  86,   5, 
      16,   0,   1,   0,   0,   0, 
-     26,   0,  16,   0,   1,   0, 
-      0,   0,   1,  64,   0,   0, 
-      0,   0,   0,   0,   6, 224, 
-     17,   0,   2,   0,   0,   0, 
-      1,   0,   0,   7,  34,   0, 
-     16,   0,   4,   0,   0,   0, 
-     26,   0,  16,   0,   1,   0, 
-      0,   0,  58,   0,  16,   0, 
-      1,   0,   0,   0,  18,   0, 
-      0,   1,  54,   0,   0,   8, 
-     50,   0,  16,   0,   4,   0, 
+     86,   1,  16,   0,   4,   0, 
       0,   0,   2,  64,   0,   0, 
-      0,   0, 128,  63,   0,   0, 
       0,   0,   0,   0,   0,   0, 
-      0,   0,   0,   0,  21,   0, 
-      0,   1,  79,   0,   0,   8, 
-     34,   0,  16,   0,   1,   0, 
-      0,   0,  26,   0,  16,   0, 
-      0,   0,   0,   0,  26, 128, 
-     32,   0,   0,   0,   0,   0, 
-      1,   0,   0,   0,  31,   0, 
-      4,   3,  26,   0,  16,   0, 
-      1,   0,   0,   0,  38,   0, 
-      0,   9,   0, 208,   0,   0, 
-     34,   0,  16,   0,   0,   0, 
-      0,   0,  26,   0,  16,   0, 
-      0,   0,   0,   0,  26, 128, 
-     32,   0,   0,   0,   0,   0, 
-      2,   0,   0,   0,  35,   0, 
+      0,   0,   0,   0,   0,   0, 
+      0,   0, 128,  63,  79,   0, 
+      0,   8,  34,   0,  16,   0, 
+      2,   0,   0,   0,  26,   0, 
+     16,   0,   0,   0,   0,   0, 
+     26, 128,  32,   0,   0,   0, 
+      0,   0,   1,   0,   0,   0, 
+     35,   0,   0,  10,  34,   0, 
+     16,   0,   0,   0,   0,   0, 
+     26,   0,  16,   0,   0,   0, 
+      0,   0,  26, 128,  32,   0, 
+      0,   0,   0,   0,   2,   0, 
+      0,   0,  10,   0,  16,   0, 
+      1,   0,   0,   0,  35,   0, 
       0,  10,  34,   0,  16,   0, 
       0,   0,   0,   0,  10,   0, 
-     16,   0,   1,   0,   0,   0, 
-     10, 128,  32,   0,   0,   0, 
+     16,   0,   3,   0,   0,   0, 
+     42, 128,  32,   0,   0,   0, 
       0,   0,   2,   0,   0,   0, 
      26,   0,  16,   0,   0,   0, 
-      0,   0,  35,   0,   0,  10, 
-     34,   0,  16,   0,   0,   0, 
-      0,   0,  10,   0,  16,   0, 
-      3,   0,   0,   0,  42, 128, 
-     32,   0,   0,   0,   0,   0, 
-      2,   0,   0,   0,  26,   0, 
+      0,   0, 167,   0,   0, 139, 
+      2,  35,   0, 128, 131, 153, 
+     25,   0,  18,   0,  16,   0, 
+      3,   0,   0,   0,  26,   0, 
      16,   0,   0,   0,   0,   0, 
-    167,   0,   0, 139,   2,  35, 
-      0, 128, 131, 153,  25,   0, 
-     18,   0,  16,   0,   1,   0, 
-      0,   0,  26,   0,  16,   0, 
-      0,   0,   0,   0,   1,  64, 
-      0,   0,   0,   0,   0,   0, 
-      6, 224,  17,   0,   0,   0, 
-      0,   0,  32,   0,   0,   8, 
-    130,   0,  16,   0,   1,   0, 
-      0,   0,  58, 128,  32,   0, 
-      0,   0,   0,   0,   1,   0, 
-      0,   0,   1,  64,   0,   0, 
-      2,   0,   0,   0,  31,   0, 
+      1,  64,   0,   0,   0,   0, 
+      0,   0,   6, 224,  17,   0, 
+      0,   0,   0,   0,  31,   0, 
       4,   3,  58,   0,  16,   0, 
-      1,   0,   0,   0,  30,   0, 
+      2,   0,   0,   0,  30,   0, 
       0,   8,  34,   0,  16,   0, 
       0,   0,   0,   0,  26,   0, 
      16,   0,   0,   0,   0,   0, 
@@ -490,39 +466,37 @@ const BYTE g_DFT[] =
       0,   0,   2,   0,   0,   0, 
     167,   0,   0, 139,   2,  35, 
       0, 128, 131, 153,  25,   0, 
-     34,   0,  16,   0,   1,   0, 
+     34,   0,  16,   0,   3,   0, 
       0,   0,  26,   0,  16,   0, 
       0,   0,   0,   0,   1,  64, 
       0,   0,   0,   0,   0,   0, 
       6, 224,  17,   0,   0,   0, 
       0,   0,  18,   0,   0,   1, 
      54,   0,   0,   5,  34,   0, 
-     16,   0,   1,   0,   0,   0, 
+     16,   0,   3,   0,   0,   0, 
       1,  64,   0,   0,   0,   0, 
       0,   0,  21,   0,   0,   1, 
-     18,   0,   0,   1,  54,   0, 
-      0,   8,  50,   0,  16,   0, 
-      1,   0,   0,   0,   2,  64, 
-      0,   0,   0,   0,   0,   0, 
-      0,   0,   0,   0,   0,   0, 
-      0,   0,   0,   0,   0,   0, 
-     21,   0,   0,   1,  56,   0, 
+      1,   0,   0,   7,  98,   0, 
+     16,   0,   2,   0,   0,   0, 
+     86,   5,  16,   0,   2,   0, 
+      0,   0,   6,   1,  16,   0, 
+      3,   0,   0,   0,  56,   0, 
       0,   7,  34,   0,  16,   0, 
       0,   0,   0,   0,  26,   0, 
-     16,   0,   4,   0,   0,   0, 
-     26,   0,  16,   0,   1,   0, 
+     16,   0,   1,   0,   0,   0, 
+     42,   0,  16,   0,   2,   0, 
       0,   0,  50,   0,   0,  10, 
      34,   0,  16,   0,   0,   0, 
-      0,   0,  10,   0,  16,   0, 
-      1,   0,   0,   0,  10,   0, 
-     16,   0,   4,   0,   0,   0, 
+      0,   0,  26,   0,  16,   0, 
+      2,   0,   0,   0,  58,   0, 
+     16,   0,   1,   0,   0,   0, 
      26,   0,  16, 128,  65,   0, 
       0,   0,   0,   0,   0,   0, 
      15,   0,   0,   7,  18,   0, 
      16,   0,   1,   0,   0,   0, 
-     22,   5,  16,   0,   1,   0, 
-      0,   0,  70,   0,  16,   0, 
-      4,   0,   0,   0,  78,   0, 
+    150,   5,  16,   0,   2,   0, 
+      0,   0, 214,   5,  16,   0, 
+      1,   0,   0,   0,  78,   0, 
       0,   8,   0, 208,   0,   0, 
      34,   0,  16,   0,   1,   0, 
       0,   0,  10,   0,  16,   0, 
@@ -610,65 +584,77 @@ const BYTE g_DFT[] =
      16,   0,   0,   0,   0,   0, 
     166, 138,  32,   0,   0,   0, 
       0,   0,   7,   0,   0,   0, 
-     57,   0,   0,   8,  18,   0, 
+     24,   0,   0,   8,  18,   0, 
      16,   0,   1,   0,   0,   0, 
      26, 128,  32,   0,   0,   0, 
       0,   0,   7,   0,   0,   0, 
       1,  64,   0,   0,   0,   0, 
+      0,   0,  31,   0,   4,   3, 
+     10,   0,  16,   0,   1,   0, 
+      0,   0,  54,   0,   0,   5, 
+     18,   0,  16,   0,   1,   0, 
+      0,   0,  58,   0,  16,   0, 
+      0,   0,   0,   0,  18,   0, 
+      0,   1,  57,   0,   0,   8, 
+     66,   0,  16,   0,   1,   0, 
+      0,   0,  26, 128,  32,   0, 
+      0,   0,   0,   0,   7,   0, 
+      0,   0,   1,  64,   0,   0, 
+      0,   0,   0,   0,  56,   0, 
+      0,   7,  34,   0,  16,   0, 
+      1,   0,   0,   0,  26,   0, 
+     16,   0,   1,   0,   0,   0, 
+     26,   0,  16,   0,   1,   0, 
       0,   0,  56,   0,   0,   7, 
      34,   0,  16,   0,   1,   0, 
       0,   0,  26,   0,  16,   0, 
-      1,   0,   0,   0,  26,   0, 
-     16,   0,   1,   0,   0,   0, 
-     56,   0,   0,   7,  34,   0, 
+      1,   0,   0,   0,   1,  64, 
+      0,   0, 219,  15,  73,  64, 
+     14,   0,   0,   8,  34,   0, 
      16,   0,   1,   0,   0,   0, 
      26,   0,  16,   0,   1,   0, 
-      0,   0,   1,  64,   0,   0, 
-    219,  15,  73,  64,  14,   0, 
-      0,   8,  34,   0,  16,   0, 
-      1,   0,   0,   0,  26,   0, 
+      0,   0,  26, 128,  32,   0, 
+      0,   0,   0,   0,   7,   0, 
+      0,   0,  77,   0,   0,   7, 
+     18,   0,  16,   0,   2,   0, 
+      0,   0,  18,   0,  16,   0, 
+      3,   0,   0,   0,  26,   0, 
      16,   0,   1,   0,   0,   0, 
-     26, 128,  32,   0,   0,   0, 
-      0,   0,   7,   0,   0,   0, 
-     77,   0,   0,   7,  18,   0, 
+     54,   0,   0,   5,  34,   0, 
      16,   0,   2,   0,   0,   0, 
-     18,   0,  16,   0,   3,   0, 
-      0,   0,  26,   0,  16,   0, 
-      1,   0,   0,   0,  54,   0, 
-      0,   5,  34,   0,  16,   0, 
-      2,   0,   0,   0,  10,   0, 
-     16,   0,   3,   0,   0,   0, 
-     55,   0,   0,  12,  50,   0, 
-     16,   0,   1,   0,   0,   0, 
-      6,   0,  16,   0,   1,   0, 
-      0,   0,  70,   0,  16,   0, 
-      2,   0,   0,   0,   2,  64, 
+     10,   0,  16,   0,   3,   0, 
+      0,   0,  55,   0,   0,  12, 
+     98,   0,  16,   0,   1,   0, 
+      0,   0, 166,  10,  16,   0, 
+      1,   0,   0,   0,   6,   1, 
+     16,   0,   2,   0,   0,   0, 
+      2,  64,   0,   0,   0,   0, 
       0,   0,   0,   0,   0,   0, 
       0,   0, 128,  63,   0,   0, 
-      0,   0,   0,   0,   0,   0, 
-     56,   0,   0,   7, 194,   0, 
+      0,   0,  56,   0,   0,   7, 
+     50,   0,  16,   0,   2,   0, 
+      0,   0,  86,   5,  16,   0, 
+      0,   0,   0,   0, 150,   5, 
      16,   0,   1,   0,   0,   0, 
-     86,   5,  16,   0,   0,   0, 
-      0,   0,   6,   4,  16,   0, 
-      1,   0,   0,   0,  50,   0, 
-      0,  10,  34,   0,  16,   0, 
+     50,   0,   0,  10,  18,   0, 
+     16,   0,   1,   0,   0,   0, 
+     58,   0,  16,   0,   0,   0, 
+      0,   0,  42,   0,  16,   0, 
+      1,   0,   0,   0,  10,   0, 
+     16, 128,  65,   0,   0,   0, 
+      2,   0,   0,   0,  50,   0, 
+      0,   9,  34,   0,  16,   0, 
       0,   0,   0,   0,  58,   0, 
      16,   0,   0,   0,   0,   0, 
      26,   0,  16,   0,   1,   0, 
-      0,   0,  42,   0,  16, 128, 
-     65,   0,   0,   0,   1,   0, 
-      0,   0, 168,   0,   0,   9, 
+      0,   0,  26,   0,  16,   0, 
+      2,   0,   0,   0,  21,   0, 
+      0,   1, 168,   0,   0,   9, 
      18, 224,  17,   0,   1,   0, 
       0,   0,  42,   0,  16,   0, 
       0,   0,   0,   0,   1,  64, 
       0,   0,   0,   0,   0,   0, 
-     26,   0,  16,   0,   0,   0, 
-      0,   0,  50,   0,   0,   9, 
-     34,   0,  16,   0,   0,   0, 
-      0,   0,  58,   0,  16,   0, 
-      0,   0,   0,   0,  10,   0, 
-     16,   0,   1,   0,   0,   0, 
-     58,   0,  16,   0,   1,   0, 
+     10,   0,  16,   0,   1,   0, 
       0,   0, 168,   0,   0,   9, 
      18, 224,  17,   0,   1,   0, 
       0,   0,  10,   0,  16,   0, 
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham_fp16.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham_fp16.h
index 988c0aa66ade2..56ce759875687 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham_fp16.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham_fp16.h
@@ -15,7 +15,7 @@
 ; Name                 Index   Mask Register SysValue  Format   Used
 ; -------------------- ----- ------ -------- -------- ------- ------
 ; no parameters
-; shader hash: e08f21199c48b0db30bf21bd8c5b80dc
+; shader hash: 6a1d88feb14177832f5ee49ca330c549
 ;
 ; Pipeline Runtime Information: 
 ;
@@ -125,7 +125,7 @@ define void @DFT() {
   %47 = fpext half %46 to float
   %48 = extractvalue %dx.types.CBufRet.i32 %37, 3
   %49 = icmp eq i32 %48, 2
-  br i1 %49, label %50, label %56
+  br i1 %49, label %50, label %56, !dx.controlflow.hints !15
 
 ; <label>:50                                      ; preds = %41
   %51 = extractvalue %dx.types.CBufRet.i32 %42, 3
@@ -141,7 +141,7 @@ define void @DFT() {
   %59 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %4, i32 1)  ; CBufferLoadLegacy(handle,regIndex)
   %60 = extractvalue %dx.types.CBufRet.i32 %59, 1
   %61 = icmp ult i32 %33, %60
-  br i1 %61, label %62, label %83, !dx.controlflow.hints !15
+  br i1 %61, label %62, label %83, !dx.controlflow.hints !16
 
 ; <label>:62                                      ; preds = %56
   %63 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %4, i32 2)  ; CBufferLoadLegacy(handle,regIndex)
@@ -158,7 +158,7 @@ define void @DFT() {
   %74 = fpext half %73 to float
   %75 = extractvalue %dx.types.CBufRet.i32 %59, 3
   %76 = icmp eq i32 %75, 2
-  br i1 %76, label %77, label %83, !dx.controlflow.hints !16
+  br i1 %76, label %77, label %83, !dx.controlflow.hints !17
 
 ; <label>:77                                      ; preds = %62
   %78 = extractvalue %dx.types.CBufRet.i32 %63, 3
@@ -188,7 +188,7 @@ define void @DFT() {
   %98 = fpext half %97 to float
   %99 = extractvalue %dx.types.CBufRet.i32 %37, 3
   %100 = icmp eq i32 %99, 2
-  br i1 %100, label %101, label %107
+  br i1 %100, label %101, label %107, !dx.controlflow.hints !15
 
 ; <label>:101                                     ; preds = %92
   %102 = extractvalue %dx.types.CBufRet.i32 %93, 3
@@ -202,7 +202,7 @@ define void @DFT() {
   %108 = phi float [ %98, %101 ], [ %98, %92 ], [ 1.000000e+00, %83 ]
   %109 = phi float [ %106, %101 ], [ 0.000000e+00, %92 ], [ 0.000000e+00, %83 ]
   %110 = icmp ult i32 %34, %60
-  br i1 %110, label %111, label %132, !dx.controlflow.hints !15
+  br i1 %110, label %111, label %132, !dx.controlflow.hints !16
 
 ; <label>:111                                     ; preds = %107
   %112 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %4, i32 2)  ; CBufferLoadLegacy(handle,regIndex)
@@ -219,7 +219,7 @@ define void @DFT() {
   %123 = fpext half %122 to float
   %124 = extractvalue %dx.types.CBufRet.i32 %59, 3
   %125 = icmp eq i32 %124, 2
-  br i1 %125, label %126, label %132, !dx.controlflow.hints !16
+  br i1 %125, label %126, label %132, !dx.controlflow.hints !17
 
 ; <label>:126                                     ; preds = %111
   %127 = extractvalue %dx.types.CBufRet.i32 %112, 3
@@ -270,19 +270,21 @@ define void @DFT() {
   %170 = fmul fast float %158, %169
   %171 = extractvalue %dx.types.CBufRet.f32 %157, 1
   %172 = fcmp fast oeq float %171, 0.000000e+00
-  br i1 %172, label %179, label %173
+  br i1 %172, label %173, label %176, !dx.controlflow.hints !18
 
 ; <label>:173                                     ; preds = %132
-  %174 = fmul fast float %146, %146
-  %175 = fmul fast float %174, 0x400921FB60000000
-  %176 = fdiv fast float %175, %171
-  %177 = call float @dx.op.unary.f32(i32 12, float %176)  ; Cos(value)
-  %178 = call float @dx.op.unary.f32(i32 13, float %176)  ; Sin(value)
-  br label %179
+  %174 = fptrunc float %164 to half
+  call void @dx.op.rawBufferStore.f16(i32 140, %dx.types.Handle %2, i32 %154, i32 0, half %174, half undef, half undef, half undef, i8 1, i32 2)  ; RawBufferStore(uav,index,elementOffset,value0,value1,value2,value3,mask,alignment)
+  %175 = fptrunc float %170 to half
+  call void @dx.op.rawBufferStore.f16(i32 140, %dx.types.Handle %2, i32 %156, i32 0, half %175, half undef, half undef, half undef, i8 1, i32 2)  ; RawBufferStore(uav,index,elementOffset,value0,value1,value2,value3,mask,alignment)
+  br label %190
 
-; <label>:179                                     ; preds = %173, %132
-  %180 = phi float [ %177, %173 ], [ 1.000000e+00, %132 ]
-  %181 = phi float [ %178, %173 ], [ 0.000000e+00, %132 ]
+; <label>:176                                     ; preds = %132
+  %177 = fmul fast float %146, %146
+  %178 = fmul fast float %177, 0x400921FB60000000
+  %179 = fdiv fast float %178, %171
+  %180 = call float @dx.op.unary.f32(i32 12, float %179)  ; Cos(value)
+  %181 = call float @dx.op.unary.f32(i32 13, float %179)  ; Sin(value)
   %182 = fmul fast float %180, %164
   %183 = fmul fast float %181, %170
   %184 = fsub fast float %182, %183
@@ -295,7 +297,7 @@ define void @DFT() {
   call void @dx.op.rawBufferStore.f16(i32 140, %dx.types.Handle %2, i32 %156, i32 0, half %189, half undef, half undef, half undef, i8 1, i32 2)  ; RawBufferStore(uav,index,elementOffset,value0,value1,value2,value3,mask,alignment)
   br label %190
 
-; <label>:190                                     ; preds = %179, %0
+; <label>:190                                     ; preds = %176, %173, %0
   ret void
 }
 
@@ -345,16 +347,18 @@ attributes #2 = { nounwind }
 !11 = !{void ()* @DFT, !"DFT", null, !4, !12}
 !12 = !{i32 0, i64 8388656, i32 4, !13}
 !13 = !{i32 64, i32 1, i32 1}
-!14 = distinct !{!14, !"dx.controlflow.hints", i32 1}
+!14 = distinct !{!14, !"dx.controlflow.hints", i32 2}
 !15 = distinct !{!15, !"dx.controlflow.hints", i32 1}
-!16 = distinct !{!16, !"dx.controlflow.hints", i32 1}
+!16 = distinct !{!16, !"dx.controlflow.hints", i32 2}
+!17 = distinct !{!17, !"dx.controlflow.hints", i32 1}
+!18 = distinct !{!18, !"dx.controlflow.hints", i32 1}
 
 #endif
 
 const unsigned char g_DFT[] = {
-  0x44, 0x58, 0x42, 0x43, 0x0f, 0xc1, 0xea, 0x65, 0x6d, 0xe3, 0x8d, 0x13,
-  0x2c, 0xb2, 0x19, 0xb3, 0xd4, 0xb1, 0x94, 0xb9, 0x01, 0x00, 0x00, 0x00,
-  0xfc, 0x0b, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00,
+  0x44, 0x58, 0x42, 0x43, 0x12, 0x40, 0x8a, 0x15, 0xf2, 0x7d, 0x33, 0xd8,
+  0x35, 0x6a, 0x11, 0xd5, 0x43, 0xa1, 0x29, 0x3b, 0x01, 0x00, 0x00, 0x00,
+  0x3c, 0x0c, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00,
   0x48, 0x00, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00,
   0x18, 0x01, 0x00, 0x00, 0x34, 0x01, 0x00, 0x00, 0x53, 0x46, 0x49, 0x30,
   0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00,
@@ -376,12 +380,12 @@ const unsigned char g_DFT[] = {
   0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00,
   0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
   0x00, 0x00, 0x00, 0x00, 0x48, 0x41, 0x53, 0x48, 0x14, 0x00, 0x00, 0x00,
-  0x00, 0x00, 0x00, 0x00, 0xe0, 0x8f, 0x21, 0x19, 0x9c, 0x48, 0xb0, 0xdb,
-  0x30, 0xbf, 0x21, 0xbd, 0x8c, 0x5b, 0x80, 0xdc, 0x44, 0x58, 0x49, 0x4c,
-  0xc0, 0x0a, 0x00, 0x00, 0x62, 0x00, 0x05, 0x00, 0xb0, 0x02, 0x00, 0x00,
+  0x00, 0x00, 0x00, 0x00, 0x6a, 0x1d, 0x88, 0xfe, 0xb1, 0x41, 0x77, 0x83,
+  0x2f, 0x5e, 0xe4, 0x9c, 0xa3, 0x30, 0xc5, 0x49, 0x44, 0x58, 0x49, 0x4c,
+  0x00, 0x0b, 0x00, 0x00, 0x62, 0x00, 0x05, 0x00, 0xc0, 0x02, 0x00, 0x00,
   0x44, 0x58, 0x49, 0x4c, 0x02, 0x01, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
-  0xa8, 0x0a, 0x00, 0x00, 0x42, 0x43, 0xc0, 0xde, 0x21, 0x0c, 0x00, 0x00,
-  0xa7, 0x02, 0x00, 0x00, 0x0b, 0x82, 0x20, 0x00, 0x02, 0x00, 0x00, 0x00,
+  0xe8, 0x0a, 0x00, 0x00, 0x42, 0x43, 0xc0, 0xde, 0x21, 0x0c, 0x00, 0x00,
+  0xb7, 0x02, 0x00, 0x00, 0x0b, 0x82, 0x20, 0x00, 0x02, 0x00, 0x00, 0x00,
   0x13, 0x00, 0x00, 0x00, 0x07, 0x81, 0x23, 0x91, 0x41, 0xc8, 0x04, 0x49,
   0x06, 0x10, 0x32, 0x39, 0x92, 0x01, 0x84, 0x0c, 0x25, 0x05, 0x08, 0x19,
   0x1e, 0x04, 0x8b, 0x62, 0x80, 0x18, 0x45, 0x02, 0x42, 0x92, 0x0b, 0x42,
@@ -441,7 +445,7 @@ const unsigned char g_DFT[] = {
   0x4a, 0xa0, 0x08, 0x8a, 0x61, 0x04, 0xa0, 0x30, 0x0a, 0x50, 0xa0, 0x10,
   0x0a, 0x30, 0x80, 0xb0, 0x11, 0x00, 0x0a, 0x0b, 0x1c, 0x10, 0x10, 0x81,
   0xc0, 0x19, 0x00, 0xea, 0x66, 0x00, 0x00, 0x00, 0x79, 0x18, 0x00, 0x00,
-  0x4f, 0x00, 0x00, 0x00, 0x1a, 0x03, 0x4c, 0x90, 0x46, 0x02, 0x13, 0x44,
+  0x52, 0x00, 0x00, 0x00, 0x1a, 0x03, 0x4c, 0x90, 0x46, 0x02, 0x13, 0x44,
   0x35, 0x18, 0x63, 0x0b, 0x73, 0x3b, 0x03, 0xb1, 0x2b, 0x93, 0x9b, 0x4b,
   0x7b, 0x73, 0x03, 0x99, 0x71, 0xb9, 0x01, 0x41, 0xa1, 0x0b, 0x3b, 0x9b,
   0x7b, 0x91, 0x2a, 0x62, 0x2a, 0x0a, 0x9a, 0x2a, 0xfa, 0x9a, 0xb9, 0x81,
@@ -458,16 +462,17 @@ const unsigned char g_DFT[] = {
   0x70, 0x26, 0x08, 0xc3, 0xb3, 0x61, 0xe0, 0x86, 0x61, 0x03, 0xa1, 0x68,
   0x5b, 0xb7, 0xa1, 0xc0, 0x32, 0xe0, 0xf2, 0x48, 0x91, 0xe1, 0xb9, 0x8c,
   0xbd, 0xb9, 0xd1, 0xc9, 0xbd, 0xb1, 0x99, 0xb1, 0xbd, 0xdd, 0xb9, 0xa0,
-  0xa5, 0xb9, 0xd1, 0xcd, 0xad, 0x18, 0xc2, 0x00, 0x0c, 0x86, 0x15, 0x83,
-  0x18, 0x80, 0xc1, 0xb0, 0x62, 0x18, 0x03, 0x30, 0x18, 0xaa, 0xb0, 0xb1,
-  0xd9, 0xb5, 0xb9, 0xa4, 0x91, 0x95, 0xb9, 0xd1, 0x4d, 0x09, 0x82, 0x2a,
-  0x64, 0x78, 0x2e, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x53, 0x02,
-  0xa2, 0x09, 0x19, 0x9e, 0x8b, 0x5d, 0x18, 0x9b, 0x5d, 0x99, 0xdc, 0x94,
-  0xc0, 0xa8, 0x43, 0x86, 0xe7, 0x32, 0x87, 0x16, 0x46, 0x56, 0x26, 0xd7,
-  0xf4, 0x46, 0x56, 0xc6, 0x36, 0x25, 0x40, 0xca, 0x90, 0xe1, 0xb9, 0xc8,
-  0x95, 0xcd, 0xbd, 0xd5, 0xc9, 0x8d, 0x95, 0xcd, 0x4d, 0x09, 0xac, 0x3a,
-  0x64, 0x78, 0x2e, 0x65, 0x6e, 0x74, 0x72, 0x79, 0x50, 0x6f, 0x69, 0x6e,
-  0x74, 0x73, 0x53, 0x02, 0x0f, 0x00, 0x00, 0x00, 0x79, 0x18, 0x00, 0x00,
+  0xa5, 0xb9, 0xd1, 0xcd, 0xad, 0x18, 0xc2, 0x00, 0x0c, 0x88, 0x15, 0x83,
+  0x18, 0x80, 0xc1, 0xb0, 0x62, 0x18, 0x03, 0x30, 0x20, 0x56, 0x0c, 0x64,
+  0x00, 0x06, 0xc3, 0x8a, 0xa1, 0x0c, 0xc0, 0x60, 0xa8, 0xc2, 0xc6, 0x66,
+  0xd7, 0xe6, 0x92, 0x46, 0x56, 0xe6, 0x46, 0x37, 0x25, 0x08, 0xaa, 0x90,
+  0xe1, 0xb9, 0xd8, 0x95, 0xc9, 0xcd, 0xa5, 0xbd, 0xb9, 0x4d, 0x09, 0x88,
+  0x26, 0x64, 0x78, 0x2e, 0x76, 0x61, 0x6c, 0x76, 0x65, 0x72, 0x53, 0x02,
+  0xa3, 0x0e, 0x19, 0x9e, 0xcb, 0x1c, 0x5a, 0x18, 0x59, 0x99, 0x5c, 0xd3,
+  0x1b, 0x59, 0x19, 0xdb, 0x94, 0x00, 0x29, 0x43, 0x86, 0xe7, 0x22, 0x57,
+  0x36, 0xf7, 0x56, 0x27, 0x37, 0x56, 0x36, 0x37, 0x25, 0xb0, 0xea, 0x90,
+  0xe1, 0xb9, 0x94, 0xb9, 0xd1, 0xc9, 0xe5, 0x41, 0xbd, 0xa5, 0xb9, 0xd1,
+  0xcd, 0x4d, 0x09, 0x3c, 0x00, 0x00, 0x00, 0x00, 0x79, 0x18, 0x00, 0x00,
   0x59, 0x00, 0x00, 0x00, 0x33, 0x08, 0x80, 0x1c, 0xc4, 0xe1, 0x1c, 0x66,
   0x14, 0x01, 0x3d, 0x88, 0x43, 0x38, 0x84, 0xc3, 0x8c, 0x42, 0x80, 0x07,
   0x79, 0x78, 0x07, 0x73, 0x98, 0x71, 0x0c, 0xe6, 0x00, 0x0f, 0xed, 0x10,
@@ -510,9 +515,9 @@ const unsigned char g_DFT[] = {
   0x13, 0x11, 0x7e, 0x51, 0xeb, 0x16, 0x20, 0x0d, 0x97, 0xef, 0x3c, 0xfe,
   0x74, 0x44, 0x04, 0x30, 0x88, 0x83, 0x8f, 0xdc, 0xb6, 0x09, 0x3c, 0xc3,
   0xe5, 0x3b, 0x8f, 0x4f, 0x35, 0x40, 0x84, 0xf9, 0xc5, 0x6d, 0x03, 0x00,
-  0x61, 0x20, 0x00, 0x00, 0x22, 0x01, 0x00, 0x00, 0x13, 0x04, 0x51, 0x2c,
+  0x61, 0x20, 0x00, 0x00, 0x2f, 0x01, 0x00, 0x00, 0x13, 0x04, 0x51, 0x2c,
   0x10, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x34, 0x94, 0x5d, 0x59,
-  0x0a, 0x94, 0x5c, 0xf9, 0x94, 0x43, 0x0d, 0x94, 0x46, 0x61, 0x0a, 0x94,
+  0x0a, 0x94, 0x5c, 0x61, 0x0a, 0x94, 0x4f, 0x39, 0xd4, 0x40, 0x69, 0x94,
   0x6e, 0x40, 0x19, 0x94, 0x02, 0x2d, 0x45, 0x50, 0x02, 0x64, 0x8c, 0x11,
   0xec, 0xfe, 0x28, 0xb3, 0x60, 0x30, 0x46, 0xb0, 0xfb, 0xa3, 0xcc, 0x82,
   0xc3, 0x18, 0xc1, 0xee, 0x8f, 0x32, 0x09, 0x06, 0x94, 0xcc, 0x00, 0x90,
@@ -525,87 +530,91 @@ const unsigned char g_DFT[] = {
   0x3c, 0x60, 0x30, 0x62, 0x70, 0x00, 0x20, 0x08, 0x06, 0xd3, 0x19, 0x60,
   0x42, 0x18, 0x8c, 0x26, 0x04, 0x40, 0x05, 0x03, 0x8c, 0x26, 0x0c, 0xc1,
   0x70, 0x83, 0x10, 0x90, 0xc1, 0x2c, 0x43, 0x00, 0x05, 0x23, 0x06, 0x07,
-  0x00, 0x82, 0x60, 0x30, 0xb1, 0x41, 0x77, 0x64, 0xa3, 0x09, 0xc1, 0x50,
-  0xc1, 0x1a, 0xe0, 0x68, 0x02, 0x22, 0x54, 0xe0, 0x69, 0xb9, 0x41, 0x70,
-  0x35, 0x80, 0x01, 0x54, 0x10, 0x06, 0x6a, 0x71, 0x10, 0x5c, 0x60, 0xc4,
-  0xe0, 0x00, 0x40, 0x10, 0x0c, 0xa6, 0x3a, 0x30, 0x03, 0xa8, 0x1b, 0x4d,
-  0x08, 0x82, 0xd1, 0x04, 0x41, 0xa8, 0x40, 0x90, 0x82, 0x82, 0xaa, 0x48,
-  0x98, 0x12, 0x88, 0xa9, 0xa1, 0xa8, 0x12, 0x1a, 0xac, 0x60, 0xb9, 0x5a,
-  0xd6, 0x00, 0xaa, 0x08, 0xb4, 0x86, 0x00, 0x2a, 0xa0, 0x60, 0x34, 0xe1,
-  0x02, 0x86, 0x1b, 0x82, 0x50, 0x00, 0x83, 0x11, 0x83, 0x03, 0x00, 0x41,
-  0x30, 0x98, 0x4a, 0xc1, 0x0e, 0xc0, 0x80, 0x0d, 0x46, 0x13, 0x02, 0x61,
-  0xb8, 0xc1, 0x08, 0xc8, 0xa0, 0x88, 0x40, 0x67, 0x19, 0x04, 0x22, 0x18,
-  0x31, 0x38, 0x00, 0x10, 0x04, 0x83, 0x29, 0x15, 0xf4, 0x80, 0x0c, 0x48,
-  0x61, 0x34, 0x21, 0x10, 0x2a, 0x50, 0x64, 0xc4, 0x40, 0x01, 0x40, 0x10,
-  0x0c, 0x1c, 0x57, 0xe0, 0x03, 0x35, 0x08, 0x4c, 0x21, 0x0e, 0x50, 0x61,
-  0x34, 0x21, 0x00, 0x2e, 0x30, 0x70, 0x34, 0x41, 0x19, 0x86, 0x1b, 0x02,
-  0x56, 0x00, 0x83, 0x59, 0x86, 0x81, 0x08, 0x46, 0x13, 0x90, 0xa1, 0x82,
-  0x03, 0x46, 0x0c, 0x14, 0x00, 0x04, 0xc1, 0xc0, 0xa9, 0x85, 0x51, 0x88,
-  0x83, 0xa0, 0x15, 0xf0, 0xe0, 0x15, 0x46, 0x13, 0x02, 0xe0, 0x02, 0x03,
-  0x67, 0x09, 0x88, 0x81, 0x0e, 0x03, 0x1a, 0x20, 0x81, 0x1d, 0x82, 0x81,
-  0x0e, 0x83, 0x18, 0xd8, 0x41, 0x60, 0x87, 0x60, 0xc4, 0xe0, 0x00, 0x40,
-  0x10, 0x0c, 0x26, 0x5c, 0x48, 0x85, 0x39, 0xa8, 0x85, 0xd1, 0x84, 0x20,
-  0x18, 0x6e, 0xc8, 0x02, 0x32, 0x98, 0x65, 0x28, 0x8e, 0x60, 0xc4, 0xe0,
-  0x00, 0x40, 0x10, 0x0c, 0xa6, 0x5d, 0x60, 0x05, 0x3b, 0xb8, 0x85, 0xd1,
-  0x84, 0x00, 0xa8, 0x60, 0x0c, 0x64, 0x34, 0x61, 0x08, 0x2a, 0xf0, 0xa4,
-  0x82, 0x01, 0x46, 0x13, 0x0c, 0xa1, 0x02, 0x33, 0x90, 0x1a, 0x02, 0x18,
-  0x31, 0x50, 0x00, 0x10, 0x04, 0x03, 0xc7, 0x1c, 0x68, 0x01, 0x14, 0x02,
-  0x5f, 0x48, 0x05, 0x70, 0x18, 0x4d, 0x08, 0x80, 0x0b, 0x0c, 0x1c, 0x4d,
-  0x78, 0x86, 0xe1, 0x86, 0x80, 0x1c, 0xc0, 0x60, 0x96, 0xc1, 0x38, 0x82,
-  0xd1, 0x04, 0x67, 0xa8, 0xe0, 0x80, 0x11, 0x03, 0x05, 0x00, 0x41, 0x30,
-  0x70, 0xda, 0x61, 0x17, 0x4e, 0x21, 0x28, 0x07, 0x58, 0x38, 0x87, 0xd1,
-  0x84, 0x00, 0xb8, 0xc0, 0xc0, 0x59, 0x82, 0x63, 0xa0, 0xc3, 0x80, 0x0c,
-  0xa8, 0xd0, 0x09, 0x62, 0xa0, 0xc3, 0x20, 0x0c, 0x9e, 0x28, 0x78, 0x82,
-  0x30, 0x41, 0x93, 0x8f, 0x09, 0x9a, 0x7c, 0x8c, 0xd8, 0xe4, 0x63, 0x44,
-  0x27, 0x9f, 0xe1, 0x06, 0x39, 0x70, 0x03, 0x32, 0xa8, 0x38, 0x08, 0x74,
-  0x96, 0x01, 0x51, 0x82, 0x11, 0x83, 0x03, 0x00, 0x41, 0x30, 0x98, 0xec,
-  0xe1, 0x1c, 0x62, 0x21, 0x1e, 0x46, 0x13, 0x02, 0xa1, 0x02, 0x3b, 0x90,
-  0x11, 0x03, 0x05, 0x00, 0x41, 0x30, 0x70, 0xf6, 0x21, 0x1d, 0x6e, 0x21,
-  0x98, 0x07, 0x5f, 0xa8, 0x87, 0xd1, 0x84, 0x00, 0xb8, 0xc0, 0xc0, 0xd1,
-  0x84, 0x3b, 0x18, 0x86, 0x1b, 0x82, 0x7c, 0x00, 0x83, 0x59, 0x86, 0x44,
-  0x09, 0x46, 0x13, 0x90, 0xa1, 0x82, 0x03, 0x46, 0x0c, 0x14, 0x00, 0x04,
-  0xc1, 0xc0, 0x11, 0x09, 0x78, 0xf0, 0x85, 0x40, 0x1f, 0xca, 0x81, 0x1f,
-  0x46, 0x13, 0x02, 0xe0, 0x02, 0x03, 0x67, 0x09, 0x94, 0x81, 0x0e, 0x03,
-  0x4a, 0x20, 0xc4, 0x34, 0x8e, 0x81, 0x0e, 0x83, 0x48, 0x4c, 0x03, 0x31,
-  0x8d, 0x63, 0xb8, 0x61, 0x14, 0xd8, 0x80, 0x0c, 0x66, 0x19, 0x96, 0x26,
-  0x18, 0x31, 0x38, 0x00, 0x10, 0x04, 0x83, 0xc9, 0x24, 0xee, 0x21, 0x1c,
-  0x44, 0x62, 0x34, 0x21, 0x00, 0x2a, 0x70, 0x05, 0x19, 0x4d, 0x18, 0x82,
-  0x0a, 0x50, 0x41, 0x2a, 0x18, 0x60, 0x34, 0xc1, 0x10, 0x2a, 0x88, 0x05,
-  0xa9, 0x21, 0x80, 0x11, 0x03, 0x05, 0x00, 0x41, 0x30, 0x70, 0x62, 0xe2,
-  0x1f, 0xd6, 0x21, 0x48, 0x09, 0x7a, 0x58, 0x89, 0xd1, 0x84, 0x00, 0xb8,
-  0xc0, 0xc0, 0xd1, 0x04, 0x3d, 0x18, 0x86, 0x1b, 0x82, 0x97, 0x00, 0x83,
-  0x59, 0x06, 0xa6, 0x09, 0x46, 0x13, 0x9c, 0xa1, 0x82, 0x03, 0x46, 0x0c,
-  0x14, 0x00, 0x04, 0xc1, 0xc0, 0xc1, 0x09, 0x93, 0x90, 0x87, 0x00, 0x26,
-  0xf6, 0x41, 0x26, 0x46, 0x13, 0x02, 0xe0, 0x02, 0x03, 0x67, 0x09, 0x9a,
-  0x81, 0x0e, 0x03, 0x62, 0xa0, 0x05, 0x3e, 0x94, 0x81, 0x0e, 0x83, 0x60,
-  0xe4, 0x63, 0x91, 0x0f, 0xc5, 0x04, 0x4c, 0x3e, 0x26, 0x60, 0xf2, 0x31,
-  0x21, 0x88, 0x8f, 0x15, 0x9a, 0x7c, 0xac, 0xe0, 0xe4, 0x63, 0x81, 0x00,
-  0x9f, 0x82, 0x87, 0x96, 0x80, 0x3a, 0x87, 0x40, 0x47, 0x13, 0xf8, 0x61,
-  0x18, 0x6e, 0x08, 0xc2, 0x02, 0x0c, 0xa6, 0x1b, 0x52, 0x02, 0x25, 0x82,
-  0x23, 0x8c, 0x32, 0x21, 0x90, 0xcf, 0xdd, 0x83, 0x51, 0x26, 0x04, 0xf4,
-  0x19, 0x31, 0x30, 0x00, 0x10, 0x04, 0x83, 0xa3, 0x2d, 0xc6, 0x22, 0x18,
-  0x31, 0x30, 0x00, 0x10, 0x04, 0x83, 0xc3, 0x2d, 0x6c, 0x42, 0x18, 0x31,
-  0x38, 0x00, 0x10, 0x04, 0x83, 0x89, 0x2d, 0x7a, 0xe2, 0x24, 0xc6, 0x62,
-  0x34, 0x21, 0x10, 0x2a, 0x28, 0x09, 0x19, 0x4d, 0x18, 0x86, 0x12, 0x02,
-  0x18, 0x31, 0x38, 0x00, 0x10, 0x04, 0x03, 0x0b, 0x2e, 0xc4, 0x82, 0x25,
-  0x7c, 0x62, 0x34, 0x21, 0x10, 0x2c, 0xb1, 0xe4, 0x63, 0x09, 0x25, 0x1f,
-  0x2b, 0x05, 0x52, 0x88, 0x8f, 0x05, 0x03, 0x7c, 0x2c, 0x18, 0xe2, 0x63,
-  0x46, 0x20, 0x1f, 0x7b, 0x32, 0xf9, 0xd8, 0xd3, 0xc9, 0xc7, 0x50, 0x21,
-  0x15, 0xe0, 0x63, 0xc1, 0x00, 0x1f, 0x0b, 0x06, 0xf8, 0x18, 0x13, 0xc8,
-  0x67, 0x34, 0xc1, 0x09, 0x86, 0x23, 0x82, 0x9f, 0x08, 0xbe, 0x59, 0x86,
-  0xc7, 0x09, 0x6c, 0xdb, 0xe4, 0x63, 0x01, 0x59, 0xc8, 0xc7, 0x02, 0x82,
-  0x3e, 0x23, 0x06, 0x06, 0x00, 0x82, 0x60, 0x70, 0x9c, 0x46, 0x5f, 0x04,
-  0x23, 0x06, 0x06, 0x00, 0x82, 0x60, 0x70, 0xa0, 0x06, 0x5c, 0x08, 0xb3,
-  0x04, 0xcf, 0x40, 0x85, 0x41, 0x38, 0xac, 0xd2, 0x0c, 0x54, 0x18, 0x84,
-  0xc3, 0x2a, 0x8d, 0x09, 0x90, 0x7c, 0x4c, 0x58, 0xe4, 0x63, 0x42, 0x10,
-  0x9f, 0x0b, 0x92, 0x1b, 0x31, 0x70, 0x00, 0x10, 0x04, 0x03, 0xa8, 0x35,
-  0xec, 0x22, 0x2d, 0x3c, 0xd3, 0x08, 0xda, 0xa2, 0x2d, 0xda, 0x22, 0x2e,
-  0x50, 0xc3, 0x0a, 0x4a, 0x3e, 0x76, 0x3c, 0xf2, 0x31, 0x21, 0x80, 0xcf,
-  0x05, 0xc9, 0x8d, 0x18, 0x38, 0x00, 0x08, 0x82, 0x01, 0x14, 0x1b, 0x7a,
-  0xd1, 0x16, 0x60, 0xa0, 0x1a, 0x41, 0x5c, 0xc4, 0x45, 0x5c, 0xd4, 0x05,
-  0x6b, 0xcc, 0x12, 0x40, 0x18, 0x10, 0x03, 0x00, 0x09, 0x00, 0x00, 0x00,
-  0x5b, 0x06, 0x34, 0x78, 0xc0, 0x60, 0xcb, 0xd0, 0x07, 0x4f, 0x18, 0x6c,
-  0x19, 0x58, 0xe1, 0x11, 0x83, 0x2d, 0xc3, 0x2e, 0x3c, 0x60, 0xb0, 0x65,
-  0x70, 0x87, 0x27, 0x0c, 0xb6, 0x0c, 0xfd, 0xf0, 0x88, 0x01, 0x00, 0x00,
-  0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00
+  0x00, 0x82, 0x60, 0x30, 0xb1, 0x41, 0x77, 0x60, 0xa3, 0x09, 0xc1, 0x50,
+  0xc1, 0x1a, 0xe0, 0x68, 0x02, 0x22, 0x54, 0xd0, 0x69, 0xb9, 0x41, 0x70,
+  0x35, 0x7c, 0x50, 0x01, 0x18, 0xa8, 0xc5, 0x41, 0x70, 0x81, 0x11, 0x83,
+  0x03, 0x00, 0x41, 0x30, 0x98, 0xea, 0xc0, 0x0c, 0xa0, 0x6e, 0x34, 0x21,
+  0x08, 0x46, 0x13, 0x04, 0xa1, 0x02, 0x41, 0x0a, 0x0a, 0xaa, 0x22, 0x61,
+  0x4a, 0x20, 0xa6, 0x86, 0xa2, 0x4a, 0x68, 0xb0, 0x82, 0xe5, 0x6a, 0x51,
+  0x03, 0xa8, 0x22, 0xd0, 0x1a, 0x02, 0xa8, 0x80, 0x82, 0xd1, 0x84, 0x0b,
+  0x18, 0x6e, 0x08, 0x42, 0x01, 0x0c, 0x46, 0x0c, 0x0e, 0x00, 0x04, 0xc1,
+  0x60, 0x2a, 0x05, 0x3b, 0x00, 0x03, 0x36, 0x18, 0x4d, 0x08, 0x84, 0xe1,
+  0x06, 0x23, 0x20, 0x83, 0x22, 0x02, 0x9d, 0x65, 0x10, 0x88, 0x60, 0xc4,
+  0xe0, 0x00, 0x40, 0x10, 0x0c, 0xa6, 0x54, 0xd0, 0x03, 0x32, 0x20, 0x85,
+  0xd1, 0x84, 0x40, 0xa8, 0x40, 0x91, 0x11, 0x03, 0x05, 0x00, 0x41, 0x30,
+  0x70, 0x5c, 0x81, 0x0f, 0xd4, 0x20, 0x30, 0x85, 0x38, 0x40, 0x85, 0xd1,
+  0x84, 0x00, 0xb8, 0xc0, 0xc0, 0xd1, 0x04, 0x65, 0x18, 0x6e, 0x08, 0x58,
+  0x01, 0x0c, 0x66, 0x19, 0x06, 0x22, 0x18, 0x4d, 0x40, 0x86, 0x0a, 0x0e,
+  0x18, 0x31, 0x50, 0x00, 0x10, 0x04, 0x03, 0xa7, 0x16, 0x46, 0x21, 0x0e,
+  0x82, 0x56, 0xc0, 0x83, 0x57, 0x18, 0x4d, 0x08, 0x80, 0x0b, 0x0c, 0x9c,
+  0x25, 0x20, 0x06, 0x3a, 0x0c, 0x68, 0x80, 0x04, 0x76, 0x08, 0x06, 0x3a,
+  0x0c, 0x62, 0x60, 0x07, 0x81, 0x1d, 0x82, 0x11, 0x83, 0x03, 0x00, 0x41,
+  0x30, 0x98, 0x70, 0x21, 0x15, 0xe6, 0xa0, 0x16, 0x46, 0x13, 0x82, 0x60,
+  0xb8, 0x21, 0x0b, 0xc8, 0x60, 0x96, 0xa1, 0x38, 0x82, 0x11, 0x83, 0x03,
+  0x00, 0x41, 0x30, 0x98, 0x76, 0x81, 0x15, 0xec, 0xe0, 0x16, 0x46, 0x13,
+  0x02, 0xa0, 0x82, 0x31, 0x90, 0xd1, 0x84, 0x21, 0xa8, 0xc0, 0x93, 0x0a,
+  0x06, 0x18, 0x4d, 0x30, 0x84, 0x0a, 0xcc, 0x40, 0x6a, 0x08, 0x60, 0xc4,
+  0x40, 0x01, 0x40, 0x10, 0x0c, 0x1c, 0x73, 0xa0, 0x05, 0x50, 0x08, 0x7c,
+  0x21, 0x15, 0xc0, 0x61, 0x34, 0x21, 0x00, 0x2e, 0x30, 0x70, 0x34, 0xe1,
+  0x19, 0x86, 0x1b, 0x02, 0x72, 0x00, 0x83, 0x59, 0x06, 0xe3, 0x08, 0x46,
+  0x13, 0x9c, 0xa1, 0x82, 0x03, 0x46, 0x0c, 0x14, 0x00, 0x04, 0xc1, 0xc0,
+  0x69, 0x87, 0x5d, 0x38, 0x85, 0xa0, 0x1c, 0x60, 0xe1, 0x1c, 0x46, 0x13,
+  0x02, 0xe0, 0x02, 0x03, 0x67, 0x09, 0x8e, 0x81, 0x0e, 0x03, 0x32, 0xa0,
+  0x42, 0x27, 0x88, 0x81, 0x0e, 0x83, 0x30, 0x78, 0xa2, 0xe0, 0x09, 0xc2,
+  0x04, 0x4d, 0x3e, 0x26, 0x68, 0xf2, 0x31, 0x62, 0x93, 0x8f, 0x11, 0x9d,
+  0x7c, 0x86, 0x1b, 0xe4, 0xc0, 0x0d, 0xc8, 0xa0, 0xe2, 0x20, 0xd0, 0x59,
+  0x06, 0x44, 0x09, 0x46, 0x0c, 0x0e, 0x00, 0x04, 0xc1, 0x60, 0xb2, 0x87,
+  0x73, 0x88, 0x85, 0x78, 0x18, 0x4d, 0x08, 0x84, 0x0a, 0xec, 0x40, 0x46,
+  0x0c, 0x14, 0x00, 0x04, 0xc1, 0xc0, 0xd9, 0x87, 0x74, 0xb8, 0x85, 0x60,
+  0x1e, 0x7c, 0xa1, 0x1e, 0x46, 0x13, 0x02, 0xe0, 0x02, 0x03, 0x47, 0x13,
+  0xee, 0x60, 0x18, 0x6e, 0x08, 0xf2, 0x01, 0x0c, 0x66, 0x19, 0x12, 0x25,
+  0x18, 0x4d, 0x40, 0x86, 0x0a, 0x0e, 0x18, 0x31, 0x50, 0x00, 0x10, 0x04,
+  0x03, 0x47, 0x24, 0xe0, 0xc1, 0x17, 0x02, 0x7d, 0x28, 0x07, 0x7e, 0x18,
+  0x4d, 0x08, 0x80, 0x0b, 0x0c, 0x9c, 0x25, 0x50, 0x06, 0x3a, 0x0c, 0x28,
+  0x81, 0x10, 0xd3, 0x38, 0x06, 0x3a, 0x0c, 0x22, 0x31, 0x0d, 0xc4, 0x34,
+  0x8e, 0xe1, 0x86, 0x51, 0x60, 0x03, 0x32, 0x98, 0x65, 0x58, 0x9a, 0x60,
+  0xc4, 0xe0, 0x00, 0x40, 0x10, 0x0c, 0x26, 0x93, 0xb8, 0x87, 0x70, 0x10,
+  0x89, 0xd1, 0x84, 0x00, 0xa8, 0xc0, 0x15, 0x64, 0x34, 0x61, 0x08, 0x2a,
+  0x40, 0x05, 0xa9, 0x60, 0x80, 0xd1, 0x04, 0x43, 0xa8, 0x20, 0x16, 0xa4,
+  0x86, 0x00, 0x46, 0x0c, 0x14, 0x00, 0x04, 0xc1, 0xc0, 0x89, 0x89, 0x7f,
+  0x58, 0x87, 0x20, 0x25, 0xe8, 0x61, 0x25, 0x46, 0x13, 0x02, 0xe0, 0x02,
+  0x03, 0x47, 0x13, 0xf4, 0x60, 0x18, 0x6e, 0x08, 0x5e, 0x02, 0x0c, 0x66,
+  0x19, 0x98, 0x26, 0x18, 0x4d, 0x70, 0x86, 0x0a, 0x0e, 0x18, 0x31, 0x50,
+  0x00, 0x10, 0x04, 0x03, 0x07, 0x27, 0x4c, 0x42, 0x1e, 0x02, 0x98, 0xd8,
+  0x07, 0x99, 0x18, 0x4d, 0x08, 0x80, 0x0b, 0x0c, 0x9c, 0x25, 0x68, 0x06,
+  0x3a, 0x0c, 0x88, 0x81, 0x16, 0xf8, 0x50, 0x06, 0x3a, 0x0c, 0x82, 0x91,
+  0x8f, 0x45, 0x3e, 0x14, 0x13, 0x30, 0xf9, 0x98, 0x80, 0xc9, 0xc7, 0x84,
+  0x20, 0x3e, 0x56, 0x68, 0xf2, 0xb1, 0x82, 0x93, 0x8f, 0x05, 0x02, 0x7c,
+  0x0a, 0x1e, 0x58, 0x02, 0xea, 0x1c, 0x02, 0x1d, 0x4d, 0xe0, 0x87, 0x61,
+  0xb8, 0x21, 0x08, 0x0b, 0x30, 0x98, 0x6e, 0x48, 0x09, 0x94, 0x08, 0x8e,
+  0x30, 0xca, 0x84, 0x40, 0x3e, 0x77, 0x0f, 0x46, 0x99, 0x10, 0xd0, 0x67,
+  0xc4, 0xc0, 0x00, 0x40, 0x10, 0x0c, 0x8e, 0xb6, 0x18, 0x8b, 0x60, 0xc4,
+  0xc0, 0x00, 0x40, 0x10, 0x0c, 0x0e, 0xb7, 0xa8, 0x09, 0x61, 0xc4, 0xe0,
+  0x00, 0x40, 0x10, 0x0c, 0x26, 0xb6, 0xe8, 0x89, 0x93, 0x18, 0x8b, 0xd1,
+  0x84, 0x40, 0xa8, 0xa0, 0x24, 0x64, 0x34, 0x61, 0x18, 0x4a, 0x08, 0x60,
+  0xc4, 0xe0, 0x00, 0x40, 0x10, 0x0c, 0x2c, 0xb8, 0x10, 0x0b, 0x96, 0xe8,
+  0x89, 0xd1, 0x84, 0x40, 0xb0, 0xc4, 0x92, 0x8f, 0x25, 0x94, 0x7c, 0xac,
+  0x14, 0x48, 0x21, 0x3e, 0x16, 0x0c, 0xf0, 0xb1, 0x60, 0x88, 0x8f, 0x19,
+  0x81, 0x7c, 0xec, 0xc9, 0xe4, 0x63, 0x4f, 0x27, 0x1f, 0x43, 0x85, 0x54,
+  0x80, 0x8f, 0x05, 0x03, 0x7c, 0x2c, 0x18, 0xe0, 0x63, 0x4c, 0x20, 0x9f,
+  0xd1, 0x04, 0x27, 0x18, 0x8e, 0x08, 0x7e, 0x22, 0xf8, 0x66, 0x19, 0x9c,
+  0x27, 0xb8, 0x24, 0xb9, 0x11, 0x03, 0x07, 0x00, 0x41, 0x30, 0x80, 0x46,
+  0x03, 0x2e, 0x7e, 0x82, 0xe2, 0x8b, 0x60, 0x2c, 0xc6, 0x62, 0x2c, 0xce,
+  0xc2, 0x2f, 0x8e, 0x48, 0x6e, 0xc4, 0xc0, 0x01, 0x40, 0x10, 0x0c, 0x20,
+  0xd2, 0x88, 0x0b, 0xb0, 0x98, 0xfa, 0x22, 0x20, 0x0b, 0xb2, 0x20, 0x0b,
+  0xb4, 0xf8, 0x8b, 0x59, 0x02, 0xc8, 0xba, 0x4e, 0x3e, 0x16, 0x98, 0x85,
+  0x7c, 0x2c, 0x30, 0xe8, 0x33, 0x62, 0x60, 0x00, 0x20, 0x08, 0x06, 0x47,
+  0x6a, 0xfc, 0x45, 0x30, 0x62, 0x60, 0x00, 0x20, 0x08, 0x06, 0x87, 0x6a,
+  0xc4, 0x85, 0x60, 0x02, 0x24, 0x1f, 0x13, 0x16, 0xf9, 0x98, 0x10, 0xc4,
+  0xe7, 0x82, 0xe4, 0x46, 0x0c, 0x1c, 0x00, 0x04, 0xc1, 0x00, 0x6a, 0x0d,
+  0xbd, 0x48, 0x0b, 0xcf, 0x34, 0x82, 0xb6, 0x68, 0x8b, 0xb6, 0x88, 0x0b,
+  0xd4, 0xb0, 0x82, 0x92, 0x8f, 0x1d, 0x8f, 0x7c, 0x4c, 0x08, 0xe0, 0x73,
+  0x41, 0x72, 0x23, 0x06, 0x0e, 0x00, 0x82, 0x60, 0x00, 0xc5, 0x86, 0x5f,
+  0xb4, 0x05, 0x18, 0xa8, 0x46, 0x10, 0x17, 0x71, 0x11, 0x17, 0x75, 0xc1,
+  0x1a, 0xb3, 0x04, 0x10, 0x06, 0xc4, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00,
+  0x5b, 0x06, 0x34, 0x78, 0xc0, 0x60, 0xcb, 0x10, 0x07, 0x4f, 0x18, 0x6c,
+  0x19, 0xfa, 0xe0, 0x11, 0x83, 0x2d, 0x03, 0x2b, 0x3c, 0x63, 0xb0, 0x65,
+  0xd8, 0x85, 0x07, 0x0c, 0xb6, 0x0c, 0xe4, 0xf0, 0x84, 0xc1, 0x96, 0xc1,
+  0x1d, 0x1e, 0x31, 0xd8, 0x32, 0xf4, 0xc3, 0x33, 0x06, 0x5b, 0x06, 0xb6,
+  0x78, 0xc8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00
 };
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index 9c136ed8c9484..71fc8741bfdc8 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -435,6 +435,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Dropout);
 DML_OP_EXTERN_CREATION_FUNCTION(MatMul);
 DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMul);
 DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMulActivation);
+DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeMatMul);
 DML_OP_EXTERN_CREATION_FUNCTION(Cast);
 DML_OP_EXTERN_CREATION_FUNCTION(CastLike15);
 DML_OP_EXTERN_CREATION_FUNCTION(CastLike19);
@@ -503,6 +504,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul);
 DML_OP_EXTERN_CREATION_FUNCTION(QLinearConcat);
 DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear);
 DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger);
+DML_OP_EXTERN_CREATION_FUNCTION(MatMulIntegerToFloat);
 DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger);
 DML_OP_EXTERN_CREATION_FUNCTION(Trilu);
 
@@ -514,6 +516,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Resize19);
 
 DML_OP_EXTERN_CREATION_FUNCTION(Shape);
 DML_OP_EXTERN_CREATION_FUNCTION(Size);
+DML_OP_EXTERN_CREATION_FUNCTION(QAttention);
 DML_OP_EXTERN_CREATION_FUNCTION(Attention);
 DML_OP_EXTERN_CREATION_FUNCTION(MultiHeadAttention);
 DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
@@ -535,6 +538,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Pad);
 DML_OP_EXTERN_QUERY_FUNCTION(LayerNormalization);
 DML_OP_EXTERN_QUERY_FUNCTION(SkipLayerNormalization);
 DML_OP_EXTERN_QUERY_FUNCTION(QLinearSigmoid);
+DML_OP_EXTERN_QUERY_FUNCTION(QAttention);
 DML_OP_EXTERN_QUERY_FUNCTION(Attention);
 
 constexpr static std::array<const char*, 1> typeNameListDefault = {"T"};
@@ -612,20 +616,35 @@ constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLayerN
 constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListShape = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
 constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
 constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8};
+
+constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQAttention = {
+    SupportedTensorDataTypes::Ints8Bit,
+    SupportedTensorDataTypes::Ints8Bit,
+    SupportedTensorDataTypes::Float16to32,
+    SupportedTensorDataTypes::Int32
+};
+
 constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32};
 constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRotaryEmbedding = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
 constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
 constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool};
 
 constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearMatMul = {
-    SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
-    SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
-    SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
+    SupportedTensorDataTypes::Ints8Bit,
+    SupportedTensorDataTypes::Ints8Bit,
+    SupportedTensorDataTypes::Ints8Bit
 };
+
+constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListMatMulIntegerToFloat = {
+    SupportedTensorDataTypes::Ints8Bit,
+    SupportedTensorDataTypes::Ints8Bit,
+    SupportedTensorDataTypes::Float16to32
+};
+
 constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQLinearConv = {
-    SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
-    SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
-    SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
+    SupportedTensorDataTypes::Ints8Bit,
+    SupportedTensorDataTypes::Ints8Bit,
+    SupportedTensorDataTypes::Ints8Bit,
     SupportedTensorDataTypes::Int32
 };
 
@@ -1057,8 +1076,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
     {REG_INFO_MS(   1,  Gelu,                               typeNameListDefault,            supportedTypeListFloat16to32,           DmlGraphSupport::Supported)},
     {REG_INFO_MS(   1,  BiasGelu,                           typeNameListDefault,            supportedTypeListFloat16to32,           DmlGraphSupport::Supported)},
     {REG_INFO_MS(   1,  FusedMatMul,                        typeNameListDefault,            supportedTypeListFloat16to32,           DmlGraphSupport::Supported)},
+    {REG_INFO_MS(   1,  DynamicQuantizeMatMul,              typeNameListTwo,                supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
     {REG_INFO_MS(   1,  FusedMatMulActivation,              typeNameListDefault,            supportedTypeListFloat16to32,           DmlGraphSupport::Supported)},
     {REG_INFO_MS(   1,  QLinearSigmoid,                     typeNameListDefault,            supportedTypeListQLinearSigmoid,        DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)},
+    {REG_INFO_MS(   1,  QAttention,                         typeNameListFour,               supportedTypeListQAttention,            DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQAttention)},
     {REG_INFO_MS(   1,  Attention,                          typeNameListAttention,          supportedTypeListAttention,             DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
     {REG_INFO_MS(   1,  MultiHeadAttention,                 typeNameListAttention,          supportedTypeListAttention,             DmlGraphSupport::Supported)},
     {REG_INFO_MS(   1,  RotaryEmbedding,                    typeNameListRotaryEmbedding,    supportedTypeListRotaryEmbedding,       DmlGraphSupport::Supported)},
@@ -1083,6 +1104,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
     {REG_INFO(     10,  QLinearConv,                        typeNameListFour,               supportedTypeListQLinearConv,           DmlGraphSupport::Supported)},
     {REG_INFO(     10,  QLinearMatMul,                      typeNameListThree,              supportedTypeListQLinearMatMul,         DmlGraphSupport::Supported)},
     {REG_INFO(     10,  MatMulInteger,                      typeNameListThree,              supportedTypeListInteger,               DmlGraphSupport::Supported)},
+    {REG_INFO_MS(   1,  MatMulIntegerToFloat,               typeNameListThree,              supportedTypeListMatMulIntegerToFloat,  DmlGraphSupport::Supported)},
     {REG_INFO(     10,  ConvInteger,                        typeNameListThree,              supportedTypeListInteger,               DmlGraphSupport::Supported)},
     {REG_INFO(     11,  DynamicQuantizeLinear,              typeNameListTwo,                supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
     {REG_INFO(      7,  LayerNormalization,                 typeNameListLayerNormContrib,   supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/Shaders/stockham.hlsl b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/Shaders/stockham.hlsl
index 01e62b0727520..c8a006c7e12e0 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/Shaders/stockham.hlsl
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/Shaders/stockham.hlsl
@@ -41,20 +41,21 @@ float2 ReadSourceValue(uint3 index)
     float2 value = float2(0, 0);
 
     bool hasWindow = HasWindow == 1;
-    [branch]
+    [flatten]
     if (hasWindow && index.y < (uint)WindowSizes[2])
     {
         uint windowIndexReal = index.y * WindowStrides[2];
         window_value.x = window[windowIndexReal];
 
         uint windowIndexImaginary = windowIndexReal + WindowStrides[3];
+        [branch]
         if (WindowSizes[3] == 2)
         {
             window_value.y = window[windowIndexImaginary];
         }
     }
 
-    [branch]
+    [flatten]
     if (index.y < (uint)InputSizes[1])
     {
         uint indexReal =
@@ -108,7 +109,7 @@ void DFT(uint3 dtid : SV_DispatchThreadId)
     uint index = StartIndex + dtid.x;
     if (index < ElementCount)
     {
-        uint halfTotalDFTLength = DFTLength / 2;
+        uint halfTotalDFTLength = DFTLength >> 1;
         uint N = 1U << DFTIteration;
         uint halfN = 1U << (DFTIteration - 1);
 
@@ -143,8 +144,16 @@ void DFT(uint3 dtid : SV_DispatchThreadId)
         unweighted.y = Scale * (inputEvenValue.y + (w.x * inputOddValue.y + w.y * inputOddValue.x));
 
         // When ChirpLength is 0, then chirp should evaluate to (1,0), which is a no-op.
-        float2 chirp = CalculateChirp(k, ChirpLength);
-        dst[outputIndex.x] = (TBUFFER)(unweighted.x * chirp.x - unweighted.y * chirp.y);
-        dst[outputIndex.y] = (TBUFFER)(unweighted.x * chirp.y + unweighted.y * chirp.x);
+        [branch]
+        if (ChirpLength == 0)
+        {
+            dst[outputIndex.x] = (TBUFFER)(unweighted.x);
+            dst[outputIndex.y] = (TBUFFER)(unweighted.y);
+        }
+        else {
+            float2 chirp = CalculateChirp(k, ChirpLength);
+            dst[outputIndex.x] = (TBUFFER)(unweighted.x * chirp.x - unweighted.y * chirp.y);
+            dst[outputIndex.y] = (TBUFFER)(unweighted.x * chirp.y + unweighted.y * chirp.x);
+        }
     }
 }
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h
new file mode 100644
index 0000000000000..02166f992449e
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h
@@ -0,0 +1,141 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include <string>
+#include <string_view>
+#include <locale>
+#include <codecvt>
+        
+
+namespace Dml
+{
+    static inline std::wstring ConvertToWString(std::string_view str)
+    {
+        std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>,wchar_t> g_converterToUtf16;
+        return g_converterToUtf16.from_bytes(str.data());
+    }
+
+    static inline std::wstring GetModelName(const onnxruntime::Path& modelPath)
+    {
+        if (modelPath.GetComponents().empty())
+        {
+            return L"";
+        }
+        
+        const onnxruntime::PathString& pathString = modelPath.GetComponents().back();
+        size_t dotPosition = pathString.find_last_of('.');
+        if (dotPosition == std::string::npos)
+        {
+            return L"";
+        }
+
+        return pathString.substr(0, dotPosition);
+    }
+
+    static inline std::wstring GetSanitizedFileName(std::wstring_view name)
+    {
+        std::wstring newName(name);
+        for (wchar_t& c : newName)
+        {
+            switch (c)
+            {
+            case '\\':
+            case '/':
+            case '\"':
+            case '|':
+            case '<':
+            case '>':
+            case ':':
+            case '?':
+            case '*':
+                c = '_';
+                break;
+            }
+        }
+        return newName;
+    }
+
+    static inline std::string GetSanitizedFileName(std::string_view name)
+    {
+        std::string newName(name);
+        for (char& c : newName)
+        {
+            switch (c)
+            {
+            case '\\':
+            case '/':
+            case '\"':
+            case '|':
+            case '<':
+            case '>':
+            case ':':
+            case '?':
+            case '*':
+                c = '_';
+                break;
+            }
+        }
+        return newName;
+    }
+
+    static inline void WriteToFile(std::wstring_view directoryName, std::wstring_view fileName, std::uint8_t* data, size_t dataSize)
+    {
+        std::wstring sanitizedFileName = GetSanitizedFileName(fileName);
+        std::filesystem::create_directory(directoryName);
+        std::wstring fullSanitizedFileName = std::wstring(directoryName) +
+                                (directoryName.empty() ? L"" : L"/") +
+                                sanitizedFileName;
+        std::ofstream file(fullSanitizedFileName, std::ios::binary);
+        if (!file.is_open()) 
+        {
+            std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>,wchar_t> g_converterToUtf16;
+            std::stringstream errorMessage;
+            errorMessage << "File named: " << g_converterToUtf16.to_bytes(fileName.data()) << " could not be opened\n";
+            throw std::ios::failure(errorMessage.str());
+        }
+        file.write(reinterpret_cast<const char*>(data), dataSize);
+    }
+
+}
+
+namespace StringUtil
+{
+    struct NameAndIndex
+    {
+        const char* name; // Null terminated.
+        uint32_t index;
+    };
+
+    struct WideNameAndIndex
+    {
+        const wchar_t* name; // Null terminated.
+        uint32_t index;
+    };
+
+    inline std::optional<uint32_t> MapToIndex(std::string_view mode, gsl::span<const NameAndIndex> nameAndIndexList)
+    {
+        for (auto& nameAndIndex : nameAndIndexList)
+        {
+            if (strncmp(nameAndIndex.name, mode.data(), mode.size()) == 0)
+            {
+                return nameAndIndex.index;
+            }
+        }
+
+        return {};
+    }
+
+    inline std::optional<uint32_t> MapToIndex(std::wstring_view mode, gsl::span<const WideNameAndIndex> nameAndIndexList)
+    {
+        for (auto& nameAndIndex : nameAndIndexList)
+        {
+            if (wcsncmp(nameAndIndex.name, mode.data(), mode.size()) == 0)
+            {
+                return nameAndIndex.index;
+            }
+        }
+
+        return {};
+    }
+}
\ No newline at end of file
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h
index 83737d2ba4848..1a796b25c5d1f 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h
@@ -17,6 +17,8 @@
 #include <chrono>
 #include <variant>
 #include <cassert>
+#include <fstream>
+#include <filesystem>
 
 #include <wrl/client.h>
 #include <wrl/implements.h>
@@ -37,6 +39,7 @@
 #include <d3d12sdklayers.h>
 #include "External/D3DX12/d3dx12.h"
 #endif
+#include "core/common/flatbuffers.h"
 
 #include "GraphicsUnknownHelper.h"
 
@@ -53,6 +56,9 @@
 #include "External/DirectMLHelpers/SchemaHelpers.h"
 #include "External/DirectMLHelpers/GeneratedSchemaHelpers.h"
 #include "External/DirectMLHelpers/DirectMLX.h"
+#include "External/DirectMLHelpers/DmlSerializedGraphDesc.h"
+#include "External/DirectMLHelpers/DmlGraphSerialization.h"
+#include "External/DirectMLHelpers/DmlGraphDeserialization.h"
 
 using Microsoft::WRL::ComPtr;
 
@@ -67,3 +73,4 @@ using Microsoft::WRL::ComPtr;
 #include "TensorDesc.h"
 #include "DescriptorPool.h"
 #include "IExecutionProvider.h"
+#include "Utility.h"
\ No newline at end of file
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h
index 3bec8d3864cba..ac3a3eb1268b8 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h
@@ -10,18 +10,11 @@ struct DML_INPUT_GRAPH_EDGE_DESC;
 struct DML_OUTPUT_GRAPH_EDGE_DESC;
 struct DML_INTERMEDIATE_GRAPH_EDGE_DESC;
 
-// Either nodesAsOpDesc or nodesAsIDMLOperator is present.
-//  1) Operator kernels which implement operators using only a single DML operator will pass a DML_OPERATOR_DESC.
-//     These kernels pass DML_OPERATOR_DESC, because while building Dml graph (inside FusedGraphKernel.cpp) we can change the
-//     the flag of constant inputs to DML_TENSOR_FLAG_OWNED_BY_DML.
-//  2) Operator kernels which implement operators using DMLX graph, they will pass IDMLOperator and won't be able
-//     to use DML_TENSOR_FLAG_OWNED_BY_DML.
 struct MLOperatorGraphDesc
 {
     uint32_t nodeCount;
-    _Field_size_opt_(nodeCount) const DML_OPERATOR_DESC** nodesAsOpDesc;
-    _Field_size_opt_(nodeCount) IDMLOperator** nodesAsIDMLOperator;
-
+    _Field_size_opt_(nodeCount) const DML_OPERATOR_DESC** nodes;
+    
     uint32_t inputEdgeCount;
     _Field_size_(inputEdgeCount) const DML_INPUT_GRAPH_EDGE_DESC* inputEdges;
 
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
index 317f5ebcbc3e1..acda1a516be09 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
@@ -2802,6 +2802,48 @@ namespace OperatorHelper
         m_qkvHiddenSizes = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes);
     }
 
+    std::vector<EdgeShapes> QAttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
+    {
+        ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 5);
+
+        auto queryShape = shapeInfo.GetInputTensorShape(0);
+        ML_CHECK_VALID_ARGUMENT(queryShape.size() == 3);
+
+        auto weightShape = shapeInfo.GetInputTensorShape(1);
+        ML_CHECK_VALID_ARGUMENT(weightShape.size() == 2);
+        ML_CHECK_VALID_ARGUMENT(weightShape[1] % 3 == 0);
+
+        const uint32_t batchSize = queryShape[0];
+        const uint32_t sequenceLength = queryShape[1];
+        const uint32_t hiddenSize = weightShape[1] / 3;
+        const uint32_t headSize = hiddenSize / m_numHeads;
+
+        std::vector<EdgeShapes> outputShapes(2);
+
+        outputShapes[0] = EdgeShapes({batchSize, sequenceLength, hiddenSize});
+
+        uint32_t totalSequenceLength = sequenceLength;
+        if (shapeInfo.IsInputValid(8))
+        {
+            ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputTensorDimensionCount(8) == 5);
+            const uint32_t pastSequenceLength = shapeInfo.GetInputTensorShape(8)[3];
+            totalSequenceLength += pastSequenceLength;
+        }
+
+        if (shapeInfo.IsOutputValid(1))
+        {
+            ML_CHECK_VALID_ARGUMENT(shapeInfo.IsInputValid(8));
+            outputShapes[1] = EdgeShapes({2, batchSize, m_numHeads, totalSequenceLength, headSize});
+        }
+
+        return outputShapes;
+    }
+
+    void QAttentionHelper::Initialize(const IKernelInformationAdapter& kernelInformation)
+    {
+        m_numHeads = gsl::narrow_cast<uint32_t>(kernelInformation.GetAttributes().GetAttribute<int64_t>(AttrName::NumHeads));
+    }
+
     std::vector<EdgeShapes> SkipLayerNormHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
     {
         ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 3);
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index 1b2521a86613f..aff31bb3050a7 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -870,7 +870,6 @@ class QLinearMatMulHelper : public MatMulHelperBase
     QLinearMatMulHelper(const Info_t& info, const Shape_t& shape) : MatMulHelperBase(info, shape, 0, 3) {}
 };
 
-
 class TopKHelper
 {
     void Initialize(
@@ -1555,6 +1554,22 @@ class AttentionHelper
     std::vector<int32_t> m_qkvHiddenSizes;
 };
 
+class QAttentionHelper
+{
+public:
+    template <typename Info_t, typename Shape_t>
+    QAttentionHelper(const Info_t& info, const Shape_t& shapeInfo)
+    {
+        Initialize(KernelInformationAdapter(info));
+    }
+
+    std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+
+private:
+    void Initialize(const IKernelInformationAdapter& kernelInformation);
+    uint32_t m_numHeads;
+};
+
 class SkipLayerNormHelper
 {
 public:
@@ -1700,6 +1715,7 @@ using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper;
 using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper;
 using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
 using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
+using ShapeInferenceHelper_QAttention = QAttentionHelper;
 using ShapeInferenceHelper_Attention = AttentionHelper;
 using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper;
 using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper;
@@ -1776,6 +1792,8 @@ using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper;
 using ShapeInferenceHelper_Identity19 = GetOutputShapeAsInputShapeHelper;
 using ShapeInferenceHelper_MatMul = MatMulHelper;
 using ShapeInferenceHelper_MatMulInteger = MatMulHelper;
+using ShapeInferenceHelper_MatMulIntegerToFloat = MatMulHelper;
+using ShapeInferenceHelper_DynamicQuantizeMatMul = MatMulHelper;
 using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper;
 using ShapeInferenceHelper_QLinearAdd = GetBroadcastedOutputShapeHelper;
 using ShapeInferenceHelper_DynamicQuantizeLinear = GetOutputShapeAsInputShapeHelper;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index e725ba085113d..7492b729425a5 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -448,7 +448,9 @@ namespace OperatorHelper
         static const int sc_sinceVer_FusedMatMul = 1;
         static const int sc_sinceVer_FusedMatMulActivation = 1;
         static const int sc_sinceVer_QLinearSigmoid = 1;
+        static const int sc_sinceVer_QAttention = 1;
         static const int sc_sinceVer_Attention = 1;
+        static const int sc_sinceVer_MatMulIntegerToFloat = 1;
         static const int sc_sinceVer_MultiHeadAttention = 1;
         static const int sc_sinceVer_SkipLayerNormalization = 1;
         static const int sc_sinceVer_EmbedLayerNormalization = 1;
@@ -461,6 +463,7 @@ namespace OperatorHelper
         static const int sc_sinceVer_RotaryEmbedding = 1;
         static const int sc_sinceVer_QLinearAveragePool = 1;
         static const int sc_sinceVer_QLinearGlobalAveragePool = 1;
+        static const int sc_sinceVer_DynamicQuantizeMatMul = 1;
     } // namespace MsftOperatorSet1
 
 } // namespace OperatorHelper
diff --git a/onnxruntime/core/providers/dml/dml_session_options_config_keys.h b/onnxruntime/core/providers/dml/dml_session_options_config_keys.h
index d11fa7516e713..5b5f371f51616 100644
--- a/onnxruntime/core/providers/dml/dml_session_options_config_keys.h
+++ b/onnxruntime/core/providers/dml/dml_session_options_config_keys.h
@@ -21,3 +21,4 @@
 // "1": disabled (disallowed). Graph fusion will never be used.
 // The default value is "0"
 static const char* const kOrtSessionOptionsConfigDisableDmlGraphFusion = "ep.dml.disable_graph_fusion";
+static const char* const kOrtSessionOptionsConfigEnableGraphSerialization = "ep.dml.enable_graph_serialization";
diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc
index 799d4172f2b64..038423104d92e 100644
--- a/onnxruntime/core/providers/js/js_execution_provider.cc
+++ b/onnxruntime/core/providers/js/js_execution_provider.cc
@@ -21,7 +21,6 @@
 #include "core/framework/kernel_registry.h"
 #include "core/graph/function_utils.h"
 #include "core/graph/indexed_sub_graph.h"
-#include "core/providers/shared/node_unit/node_unit.h"
 #include "data_transfer.h"
 
 namespace onnxruntime {
@@ -756,16 +755,16 @@ std::unique_ptr<onnxruntime::IDataTransfer> JsExecutionProvider::GetDataTransfer
 JsExecutionProvider::~JsExecutionProvider() {
 }
 
-Status JsExecutionProvider::OnRunStart() {
-  if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) {
+Status JsExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
+  if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) {
     LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model";
     EM_ASM({ Module.jsepCaptureBegin(); });
   }
   return Status::OK();
 }
 
-Status JsExecutionProvider::OnRunEnd(bool sync_stream) {
-  if (IsGraphCaptureEnabled() && !IsGraphCaptured()) {
+Status JsExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) {
+  if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) {
     if (IsGraphCaptureAllowed()) {
       EM_ASM({ Module.jsepCaptureEnd(); });
       is_graph_captured_ = true;
@@ -781,12 +780,12 @@ bool JsExecutionProvider::IsGraphCaptureEnabled() const {
   return enable_graph_capture_;
 }
 
-bool JsExecutionProvider::IsGraphCaptured() const {
+bool JsExecutionProvider::IsGraphCaptured(int) const {
   return is_graph_captured_;
 }
 
-Status JsExecutionProvider::ReplayGraph() {
-  ORT_ENFORCE(IsGraphCaptured());
+Status JsExecutionProvider::ReplayGraph(int) {
+  ORT_ENFORCE(IsGraphCaptured(0));
   EM_ASM({ Module.jsepReplay(); });
   return Status::OK();
 }
diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h
index 91a3256ec2bd5..efacf510e75df 100644
--- a/onnxruntime/core/providers/js/js_execution_provider.h
+++ b/onnxruntime/core/providers/js/js_execution_provider.h
@@ -59,12 +59,12 @@ class JsExecutionProvider : public IExecutionProvider {
 
   std::vector<AllocatorPtr> CreatePreferredAllocators() override;
 
-  Status OnRunStart() override;
-  Status OnRunEnd(bool sync_stream) override;
+  Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
+  Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override;
 
   bool IsGraphCaptureEnabled() const override;
-  bool IsGraphCaptured() const override;
-  Status ReplayGraph() override;
+  bool IsGraphCaptured(int graph_annotation_id) const override;
+  Status ReplayGraph(int graph_annotation_id) override;
 
  private:
   bool IsGraphCaptureAllowed() const;
diff --git a/onnxruntime/core/providers/js/operators/where.cc b/onnxruntime/core/providers/js/operators/where.cc
index 2f8f5e275aa98..dcdf9bee2f783 100644
--- a/onnxruntime/core/providers/js/operators/where.cc
+++ b/onnxruntime/core/providers/js/operators/where.cc
@@ -6,18 +6,19 @@
 namespace onnxruntime {
 namespace js {
 
-#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS)      \
-  ONNX_OPERATOR_KERNEL_EX(                                          \
-      OP_TYPE,                                                      \
-      kOnnxDomain,                                                  \
-      VERSION,                                                      \
-      kJsExecutionProvider,                                         \
-      KernelDefBuilder()                                            \
-          .TypeConstraint("T",                                      \
-                          {DataTypeImpl::GetTensorType<float>(),    \
-                           DataTypeImpl::GetTensorType<int32_t>(),  \
-                           DataTypeImpl::GetTensorType<uint32_t>(), \
-                           DataTypeImpl::GetTensorType<bool>()}),   \
+#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS)       \
+  ONNX_OPERATOR_KERNEL_EX(                                           \
+      OP_TYPE,                                                       \
+      kOnnxDomain,                                                   \
+      VERSION,                                                       \
+      kJsExecutionProvider,                                          \
+      KernelDefBuilder()                                             \
+          .TypeConstraint("T",                                       \
+                          {DataTypeImpl::GetTensorType<float>(),     \
+                           DataTypeImpl::GetTensorType<MLFloat16>(), \
+                           DataTypeImpl::GetTensorType<int32_t>(),   \
+                           DataTypeImpl::GetTensorType<uint32_t>(),  \
+                           DataTypeImpl::GetTensorType<bool>()}),    \
       KERNEL_CLASS);
 
 #define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
@@ -29,6 +30,7 @@ namespace js {
       KernelDefBuilder()                                                                  \
           .TypeConstraint("T",                                                            \
                           {DataTypeImpl::GetTensorType<float>(),                          \
+                           DataTypeImpl::GetTensorType<MLFloat16>(),                      \
                            DataTypeImpl::GetTensorType<int32_t>(),                        \
                            DataTypeImpl::GetTensorType<uint32_t>(),                       \
                            DataTypeImpl::GetTensorType<bool>()}),                         \
diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
index 40e76a0a67782..50782569ee80a 100644
--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
+++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
@@ -1383,11 +1383,11 @@ Status MIGraphXExecutionProvider::Sync() const {
   return Status::OK();
 }
 
-Status MIGraphXExecutionProvider::OnRunStart() {
+Status MIGraphXExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
   return Status::OK();
 }
 
-Status MIGraphXExecutionProvider::OnRunEnd(bool) {
+Status MIGraphXExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) {
   auto status = hipStreamQuery(stream_);
 
   if (status != hipSuccess) {
diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h
index d582338c7e067..c3617f409e72c 100644
--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h
+++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h
@@ -56,9 +56,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
 #ifdef MIGRAPHX_STREAM_SYNC
   Status Sync() const override;
 
-  Status OnRunStart() override;
+  Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
 
-  Status OnRunEnd(bool sync_stream) override;
+  Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override;
 #endif
 
   std::vector<std::unique_ptr<ComputeCapability>>
diff --git a/onnxruntime/core/providers/migraphx/ort_trt_int8_cal_table.fbs.h b/onnxruntime/core/providers/migraphx/ort_trt_int8_cal_table.fbs.h
index 9639040f772da..a2721f6a5b44f 100644
--- a/onnxruntime/core/providers/migraphx/ort_trt_int8_cal_table.fbs.h
+++ b/onnxruntime/core/providers/migraphx/ort_trt_int8_cal_table.fbs.h
@@ -4,7 +4,7 @@
 #define ONNXRUNTIME_CORE_PROVIDERS_MIGRAPHX_ORT_TRT_INT8_CAL_TABLE_FBS_H_
 
 #include <vector>
-#include "flatbuffers/flatbuffers.h"
+#include "core/common/flatbuffers.h"
 
 namespace CalTableFlatBuffers {
 
diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc
index 0b32508a5bb38..745504ca04941 100644
--- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc
+++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc
@@ -11,6 +11,7 @@
 
 #include "core/common/logging/logging.h"
 #include "core/common/safeint.h"
+#include "core/framework/node_unit.h"
 #include "core/framework/tensorprotoutils.h"
 #include "core/graph/graph_viewer.h"
 #include "core/graph/graph.h"
@@ -18,7 +19,6 @@
 #include "core/providers/common.h"
 #include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h"
 #include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h"
-#include "core/providers/shared/node_unit/node_unit.h"
 #include "core/providers/shared/utils/utils.h"
 
 namespace onnxruntime {
diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h
index 6a54bf7bdb938..0c0bc7b2e4674 100644
--- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h
+++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h
@@ -4,7 +4,7 @@
 #pragma once
 
 #include "core/common/common.h"
-#include "core/providers/shared/node_unit/node_unit.h"
+#include "core/framework/node_unit.h"
 #include "core/providers/nnapi/nnapi_builtin/builders/model_builder.h"
 #include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h"
 #include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h"
diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc
index b2225643b788e..edee298ad1ccf 100644
--- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc
+++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc
@@ -67,7 +67,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
 
   int32_t num_outputs;
   if (node_unit.SinceVersion() >= 18) {
-    num_outputs = SafeInt<int32_t>(*helper.GetInt("num_outputs"));
+    num_outputs = SafeInt<int32_t>(*helper.GetInt64("num_outputs"));
   } else {
     num_outputs = SafeInt<int32_t>(node_unit.Outputs().size());
   }
@@ -127,7 +127,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const No
   } else {
     uint32_t num_outputs;
     if (node_unit.SinceVersion() >= 18) {
-      auto num_outputs_attr = helper.GetInt("num_outputs");
+      auto num_outputs_attr = helper.GetInt64("num_outputs");
       if (!num_outputs_attr.has_value()) {
         LOGS_DEFAULT(VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute.";
         return false;
diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc
index 6962a7be94bb6..d0ae32378379d 100644
--- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc
+++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc
@@ -11,17 +11,19 @@
 #include "core/common/safeint.h"
 #include "core/common/status.h"
 #include "core/framework/execution_provider.h"
+#include "core/framework/node_unit.h"
 #include "core/framework/tensorprotoutils.h"
 #include "core/graph/graph_viewer.h"
+#include "core/optimizer/initializer.h"
+#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
+#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
 #include "core/providers/common.h"
 #include "core/providers/nnapi/nnapi_builtin/nnapi_api_helper.h"
-#include "core/providers/shared/node_unit/node_unit.h"
-#include "core/providers/shared/utils/utils.h"
 #include "core/providers/nnapi/nnapi_builtin/builders/helper.h"
 #include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h"
 #include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h"
 #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h"
-#include "core/optimizer/initializer.h"
+#include "core/providers/shared/utils/utils.h"
 
 using namespace android::nn::wrapper;
 
@@ -119,7 +121,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const {
 }
 
 void ModelBuilder::PreprocessNodeUnits() {
-  std::tie(node_unit_holder_, node_unit_map_) = GetAllNodeUnits(graph_viewer_);
+  std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_);
 }
 
 // Help to get all quantized operators' input and the NodeUnit(s) using the input
@@ -664,7 +666,7 @@ int32_t ModelBuilder::FindActivation(const NodeUnit& node_unit) {
 
   int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE;
   bool fuse_code_assigned_from_activation = false;
-  for (auto it = node_unit.OutputEdgesBegin(0), end = node_unit.OutputEdgesEnd(0); it != end; ++it) {
+  for (auto it = node_unit.OutputEdgesBegin(), end = node_unit.OutputEdgesEnd(); it != end; ++it) {
     const auto& dst_node = it->GetNode();
     const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()];
 
diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc
index a066c64dac67d..dab7bccf43396 100644
--- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc
+++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc
@@ -21,7 +21,6 @@
 #include "core/optimizer/initializer.h"
 #include "core/providers/common.h"
 #include "core/providers/shared/utils/utils.h"
-#include "core/providers/shared/node_unit/node_unit.h"
 #include "core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h"
 
 namespace onnxruntime::nnapi::op_builder_helpers {
@@ -965,6 +964,18 @@ Status AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit,
   return Status::OK();
 }
 
+// NOTE: Skipping Reshape results in invalid output on some SnapDragon chipsets. Whilst the NNAPI spec says the input
+// to FullyConnnected can be > 2D, those chipsets don't handle this correctly.
+//
+// CanSkipReshape could potentially be re-enabled in the future if we no longer want to support those old chipsets.
+// However, the Reshape of newer chipsets may not run on CPU so there may not be a performance issue to try and avoid,
+// so CanSkipReshape could be redundant anyway.
+//
+// Known bad chipsets: Qualcomm Snapdragon 850, 855, 865, 870.
+//
+// See https://github.com/microsoft/onnxruntime/issues/19518
+
+/*
 // We can skip the Reshape if all the output edges satisfies both the following conditions
 // 1. The output of the reshape/flatten is not an output of the graph
 // 2. The output of the reshape/flatten is the input 0 of one or more GEMM/Matmul operators,
@@ -977,7 +988,7 @@ Status AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit,
 // between NNAPI CPU impl and Hardware Accelerator impl and will speed up the execution
 // If we are going to skip the reshape, we will still add correct shape and operand type for the output in
 // onnxruntime::nnapi::Model.
-bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit,
+static bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit,
                     size_t input_rank, size_t output_rank) {
   // Since we know this is a Reshape NodeUnit, so we can safely assume there is only 1 output
   // and the node_unit has only one output node.
@@ -1039,33 +1050,37 @@ bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit
                         << node_unit.Name() << "] with output, " << output_name;
   return true;
 }
+*/
 
 Status AddReshapeOperator(ModelBuilder& model_builder,
                           const NodeUnit& node_unit,
                           const std::string& input,
                           const std::vector<int32_t>& shape) {
   auto& shaper(model_builder.GetShaper());
-  const auto& operand_indices(model_builder.GetOperandIndices());
   const auto& operand_types(model_builder.GetOperandTypes());
   const auto& output = node_unit.Outputs()[0].node_arg.Name();
 
   const auto input_shape = shaper[input];
   const auto output_shape = shaper[output];
-  const auto input_rank = input_shape.size();
-  const auto output_rank = output_shape.size();
 
   // For reshape, the output type should be the same as the input type except the shape is different
   auto output_operand_type = operand_types.at(input);
   output_operand_type.SetDimensions(output_shape);
 
+  /* See CanSkipReshape definition above for explanation of why this is disabled.
   // Since Reshape is not running using hardware in NNAPI for some CPU (e.g. Qualcomm SD for now)
   // We will try to see if we the skip the Reshape to prevent context switching between
   // NNAPI CPU impl and NNAPI hardware accelerator impl
   if (CanSkipReshape(model_builder, node_unit, input_rank, output_rank)) {
-    // Since reshape can be skipped, only register the dimension and type, with same index and new name
+    const auto& operand_indices(model_builder.GetOperandIndices());
+    const auto input_rank = input_shape.size();
+    const auto output_rank = output_shape.size();
+    // Since reshape can be skipped, only register the dimension and type, with same index and new name.
+    // This essentially redirects the downstream operator builders to the input of the skipped Reshape node,
+    // but with the output shape of the Reshape node.
     model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type);
-  } else {
-    // We still need to perform a reshape here
+  } else */
+  {
     std::string shape_name = model_builder.GetUniqueName(node_unit.Name() + input + "newshape");
     ORT_RETURN_IF_ERROR(op_builder_helpers::AddNnapiReshape(model_builder, input, shape_name, shape, output));
   }
diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h
index 7ccf4c1ef7555..0844857a06d61 100644
--- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h
+++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h
@@ -7,12 +7,12 @@
 #include <vector>
 
 #include "core/common/common.h"
+#include "core/framework/node_unit.h"
 #include "core/providers/common.h"
 #include "core/providers/nnapi/nnapi_builtin/builders/helper.h"
 #include "core/providers/nnapi/nnapi_builtin/builders/model_builder.h"
 #include "core/providers/nnapi/nnapi_builtin/builders/op_builder.h"
 #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksWrapper.h"
-#include "core/providers/shared/node_unit/node_unit.h"
 
 namespace onnxruntime::nnapi::op_builder_helpers {
 
@@ -181,9 +181,6 @@ Status AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit,
 Status AddReshapeOperator(ModelBuilder& model_builder, const NodeUnit& node_unit,
                           const std::string& input, const std::vector<int32_t>& shape);
 
-bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit,
-                    size_t input_rank, size_t output_rank);
-
 Status GetAxesForSqueezeAndUnSqueeze(ModelBuilder& model_builder, const NodeUnit& node_unit,
                                      std::vector<int32_t>& axes);
 
diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc
index b04703d7611ee..4d2888222ff0f 100644
--- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc
+++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc
@@ -7,7 +7,10 @@
 #include "core/common/logging/logging.h"
 #include "core/common/string_utils.h"
 #include "core/framework/compute_capability.h"
+#include "core/framework/node_unit.h"
 #include "core/graph/graph_viewer.h"
+#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
+#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
 #include "core/platform/env.h"
 #include "core/providers/common.h"
 #include "core/providers/nnapi/nnapi_builtin/builders/helper.h"
@@ -17,7 +20,6 @@
 #include "core/providers/nnapi/nnapi_builtin/nnapi_api_helper.h"
 #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h"
 #include "core/providers/partitioning_utils.h"
-#include "core/providers/shared/node_unit/node_unit.h"
 #include "core/session/onnxruntime_cxx_api.h"
 
 namespace onnxruntime {
@@ -119,7 +121,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
   std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
   std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
 
-  std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer);
+  std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
 
   // This holds the result of whether a NodeUnit is supported or not,
   // to prevent nodes in a NodeUnit to be checked for multiple times
@@ -181,7 +183,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
   };
 
   result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed,
-                                            gen_metadef_name, NNAPI, kNnapiExecutionProvider);
+                                            gen_metadef_name, NNAPI, kNnapiExecutionProvider, &node_unit_map);
 
   // Generally, NNAPI supports sub-graphs with at least one non-constant initializer input and one output.
   // So far, we have a few cases that sub-graph has zero valid inputs, like `CastLike`
diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc
index 330b464ffd1bb..3252603e33389 100644
--- a/onnxruntime/core/providers/openvino/backend_manager.cc
+++ b/onnxruntime/core/providers/openvino/backend_manager.cc
@@ -1,8 +1,9 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #include <fstream>
 #include <utility>
+#include <exception>
 
 #include "core/providers/shared_library/provider_api.h"
 #include "contexts.h"
@@ -24,15 +25,6 @@ BackendManager::BackendManager(const GlobalContext& global_context,
   global_context_ = global_context;
 
   auto prec_str = GetGlobalContext().precision_str;
-  if (prec_str == "FP32") {
-    subgraph_context_.precision = "FP32";
-  } else if (prec_str == "FP16") {
-    subgraph_context_.precision = "FP16";
-  } else if (prec_str == "U8") {
-    subgraph_context_.precision = "U8";
-  } else {
-    throw std::string("Invalid OpenVINO Precision type: " + prec_str);
-  }
 
   // Save the indexes of graph inputs among fused_node's inputDefs
   // (which also contains initializers).
@@ -47,7 +39,7 @@ BackendManager::BackendManager(const GlobalContext& global_context,
   for (auto input : graph_inputs) {
     auto it = subgraph_context_.input_names.find(input->Name());
     if (it == subgraph_context_.input_names.end()) {
-      throw std::string("Input not found in the input defs list");
+      ORT_THROW("Input not found in the input defs list");
     }
     int index = it->second;
     subgraph_context_.input_indexes.push_back(index);
@@ -61,6 +53,7 @@ BackendManager::BackendManager(const GlobalContext& global_context,
   }
   subgraph_context_.subgraph_name = fused_node.Name();
   model_proto_ = GetModelProtoFromFusedNode(fused_node, subgraph, logger);
+  std::string device_type = openvino_ep::BackendManager::GetGlobalContext().device_type;
 
   if (ModelHasSymbolicInputDims(subgraph)) {
     subgraph_context_.has_dynamic_input_shape = true;
@@ -75,7 +68,7 @@ BackendManager::BackendManager(const GlobalContext& global_context,
                                                           GetGlobalContext(),
                                                           subgraph_context_);
         } catch (std::string const& msg) {
-          throw msg;
+          ORT_THROW(msg);
         }
         LOGS_DEFAULT(INFO) << "[OpenVINO-EP] "
                            << "Backend created for graph " << subgraph_context_.subgraph_name;
@@ -87,12 +80,29 @@ BackendManager::BackendManager(const GlobalContext& global_context,
                        << subgraph_context_.subgraph_name;
 
     subgraph_context_.has_dynamic_input_shape = false;
+
+    // OV NPU plugin is supported with fallback to OV CPU upon compilation failures.
     try {
       concrete_backend_ = BackendFactory::MakeBackend(*model_proto_,
                                                       GetGlobalContext(),
                                                       subgraph_context_);
-    } catch (std::string const& msg) {
-      throw msg;
+    } catch (const OnnxRuntimeException& ex) {
+      if (device_type.find("NPU") != std::string::npos) {
+        LOGS_DEFAULT(WARNING) << ex.what();
+        LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU."
+                              << "Falling back to OV CPU for execution";
+        GetGlobalContext().device_type = "CPU";
+        GetGlobalContext().precision_str = "FP32";
+        try {
+          concrete_backend_ = BackendFactory::MakeBackend(*model_proto_,
+                                                          GetGlobalContext(),
+                                                          subgraph_context_);
+        } catch (std::string const& msg) {
+          ORT_THROW(msg);
+        }
+      } else {
+        ORT_THROW(ex.what());
+      }
     }
   }
 }
@@ -254,8 +264,13 @@ void BackendManager::Compute(OrtKernelContext* context) {
     LOGS_DEFAULT(INFO) << "Start Compute";
   }
 #endif
+  // OV NPU doesn't support dynamic shaped model inference.
+  // if disable_dynamic_shapes is set to true then execution of dynamic model is done
+  // by rewriting the model to static shaped model at runtime based on input shape.
+  // disable_dynamic_shapes is always set to true for OV NPU plugin.
   bool use_dynamic_backend = true;
-  if (!GetGlobalContext().disable_dynamic_shapes && subgraph_context_.has_dynamic_input_shape &&
+  if (subgraph_context_.has_dynamic_input_shape &&
+      !GetGlobalContext().disable_dynamic_shapes &&
       (GetGlobalContext().device_type.find("CPU") != std::string::npos ||
        GetGlobalContext().device_type.find("GPU") != std::string::npos)) {
     concrete_backend_->Infer(context);
@@ -263,12 +278,11 @@ void BackendManager::Compute(OrtKernelContext* context) {
   } else if (use_dynamic_backend && subgraph_context_.has_dynamic_input_shape) {
     std::vector<std::vector<int64_t>> tensor_shapes = GetInputTensorShapes(ctx);
     auto key = MakeMapKeyString(tensor_shapes, GetGlobalContext().device_type);
-
     std::shared_ptr<IBackend> dynamic_backend;
     auto search = backend_map_.find(key);
     if (search == backend_map_.end()) {
       LOGS_DEFAULT(INFO) << "[OpenVINO-EP] "
-                         << "Creating concrete backend for key: " << key;
+                         << "Creating dynamic backend for key: " << key;
       LOGS_DEFAULT(INFO) << "[OpenVINO-EP] "
                          << "Backend created for graph " << subgraph_context_.subgraph_name;
       auto modelproto_with_concrete_shapes = ReWriteInputShapeInfo(*model_proto_, tensor_shapes);
@@ -276,8 +290,22 @@ void BackendManager::Compute(OrtKernelContext* context) {
         dynamic_backend = BackendFactory::MakeBackend(*modelproto_with_concrete_shapes,
                                                       GetGlobalContext(),
                                                       subgraph_context_);
-      } catch (std::string const& msg) {
-        throw msg;
+      } catch (const OnnxRuntimeException& ex) {
+        if (GetGlobalContext().device_type.find("NPU") != std::string::npos) {
+          LOGS_DEFAULT(WARNING) << ex.what();
+          LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU."
+                                << "Falling back to OV CPU for execution";
+          GetGlobalContext().device_type = "CPU";
+          GetGlobalContext().precision_str = "FP32";
+          key = MakeMapKeyString(tensor_shapes, GetGlobalContext().device_type);
+          try {
+            dynamic_backend = BackendFactory::MakeBackend(*modelproto_with_concrete_shapes,
+                                                          GetGlobalContext(),
+                                                          subgraph_context_);
+          } catch (std::string const& msg) {
+            ORT_THROW(msg);
+          }
+        }
       }
       backend_map_.insert({key, dynamic_backend});
     } else {
diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h
index 59bda7ca640ee..376ebea225a2b 100644
--- a/onnxruntime/core/providers/openvino/backend_manager.h
+++ b/onnxruntime/core/providers/openvino/backend_manager.h
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #pragma once
diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc
index 50c839017df2a..32b5ad7d5b66d 100644
--- a/onnxruntime/core/providers/openvino/backend_utils.cc
+++ b/onnxruntime/core/providers/openvino/backend_utils.cc
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #include <algorithm>
@@ -11,12 +11,7 @@
 #include "core/providers/shared_library/provider_api.h"
 #include "backend_utils.h"
 
-#if defined(OV_API_20)
 using Exception = ov::Exception;
-#else
-using Exception = InferenceEngine::details::InferenceEngineException;
-using WaitMode = InferenceEngine::IInferRequest::WaitMode;
-#endif
 
 namespace onnxruntime {
 namespace openvino_ep {
@@ -47,7 +42,6 @@ struct static_cast_int64 {
 
 std::shared_ptr<OVNetwork>
 CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context,
-              const SubGraphContext& subgraph_context,
               std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map) {
   if (IsCILogEnabled()) {
     std::cout << "CreateNgraphFunc" << std::endl;
@@ -55,28 +49,6 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext
   const std::string model = model_proto.SerializeAsString();
   try {
     auto cnn_network = global_context.ie_core.ReadModel(model, global_context.onnx_model_path_name);
-    if ((subgraph_context.precision == "FP16") &&
-        (global_context.device_type.find("NPU") == std::string::npos)) {
-      // FP16 transformations
-      ov::pass::ConvertFP32ToFP16 pass_obj;
-      pass_obj.run_on_model(cnn_network);
-      cnn_network->validate_nodes_and_infer_types();
-
-      auto proc = ov::preprocess::PrePostProcessor(cnn_network);
-      for (size_t i = 0; i < cnn_network->inputs().size(); i++) {
-        if (cnn_network->inputs()[i].get_element_type() == ov::element::f16) {
-          proc.input(i).tensor().set_element_type(ov::element::f32);
-          proc.input(i).preprocess().convert_element_type(ov::element::f16);
-        }
-      }
-
-      for (size_t i = 0; i < cnn_network->outputs().size(); i++) {
-        if (cnn_network->outputs()[i].get_element_type() == ov::element::f16) {
-          proc.output(i).postprocess().convert_element_type(ov::element::f32);
-        }
-      }
-      cnn_network = proc.build();
-    }
 
     // Check for Constant Folding
     if (!global_context.is_wholly_supported_graph) {
@@ -103,7 +75,7 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext
 #endif
     return cnn_network;
   } catch (std::string const& msg) {
-    throw msg;
+    ORT_THROW(msg);
   }
 }
 
@@ -127,7 +99,7 @@ GetOutputTensor(Ort::KernelContext& context, size_t batch_size,
   }
   auto it = output_names.find(output_name);
   if (it == output_names.end()) {
-    throw std::string(log_tag + "Output names mismatch between OpenVINO and ONNX");
+    ORT_THROW(log_tag + "Output names mismatch between OpenVINO and ONNX");
   }
   int index = it->second;
   return context.GetOutput(index, output_shape.get(), num_dims);
@@ -145,7 +117,7 @@ GetOutputTensor(Ort::KernelContext& context,
 
   auto it = output_names.find(output_name);
   if (it == output_names.end()) {
-    throw std::string(log_tag + "Output names mismatch between OpenVINO and ONNX");
+    ORT_THROW(log_tag + "Output names mismatch between OpenVINO and ONNX");
   }
   int index = it->second;
   auto shape = node->get_shape();
@@ -204,7 +176,7 @@ void FillOutputsWithConstantData(std::shared_ptr<ov::Node> node, Ort::UnownedVal
       break;
     }
     default:
-      throw std::string(log_tag + "Unsupported output data type");
+      ORT_THROW(log_tag + "Unsupported output data type");
   }
 }
 
@@ -232,7 +204,7 @@ void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx,
   auto tensor = context.GetInput(subgraph_context.input_names.at(input_name));
   auto mem_info = tensor.GetTensorMemoryInfo();
   if (mem_info.GetAllocatorName() == OpenVINO_GPU) {
-    throw std::string(log_tag + "IO Buffering is not enabled, Please enable Input on CPU");
+    ORT_THROW(log_tag + "IO Buffering is not enabled, Please enable Input on CPU");
   }
   // Copy input data into OpenVINO's input buffer
   const char* tensor_data = tensor.GetTensorData<char>();
diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h
index 82b0351e87da5..93fa874774469 100644
--- a/onnxruntime/core/providers/openvino/backend_utils.h
+++ b/onnxruntime/core/providers/openvino/backend_utils.h
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #pragma once
@@ -65,7 +65,6 @@ void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor,
 std::shared_ptr<OVNetwork>
 CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto,
               const GlobalContext& global_context,
-              const SubGraphContext& subgraph_context,
               std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map);
 
 void printPerformanceCounts(const std::vector<OVProfilingInfo>& performanceMap,
diff --git a/onnxruntime/core/providers/openvino/backends/backend_factory.cc b/onnxruntime/core/providers/openvino/backends/backend_factory.cc
index c586dd8b38af9..a0f4ce8f843b0 100644
--- a/onnxruntime/core/providers/openvino/backends/backend_factory.cc
+++ b/onnxruntime/core/providers/openvino/backends/backend_factory.cc
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #include <memory>
@@ -24,11 +24,11 @@ BackendFactory::MakeBackend(const ONNX_NAMESPACE::ModelProto& model_proto,
     try {
       concrete_backend_ = std::make_shared<BasicBackend>(model_proto, global_context, subgraph_context);
     } catch (std::string const& msg) {
-      throw msg;
+      ORT_THROW(msg);
     }
     return concrete_backend_;
   } else {
-    throw std::string("[OpenVINO-EP] Backend factory error: Unknown backend type: " + type);
+    ORT_THROW("[OpenVINO-EP] Backend factory error: Unknown backend type: " + type);
   }
 }
 }  // namespace openvino_ep
diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc
index 0779940983aea..69d234a7c55ef 100644
--- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc
+++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #include <map>
@@ -79,20 +79,20 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto,
                                                            subgraph_context_.subgraph_name);
         LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
       } else {
-        ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_);
+        ie_cnn_network_ = CreateOVModel(model_proto, global_context_, const_outputs_map_);
         exe_network_ = global_context_.ie_core.LoadNetwork(
             ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name);
         LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
       }
 #endif
     } else {
-      ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_);
+      ie_cnn_network_ = CreateOVModel(model_proto, global_context_, const_outputs_map_);
       exe_network_ = global_context_.ie_core.LoadNetwork(
           ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name);
       LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
     }
   } catch (const char* msg) {
-    throw(msg);
+    ORT_THROW(msg);
   }
 
   inferRequestsQueue_ = std::unique_ptr<InferRequestsQueue>(new InferRequestsQueue(exe_network_, 1));
@@ -125,21 +125,17 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
   if (global_context_.device_type.find("NPU") != std::string::npos) {
     std::pair<std::string, ov::Any> device_property;
     device_property = std::make_pair("NPU_COMPILER_TYPE", "DRIVER");
+
+    const std::string env_npu_compiler_type = onnxruntime::GetEnvironmentVar("ORT_OPENVINO_NPU_COMPILER_TYPE");
+    if (!env_npu_compiler_type.empty()) {
+      device_property = std::make_pair("NPU_COMPILER_TYPE", env_npu_compiler_type);
+    }
     device_config.emplace(ov::device::properties("NPU", device_property));
   }
 }
 
 void BasicBackend::EnableCaching() {
   if (!global_context_.cache_dir.empty()) {
-    if (global_context_.is_wholly_supported_graph) {
-#if defined(OPENVINO_2022_3)
-#if defined(_WIN32) || defined(WIN32) || defined(__CYGWIN__) || defined(__MINGW32__) || defined(__BORLANDC__)
-      _putenv_s("OV_GPU_CACHE_MODEL", "1");
-#else
-      setenv("OV_GPU_CACHE_MODEL", "1", 1);
-#endif
-#endif
-    }
     LOGS_DEFAULT(INFO) << log_tag << "Enables Caching";
     global_context_.ie_core.SetCache(global_context_.cache_dir);
   }
@@ -162,7 +158,7 @@ void BasicBackend::EnableStreams() {
       (global_context_.device_type.find("HETERO") != std::string::npos) ||
       (global_context_.device_type.find("AUTO") != std::string::npos)) {
     if (global_context_.num_streams != 1) {
-      throw(log_tag + "Cannot set NUM_STREAMS to " + std::to_string(global_context_.num_streams) + " for device " + global_context_.device_type);
+      ORT_THROW(log_tag + "Cannot set NUM_STREAMS to " + std::to_string(global_context_.num_streams) + " for device " + global_context_.device_type);
     }
     // Do nothing
   } else {
@@ -198,9 +194,9 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
       if (input_names.find(onnx_input_name) != input_names.end()) {
         input_name = onnx_input_name;
       } else {
-        throw(log_tag +
-              "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name +
-              " doesn't exist in the list of OpenVINO input tensor names");
+        ORT_THROW(log_tag +
+                  "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name +
+                  " doesn't exist in the list of OpenVINO input tensor names");
       }
       size_t batch_slice_idx = 0;
       if (subgraph_context_.has_dynamic_input_shape &&
@@ -232,14 +228,14 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
         try {
           infer_request->SetTensor(input_name, tensor_ptr);
         } catch (const char* msg) {
-          throw(msg);
+          ORT_THROW(msg);
         }
       } else {
         OVTensorPtr graph_input_blob;
         try {
           graph_input_blob = infer_request->GetTensor(input_name);
         } catch (const char* msg) {
-          throw(msg);
+          ORT_THROW(msg);
         }
         FillInputBlob(graph_input_blob, batch_slice_idx, input_name, context, subgraph_context_);
       }
@@ -248,7 +244,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
     // Start Async inference
     infer_request->StartAsync();
   } catch (const char* msg) {
-    throw(msg);
+    ORT_THROW(msg);
   }
 }
 
@@ -274,10 +270,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe
       if (input_names.find(onnx_input_name) != input_names.end()) {
         input_name = onnx_input_name;
       } else {
-        throw(log_tag +
-              "Input names mismatch between OpenVINO and ONNX. " +
-              onnx_input_name +
-              " doesn't exist in the list of OpenVINO input tensor names");
+        ORT_THROW(log_tag +
+                  "Input names mismatch between OpenVINO and ONNX. " +
+                  onnx_input_name +
+                  " doesn't exist in the list of OpenVINO input tensor names");
       }
       input_idx++;
       // Kernel Context Input Buffer
@@ -322,7 +318,7 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe
         }
       }
       if (!output_name_found) {
-        throw std::string(
+        ORT_THROW(
             log_tag +
             "Output names mismatch between OpenVINO and ONNX. [ONNX Output: ] " +
             onnx_output_name + " doesn't exist in the list of OpenVINO output tensor names");
@@ -344,7 +340,7 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe
         try {
           infer_request->SetTensor(output_name, tensor_ptr);
         } catch (const char* msg) {
-          throw(msg);
+          ORT_THROW(msg);
         }
       }
     }
@@ -352,7 +348,7 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe
     // Start Async inference
     infer_request->StartAsync();
   } catch (const char* msg) {
-    throw(msg);
+    ORT_THROW(msg);
   }
 }
 #endif
@@ -382,17 +378,18 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe
         }
       }
       if (!output_name_found) {
-        throw(log_tag +
-              "Output names mismatch between OpenVINO and ONNX. "
-              "[ONNX Output: ] " +
-              onnx_output_name +
-              " doesn't exist in the "
-              "list of OpenVINO output tensor names");
+        ORT_THROW(
+            log_tag +
+            "Output names mismatch between OpenVINO and ONNX. "
+            "[ONNX Output: ] " +
+            onnx_output_name +
+            " doesn't exist in the "
+            "list of OpenVINO output tensor names");
       }
       try {
         graph_output_blob = infer_request->GetTensor(output_name);
       } catch (const char* msg) {
-        throw(msg);
+        ORT_THROW(msg);
       }
       size_t batch_size = 1;
       auto output_tensor =
@@ -413,14 +410,14 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe
         auto output_tensor = GetOutputTensor(context, out_name, subgraph_context_.output_names, node);
         auto mem_info = output_tensor.GetTensorMemoryInfo();
         if (mem_info.GetAllocatorName() == OpenVINO_GPU) {
-          throw(log_tag + "IO Buffering is not supported for constant subgraphs");
+          ORT_THROW(log_tag + "IO Buffering is not supported for constant subgraphs");
         } else {
           FillOutputsWithConstantData(node, output_tensor);
         }
       }
     }
   } catch (const char* msg) {
-    throw(msg);
+    ORT_THROW(msg);
   }
 }
 
@@ -440,7 +437,7 @@ void BasicBackend::Infer(OrtKernelContext* ctx) {
         auto output_tensor = GetOutputTensor(context, out_name, subgraph_context_.output_names, node);
         FillOutputsWithConstantData(node, output_tensor);
       } catch (std::string const& msg) {
-        throw msg;
+        ORT_THROW(msg);
       }
     }
     // Get Output tensors
@@ -461,26 +458,26 @@ void BasicBackend::Infer(OrtKernelContext* ctx) {
       try {
         StartRemoteAsyncInference(context, infer_request);
       } catch (std::string const& msg) {
-        throw msg;
+        ORT_THROW(msg);
       }
     } else {
       try {
         StartAsyncInference(context, infer_request);
       } catch (std::string const& msg) {
-        throw msg;
+        ORT_THROW(msg);
       }
     }
 #else
     try {
       StartAsyncInference(context, infer_request);
-    } catch (std::string const& msg) {
-      throw msg;
+    } catch (const std::runtime_error& e) {
+      ORT_THROW(log_tag + " Exception at StartAsyncInference: " + e.what());
     }
 #endif
     try {
       CompleteAsyncInference(context, infer_request);
-    } catch (std::string const& msg) {
-      throw msg;
+    } catch (const std::runtime_error& e) {
+      ORT_THROW(log_tag + " Exception at CompleteAsyncInference: " + e.what());
     }
 
     // Get Output tensors
diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h
index aa96dadbf0e2d..3502f660bbb20 100644
--- a/onnxruntime/core/providers/openvino/backends/basic_backend.h
+++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #pragma once
diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h
index 5f19c71683f24..8701d9f676ffd 100644
--- a/onnxruntime/core/providers/openvino/contexts.h
+++ b/onnxruntime/core/providers/openvino/contexts.h
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #pragma once
@@ -31,6 +31,7 @@ struct GlobalContext {
   int onnx_opset_version;
   void* context = 0;
   bool use_api_2;
+  std::vector<int> OpenVINO_Version = {};  // Ov Major and OV minor version from OV headers
 };
 
 // Holds context specific to subgraph.
@@ -44,7 +45,6 @@ struct SubGraphContext {
   std::vector<int> input_indexes;
   std::unordered_map<std::string, int> input_names;
   std::unordered_map<std::string, int> output_names;
-  std::string precision;
 };
 
 }  // namespace openvino_ep
diff --git a/onnxruntime/core/providers/openvino/ibackend.h b/onnxruntime/core/providers/openvino/ibackend.h
index 8aacce19c14d5..ece855c6167c6 100644
--- a/onnxruntime/core/providers/openvino/ibackend.h
+++ b/onnxruntime/core/providers/openvino/ibackend.h
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #pragma once
diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc
index e3948cc94b348..913440d2fb6ea 100644
--- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc
+++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #include "core/providers/shared_library/provider_api.h"
@@ -6,6 +6,7 @@
 #include "contexts.h"
 #include "backend_manager.h"
 #include "ov_versions/capability.h"
+#include "openvino/core/version.hpp"
 
 #define MEMCPY_S(dest, src, destsz, srcsz) memcpy(dest, src, std::min(destsz, srcsz))
 
@@ -25,6 +26,7 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv
   global_context_->enable_opencl_throttling = info.enable_opencl_throttling_;
   global_context_->disable_dynamic_shapes = info.disable_dynamic_shapes_;
   global_context_->num_of_threads = info.num_of_threads_;
+  global_context_->OpenVINO_Version = {OPENVINO_VERSION_MAJOR, OPENVINO_VERSION_MINOR};
 
   // to check if target device is available
   // using ie_core capability GetAvailableDevices to fetch list of devices plugged in
@@ -50,8 +52,7 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv
               device_found = true;
               break;
             }
-            if ((info.device_type_.find("NPU") != std::string::npos) &&
-                (info.precision_ == "FP16" || info.precision_ == "U8")) {
+            if (info.device_type_.find("NPU") != std::string::npos) {
               device_found = true;
               break;
             }
@@ -113,27 +114,10 @@ OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer,
   global_context_->onnx_opset_version =
       graph_viewer.DomainToVersionMap().at(kOnnxDomain);
 
-#if defined(OPENVINO_2023_0)
   openvino_ep::GetCapability obj(graph_viewer,
                                  global_context_->device_type,
-                                 global_context_->precision_str, "V_2023_0");
+                                 global_context_->precision_str);
   result = obj.Execute();
-#elif defined(OPENVINO_2023_1)
-  openvino_ep::GetCapability obj(graph_viewer,
-                                 global_context_->device_type,
-                                 global_context_->precision_str, "V_2023_1");
-  result = obj.Execute();
-#elif defined(OPENVINO_2023_2)
-  openvino_ep::GetCapability obj(graph_viewer,
-                                 global_context_->device_type,
-                                 global_context_->precision_str, "V_2023_2");
-  result = obj.Execute();
-#elif defined(OPENVINO_2023_3)
-  openvino_ep::GetCapability obj(graph_viewer,
-                                 global_context_->device_type,
-                                 global_context_->precision_str, "V_2023_3");
-  result = obj.Execute();
-#endif
 
   global_context_->is_wholly_supported_graph = obj.IsWhollySupportedGraph();
 
diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h
index b0c92828d8a38..b0dc881c36f33 100644
--- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h
+++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #pragma once
@@ -20,7 +20,7 @@ static void print_build_options() {
             << "you want to build"
             << std::endl;
   std::cout << "The different hardware devices that can be added with HETERO/MULTI/AUTO build "
-            << "are ['CPU','GPU']"
+            << "are ['CPU','GPU','NPU']"
             << std::endl;
   std::cout << "An example of how to specify the HETERO or MULTI or AUTO build type. "
             << "Ex: HETERO:GPU,CPU  Ex: MULTI:GPU,CPU Ex: AUTO:GPU,CPU"
@@ -48,7 +48,7 @@ static std::vector<std::string> parseDevices(const std::string& device_string) {
     print_build_options();
     ORT_THROW("Invalid device string: " + device_string);
   }
-  std::vector<std::string> dev_options = {"CPU", "GPU"};
+  std::vector<std::string> dev_options = {"CPU", "GPU", "NPU"};
   for (std::string dev : devices) {
     if (!std::count(dev_options.begin(), dev_options.end(), dev)) {
       print_build_options();
@@ -98,12 +98,9 @@ struct OpenVINOExecutionProviderInfo {
 #elif defined OPENVINO_CONFIG_GPU_FP16
       device_type_ = "GPU";
       precision_ = "FP16";
-#elif defined OPENVINO_CONFIG_NPU_FP16
+#elif defined OPENVINO_CONFIG_NPU
       device_type_ = "NPU";
-      precision_ = "FP16";
-#elif defined OPENVINO_CONFIG_NPU_U8
-      device_type_ = "NPU";
-      precision_ = "U8";
+      precision_ = "";
 #elif defined OPENVINO_CONFIG_HETERO || defined OPENVINO_CONFIG_MULTI || defined OPENVINO_CONFIG_AUTO
 #ifdef DEVICE_NAME
 #define DEVICE DEVICE_NAME
@@ -142,12 +139,9 @@ struct OpenVINOExecutionProviderInfo {
     } else if (dev_type == "GPU.1_FP16") {
       device_type_ = "GPU.1";
       precision_ = "FP16";
-    } else if (dev_type == "NPU_FP16") {
-      device_type_ = "NPU";
-      precision_ = "FP16";
-    } else if (dev_type == "NPU_U8") {
+    } else if (dev_type == "NPU") {
       device_type_ = "NPU";
-      precision_ = "U8";
+      precision_ = "";
     } else if (dev_type.find("HETERO") == 0 || dev_type.find("MULTI") == 0) {
       std::vector<std::string> devices = parseDevices(dev_type);
       precision_ = "FP16";
diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc
index 068456777bece..17511c54aab86 100644
--- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc
+++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #include "core/providers/shared_library/provider_api.h"
@@ -78,7 +78,6 @@ struct OpenVINO_Provider : Provider {
                                             // with this value at runtime.
     bool enable_opencl_throttling = false;  // [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU
                                             // device (Reduces CPU Utilization when using GPU)
-    bool disable_dynamic_shapes = false;    // [disable_dynamic_shapes]:  Execute model with default static shape for optimal performance.
     void* context = nullptr;
 
     if (provider_options_map.find("device_type") != provider_options_map.end()) {
@@ -86,7 +85,7 @@ struct OpenVINO_Provider : Provider {
 
       std::set<std::string> ov_supported_device_types = {"CPU_FP32", "CPU_FP16", "GPU_FP32",
                                                          "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16",
-                                                         "GPU.0_FP16", "GPU.1_FP16"};
+                                                         "GPU.0_FP16", "GPU.1_FP16", "NPU"};
       if (!((ov_supported_device_types.find(device_type) != ov_supported_device_types.end()) ||
             (device_type.find("HETERO:") == 0) ||
             (device_type.find("MULTI:") == 0) ||
@@ -94,7 +93,7 @@ struct OpenVINO_Provider : Provider {
         ORT_THROW(
             "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. "
             "Select from 'CPU_FP32', 'CPU_FP16', 'GPU_FP32', 'GPU.0_FP32', 'GPU.1_FP32', 'GPU_FP16', "
-            "'GPU.0_FP16', 'GPU.1_FP16' or from"
+            "'GPU.0_FP16', 'GPU.1_FP16', 'NPU' or from"
             " HETERO/MULTI/AUTO options available. \n");
       }
     }
@@ -147,12 +146,24 @@ struct OpenVINO_Provider : Provider {
       bool_flag = "";
     }
 
+    // [disable_dynamic_shapes]:  Rewrite dynamic shaped models to static shape at runtime and execute.
+    // Always true for NPU plugin.
+    bool disable_dynamic_shapes = false;
+    if (device_type.find("NPU") != std::string::npos) {
+      disable_dynamic_shapes = true;
+    }
     if (provider_options_map.find("disable_dynamic_shapes") != provider_options_map.end()) {
       bool_flag = provider_options_map.at("disable_dynamic_shapes");
       if (bool_flag == "true" || bool_flag == "True")
         disable_dynamic_shapes = true;
-      else if (bool_flag == "false" || bool_flag == "False")
-        disable_dynamic_shapes = false;
+      else if (bool_flag == "false" || bool_flag == "False") {
+        if (device_type.find("NPU") != std::string::npos) {
+          disable_dynamic_shapes = true;
+          LOGS_DEFAULT(INFO) << "[OpenVINO-EP] The value for the key 'disable_dynamic_shapes' will be set to TRUE for NPU backend.\n ";
+        } else {
+          disable_dynamic_shapes = false;
+        }
+      }
     }
     return std::make_shared<OpenVINOProviderFactory>(const_cast<char*>(device_type.c_str()),
                                                      enable_npu_fast_compile,
diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc
index ea481791111fc..d7c6654c90f81 100644
--- a/onnxruntime/core/providers/openvino/ov_interface.cc
+++ b/onnxruntime/core/providers/openvino/ov_interface.cc
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #include "ov_interface.h"
@@ -8,12 +8,7 @@
 #include "core/providers/shared_library/provider_api.h"
 #include "backend_utils.h"
 
-#if defined(OV_API_20)
 using Exception = ov::Exception;
-#else
-using Exception = InferenceEngine::details::InferenceEngineException;
-using WaitMode = InferenceEngine::IInferRequest::WaitMode;
-#endif
 
 namespace onnxruntime {
 namespace openvino_ep {
@@ -36,9 +31,9 @@ std::shared_ptr<OVNetwork> OVCore::ReadModel(const std::string& model, const std
     }
     return FE->convert(inputModel);
   } catch (const Exception& e) {
-    throw std::string(log_tag + "[OpenVINO-EP] Exception while Reading network: " + std::string(e.what()));
+    ORT_THROW(log_tag + "[OpenVINO-EP] Exception while Reading network: " + std::string(e.what()));
   } catch (...) {
-    throw std::string(log_tag + "[OpenVINO-EP] Unknown exception while Reading network");
+    ORT_THROW(log_tag + "[OpenVINO-EP] Unknown exception while Reading network");
   }
 }
 
@@ -81,9 +76,9 @@ OVExeNetwork OVCore::LoadNetwork(std::shared_ptr<OVNetwork>& ie_cnn_network,
     OVExeNetwork exe(obj);
     return exe;
   } catch (const Exception& e) {
-    throw std::string(log_tag + " Exception while Loading Network for graph: " + name + e.what());
+    ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what());
   } catch (...) {
-    throw std::string(log_tag + " Exception while Loading Network for graph " + name);
+    ORT_THROW(log_tag + " Exception while Loading Network for graph " + name);
   }
 }
 
@@ -113,9 +108,9 @@ OVExeNetwork OVCore::LoadNetwork(std::shared_ptr<OVNetwork>& model, OVRemoteCont
     auto obj = oe.compile_model(model, *context);
     return OVExeNetwork(obj);
   } catch (const Exception& e) {
-    throw std::string(log_tag + " Exception while Loading Network for graph: " + name + e.what());
+    ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what());
   } catch (...) {
-    throw std::string(log_tag + " Exception while Loading Network for graph " + name);
+    ORT_THROW(log_tag + " Exception while Loading Network for graph " + name);
   }
 }
 #endif
@@ -135,9 +130,9 @@ OVInferRequest OVExeNetwork::CreateInferRequest() {
     OVInferRequest inf_obj(infReq);
     return inf_obj;
   } catch (const Exception& e) {
-    throw std::string(log_tag + "Exception while creating InferRequest object: " + e.what());
+    ORT_THROW(log_tag + "Exception while creating InferRequest object: " + e.what());
   } catch (...) {
-    throw std::string(log_tag + "Exception while creating InferRequest object.");
+    ORT_THROW(log_tag + "Exception while creating InferRequest object.");
   }
 }
 
@@ -147,9 +142,9 @@ OVTensorPtr OVInferRequest::GetTensor(const std::string& input_name) {
     OVTensorPtr blob = std::make_shared<OVTensor>(tobj);
     return blob;
   } catch (const Exception& e) {
-    throw std::string(log_tag + " Cannot access IE Blob for input: " + input_name + e.what());
+    ORT_THROW(log_tag + " Cannot access IE Blob for input: " + input_name + e.what());
   } catch (...) {
-    throw std::string(log_tag + " Cannot access IE Blob for input: " + input_name);
+    ORT_THROW(log_tag + " Cannot access IE Blob for input: " + input_name);
   }
 }
 
@@ -157,9 +152,9 @@ void OVInferRequest::SetTensor(const std::string& name, OVTensorPtr& blob) {
   try {
     ovInfReq.set_tensor(name, *(blob.get()));
   } catch (const Exception& e) {
-    throw std::string(log_tag + " Cannot set Remote Blob for output: " + name + e.what());
+    ORT_THROW(log_tag + " Cannot set Remote Blob for output: " + name + e.what());
   } catch (...) {
-    throw std::string(log_tag + " Cannot set Remote Blob for output: " + name);
+    ORT_THROW(log_tag + " Cannot set Remote Blob for output: " + name);
   }
 }
 
@@ -167,9 +162,9 @@ void OVInferRequest::StartAsync() {
   try {
     ovInfReq.start_async();
   } catch (const Exception& e) {
-    throw std::string(log_tag + " Couldn't start Inference: " + e.what());
+    ORT_THROW(log_tag + " Couldn't start Inference: " + e.what());
   } catch (...) {
-    throw std::string(log_tag + " In Error Couldn't start Inference");
+    ORT_THROW(log_tag + " In Error Couldn't start Inference");
   }
 }
 
@@ -177,9 +172,9 @@ void OVInferRequest::Infer() {
   try {
     ovInfReq.infer();
   } catch (const Exception& e) {
-    throw std::string(log_tag + " Couldn't start Inference: " + e.what());
+    ORT_THROW(log_tag + " Couldn't start Inference: " + e.what());
   } catch (...) {
-    throw std::string(log_tag + " In Error Couldn't start Inference");
+    ORT_THROW(log_tag + " In Error Couldn't start Inference");
   }
 }
 
@@ -187,9 +182,9 @@ void OVInferRequest::WaitRequest() {
   try {
     ovInfReq.wait();
   } catch (const Exception& e) {
-    throw std::string(log_tag + " Wait Model Failed: " + e.what());
+    ORT_THROW(log_tag + " Wait Model Failed: " + e.what());
   } catch (...) {
-    throw std::string(log_tag + " Wait Mode Failed");
+    ORT_THROW(log_tag + " Wait Mode Failed");
   }
 }
 
diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h
index cf4d867d4df55..2a13fafb99fd3 100644
--- a/onnxruntime/core/providers/openvino/ov_interface.h
+++ b/onnxruntime/core/providers/openvino/ov_interface.h
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #pragma once
@@ -6,14 +6,11 @@
 #include <vector>
 #include <memory>
 
-#define OV_API_20
 #include "openvino/openvino.hpp"
 #include "openvino/pass/convert_fp32_to_fp16.hpp"
 #include "openvino/frontend/manager.hpp"
 
 #ifdef IO_BUFFER_ENABLED
-#include <gpu/gpu_context_api_ocl.hpp>
-#include <gpu/gpu_config.hpp>
 #include <openvino/runtime/intel_gpu/ocl/ocl.hpp>
 #endif
 
diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc
index 11c8a1629b073..3970bf6ff68a7 100644
--- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc
+++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) 2019- Intel Corporation
 // Licensed under the MIT License
 
 #include "core/providers/shared_library/provider_api.h"
@@ -6,6 +6,7 @@
 #include "../backend_manager.h"
 #include "capability.h"
 #include "utils.h"
+#include "openvino/core/version.hpp"
 
 #if defined(_MSC_VER)
 #pragma warning(disable : 4244 4245 5208)
@@ -25,20 +26,22 @@ namespace openvino_ep {
 // Constructor
 GetCapability::GetCapability(const GraphViewer& graph_viewer_param,
                              const std::string device_type_param,
-                             const std::string device_precision,
-                             const std::string version_param)
+                             const std::string device_precision)
     : graph_viewer_(graph_viewer_param), device_type_(device_type_param), device_precision_(device_precision) {
-  if (version_param == "V_2023_0") {
-    data_ops_ = new DataOps(graph_viewer_, V_2023_0, device_type_, device_precision_);
-  } else if (version_param == "V_2023_1") {
-    data_ops_ = new DataOps(graph_viewer_, V_2023_1, device_type_, device_precision_);
-  } else if (version_param == "V_2023_2") {
-    data_ops_ = new DataOps(graph_viewer_, V_2023_2, device_type_, device_precision_);
-  } else if (version_param == "V_2023_3") {
-    data_ops_ = new DataOps(graph_viewer_, V_2023_3, device_type_, device_precision_);
-  } else {
-    data_ops_ = new DataOps(graph_viewer_, V_2023_3, device_type_, device_precision_);
+  if (device_type_.find("NPU") != std::string::npos) {
+    device_type_ = "CPU_FP32";
   }
+#if OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 1
+  data_ops_ = new DataOps(graph_viewer_, V_2023_1, device_type_, device_precision_);
+#elif OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 2
+  data_ops_ = new DataOps(graph_viewer_, V_2023_2, device_type_, device_precision_);
+#elif OPENVINO_VERSION_MAJOR == 2023 && OPENVINO_VERSION_MINOR == 3
+  data_ops_ = new DataOps(graph_viewer_, V_2023_3, device_type_, device_precision_);
+#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 0
+  data_ops_ = new DataOps(graph_viewer_, V_2024_0, device_type_, device_precision_);
+#else
+  data_ops_ = new DataOps(graph_viewer_, V_2024_0, device_type_, device_precision_);
+#endif
 }
 
 std::vector<std::unique_ptr<ComputeCapability>> GetCapability::Execute() {
diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.h b/onnxruntime/core/providers/openvino/ov_versions/capability.h
index 2040634cc45d9..d9fe5a95ef833 100644
--- a/onnxruntime/core/providers/openvino/ov_versions/capability.h
+++ b/onnxruntime/core/providers/openvino/ov_versions/capability.h
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #pragma once
@@ -21,8 +21,7 @@ class GetCapability {
  public:
   GetCapability(const GraphViewer& graph_viewer_param,
                 const std::string device_type_param,
-                const std::string precision,
-                const std::string version_param);
+                const std::string precision);
   virtual std::vector<std::unique_ptr<ComputeCapability>> Execute();
   bool IsWhollySupportedGraph() {
     return is_wholly_supported_graph_;
diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc
index e829bf377b195..c7c3e93595719 100644
--- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc
+++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #include <unordered_set>
@@ -14,6 +14,7 @@
 #include "data_ops.h"
 #include "capability.h"
 #include "utils.h"
+#include "../ov_interface.h"
 
 #if defined(_MSC_VER)
 #pragma warning(disable : 4244 4245 5208)
@@ -36,6 +37,7 @@ namespace openvino_ep {
 std::set<std::string> ops_supported_only_in_model = {
     "Add",
     "Cast",
+    "Celu",
     "Concat",
     "ConstantOfShape",
     "DequantizeLinear",
@@ -46,6 +48,7 @@ std::set<std::string> ops_supported_only_in_model = {
     "EyeLike",
     "GatherElements",
     "GatherND",
+    "GridSample",
     "Identity",
     "LayerNormalization",
     "Loop",
@@ -72,293 +75,171 @@ std::set<std::string> ops_supported_only_in_model = {
 std::set<std::string> ops_supported_as_function = {
     "LessOrEqual",
     "GreaterOrEqual",
-    "LayerNormalization"};
+    "LayerNormalization",
+    "Celu"};
 
 std::vector<SupportedOp> supported_op_mode = {
     {"Abs", V_2020_4, {"CPU", "GPU"}},
-    {"Abs", V_2023_0, {"NPU"}},
     {"Acos", V_2020_4, {"CPU"}},
     {"Acos", V_2022_1, {"GPU"}},
-    {"Acos", V_2023_1, {"NPU"}},
     {"Acosh", V_2020_4, {"CPU"}},
     {"Acosh", V_2022_1, {"GPU"}},
-    {"Acosh", V_2023_1, {"NPU"}},
     {"Add", V_2020_4, {"CPU", "GPU"}},
-    {"Add", V_2023_0, {"NPU"}},
     {"And", V_2020_4, {"CPU", "GPU"}},
-    {"And", V_2023_1, {"NPU"}},
     {"ArgMax", V_2020_4, {"CPU"}},
     {"ArgMax", V_2021_1, {"GPU"}},
     {"ArgMin", V_2020_4, {"CPU"}},
     {"ArgMin", V_2022_1, {"GPU"}},
     {"Asin", V_2020_4, {"CPU", "GPU"}},
-    {"Asin", V_2023_1, {"NPU"}},
     {"Asinh", V_2020_4, {"CPU", "GPU"}},
-    {"Asinh", V_2023_1, {"NPU"}},
     {"Atan", V_2020_4, {"CPU", "GPU"}},
-    {"Atan", V_2023_1, {"NPU"}},
     {"Atanh", V_2020_4, {"CPU"}},
     {"Atanh", V_2022_1, {"GPU"}},
-    {"Atanh", V_2023_1, {"NPU"}},
     {"AveragePool", V_2020_4, {"CPU", "GPU"}},
-    {"AveragePool", V_2023_0, {"NPU"}},
     {"BatchNormalization", V_2020_4, {"CPU", "GPU"}},
-    {"BatchNormalization", V_2023_0, {"NPU"}},
     {"BitShift", V_2022_1, {"CPU"}},
-    {"BitShift", V_2023_1, {"NPU"}},
     {"Cast", V_2020_4, {"CPU", "GPU"}},
-    {"Cast", V_2023_0, {"NPU"}},
-    {"CastLike", V_2023_1, {"CPU", "GPU", "NPU"}},
+    {"CastLike", V_2023_1, {"CPU", "GPU"}},
     {"Ceil", V_2020_4, {"GPU"}},
     {"Ceil", V_2021_4, {"CPU"}},
-    {"Ceil", V_2023_1, {"NPU"}},
     {"Celu", V_2022_1, {"CPU", "GPU"}},
     {"Clip", V_2020_4, {"CPU", "GPU"}},
-    {"Clip", V_2023_0, {"NPU"}},
     {"Compress", V_2023_1, {"CPU", "GPU"}},
     {"Concat", V_2020_4, {"CPU", "GPU"}},
-    {"Concat", V_2023_0, {"NPU"}},
     {"Constant", V_2020_4, {"CPU", "GPU"}},
-    {"Constant", V_2023_0, {"NPU"}},
     {"ConstantOfShape", V_2020_4, {"CPU", "GPU"}},
-    {"ConstantOfShape", V_2023_0, {"NPU"}},  // Gets mapped to broadcast op in the plugin.
     {"Conv", V_2020_4, {"CPU", "GPU"}},
-    {"Conv", V_2023_0, {"NPU"}},
     {"ConvInteger", V_2022_1, {"CPU", "GPU"}},
-    {"ConvInteger", V_2023_1, {"NPU"}},
     {"ConvTranspose", V_2020_4, {"CPU", "GPU"}},
-    {"ConvTranspose", V_2023_1, {"NPU"}},
     {"Cos", V_2020_4, {"CPU"}},
     {"Cos", V_2022_1, {"GPU"}},
-    {"Cos", V_2023_0, {"NPU"}},
     {"Cosh", V_2020_4, {"CPU"}},
     {"Cosh", V_2022_1, {"GPU"}},
-    {"Cosh", V_2023_1, {"NPU"}},
     {"CumSum", V_2022_1, {"CPU", "GPU"}},
-    {"CumSum", V_2023_0, {"NPU"}},
     {"DepthToSpace", V_2020_4, {"CPU", "GPU"}},
-    {"DepthToSpace", V_2023_0, {"NPU"}},
     {"DequantizeLinear", V_2021_4, {"CPU", "GPU"}},
-    {"DequantizeLinear", V_2023_0, {"NPU"}},
     {"Div", V_2020_4, {"CPU", "GPU"}},
-    {"Div", V_2023_0, {"NPU"}},
     {"Dropout", V_2020_4, {"CPU", "GPU"}},
-    {"Dropout", V_2023_0, {"NPU"}},
     {"Elu", V_2020_4, {"CPU", "GPU"}},
-    {"Elu", V_2023_0, {"NPU"}},
     {"Einsum", V_2023_1, {"CPU", "GPU"}},
     {"Equal", V_2020_4, {"CPU", "GPU"}},
-    {"Equal", V_2023_0, {"NPU"}},  // Added for whisper decoder model.
     {"Erf", V_2020_4, {"CPU", "GPU"}},
-    {"Erf", V_2023_0, {"NPU"}},
     {"Exp", V_2020_4, {"CPU", "GPU"}},
-    {"Exp", V_2023_0, {"NPU"}},
     {"Expand", V_2022_1, {"CPU", "GPU"}},
-    {"Expand", V_2023_0, {"NPU"}},  // Gets mapped to broadcast op and multiply op in the plugin.
     {"EyeLike", V_2022_1, {"CPU"}},
-    {"EyeLike", V_2023_0, {"NPU"}},  // NoOP
     {"Flatten", V_2020_4, {"CPU", "GPU"}},
-    {"Flatten", V_2023_0, {"NPU"}},
     {"Floor", V_2020_4, {"CPU", "GPU"}},
-    {"Floor", V_2023_1, {"NPU"}},
     {"Gather", V_2020_4, {"CPU", "GPU"}},
-    {"Gather", V_2023_0, {"NPU"}},
     {"GatherElements", V_2022_2, {"CPU", "GPU"}},
-    {"GatherElements", V_2023_1, {"NPU"}},
     {"GatherND", V_2021_4, {"CPU", "GPU"}},
-    {"GatherND", V_2023_1, {"NPU"}},
+    {"Gelu", V_2023_1, {"CPU", "GPU"}},
     {"Gemm", V_2020_4, {"CPU", "GPU"}},
-    {"Gemm", V_2023_0, {"NPU"}},
     {"GlobalAveragePool", V_2020_4, {"CPU", "GPU"}},
-    {"GlobalAveragePool", V_2023_0, {"NPU"}},
     {"GlobalLpPool", V_2020_4, {"CPU", "GPU"}},
-    {"GlobalLpPool", V_2023_1, {"NPU"}},
     {"GlobalMaxPool", V_2022_1, {"CPU", "GPU"}},
-    {"GlobalMaxPool", V_2023_1, {"NPU"}},
     {"Greater", V_2020_4, {"CPU", "GPU"}},
-    {"Greater", V_2023_0, {"NPU"}},
     {"GreaterOrEqual", V_2022_1, {"CPU", "GPU"}},
-    {"GreaterOrEqual", V_2023_0, {"NPU"}},
     {"GridSample", V_2022_3, {"CPU"}},
     {"GridSample", V_2023_0, {"GPU"}},
-    {"GridSample", V_2023_1, {"NPU"}},
-    {"HardMax", V_2023_1, {"CPU", "GPU", "NPU"}},
+    {"HardMax", V_2023_1, {"CPU", "GPU"}},
     {"Identity", V_2020_4, {"CPU", "GPU"}},
-    {"Identity", V_2023_0, {"NPU"}},  // NoOP
     {"If", V_2022_3, {"CPU", "GPU"}},
-    {"If", V_2023_1, {"NPU"}},
     {"ImageScaler", V_2022_1, {"CPU", "GPU"}},
-    {"ImageScaler", V_2023_0, {"NPU"}},
     {"InstanceNormalization", V_2020_4, {"CPU", "GPU"}},
-    {"InstanceNormalization", V_2023_0, {"NPU"}},
     {"HardSigmoid", V_2020_4, {"CPU", "GPU"}},
-    {"HardSigmoid", V_2023_1, {"NPU"}},
     {"HardMax", V_2022_1, {"CPU", "GPU"}},
+    {"LayerNormalization", V_2023_0, {"CPU", "GPU"}},
     {"LeakyRelu", V_2020_4, {"CPU", "GPU"}},
-    {"LeakyRelu", V_2023_0, {"NPU"}},
     {"Less", V_2020_4, {"CPU", "GPU"}},
-    {"Less", V_2023_0, {"NPU"}},  // Added for whisper decoder model.
     {"LessOrEqual", V_2022_1, {"CPU", "GPU"}},
-    {"LessOrEqual", V_2023_0, {"NPU"}},
     {"Log", V_2020_4, {"CPU", "GPU"}},
-    {"Log", V_2023_0, {"NPU"}},
     {"LogSoftMax", V_2022_1, {"CPU", "GPU"}},
     {"Loop", V_2021_4, {"CPU", "GPU"}},
-    {"LpNormalization", V_2023_1, {"CPU", "GPU", "NPU"}},
-    {"LpPool", V_2023_1, {"CPU", "GPU", "NPU"}},
+    {"LpNormalization", V_2023_1, {"CPU", "GPU"}},
     {"LRN", V_2020_4, {"CPU", "GPU"}},
-    {"LRN", V_2023_0, {"NPU"}},
     {"LSTM", V_2020_4, {"CPU", "GPU"}},
-    {"LSTM", V_2023_1, {"NPU"}},
     {"MatMul", V_2020_4, {"CPU", "GPU"}},
-    {"MatMul", V_2023_0, {"NPU"}},
     {"MatMulInteger", V_2022_1, {"CPU"}},
-    {"MatMulInteger", V_2023_1, {"NPU"}},
     {"Max", V_2020_4, {"CPU", "GPU"}},
-    {"Max", V_2023_0, {"NPU"}},
     {"MaxPool", V_2020_4, {"CPU", "GPU"}},
-    {"MaxPool", V_2023_0, {"NPU"}},
     {"Mean", V_2020_4, {"CPU", "GPU"}},
-    {"Mean", V_2023_0, {"NPU"}},
     {"MeanVarianceNormalization", V_2022_1, {"CPU", "GPU"}},
-    {"MeanVarianceNormalization", V_2023_1, {"NPU"}},
     {"Min", V_2020_4, {"CPU", "GPU"}},
-    {"Min", V_2023_0, {"NPU"}},
     {"Mod", V_2022_1, {"CPU", "GPU"}},
     {"Mul", V_2020_4, {"CPU", "GPU"}},
-    {"Mul", V_2023_0, {"NPU"}},
     {"Neg", V_2020_4, {"CPU", "GPU"}},
-    {"Neg", V_2023_0, {"NPU"}},
     {"NonMaxSuppression", V_2021_1, {"CPU", "GPU"}},
-    {"NonMaxSuppression", V_2023_1, {"NPU"}},
     {"NonZero", V_2021_1, {"CPU"}},
     {"NonZero", V_2023_0, {"GPU"}},
     {"Not", V_2021_1, {"CPU", "GPU"}},
     {"Not", V_2020_4, {"CPU", "GPU"}},
-    {"Not", V_2023_1, {"NPU"}},
     {"OneHot", V_2020_4, {"CPU", "GPU"}},
-    {"OneHot", V_2023_1, {"NPU"}},
     {"Or", V_2022_1, {"CPU", "GPU"}},
-    {"Or", V_2023_1, {"NPU"}},
     {"Pad", V_2020_4, {"CPU", "GPU"}},
-    {"Pad", V_2023_0, {"NPU"}},
     {"Pow", V_2020_4, {"CPU", "GPU"}},
-    {"Pow", V_2023_0, {"NPU"}},
     {"PRelu", V_2020_4, {"CPU", "GPU"}},
-    {"PRelu", V_2023_0, {"NPU"}},
     {"QLinearMatMul", V_2022_3, {"CPU"}},
-    // {"QLinearMatMul", V_2023_1, {"NPU"}},
     {"QuantizeLinear", V_2021_4, {"CPU", "GPU"}},
-    {"QuantizeLinear", V_2023_0, {"NPU"}},
     {"RNN", V_2023_1, {"CPU", "GPU"}},
     {"RandomNormalLike", V_2023_0, {"CPU", "GPU"}},
     {"RandomNormalLike", V_2023_0, {"CPU", "GPU"}},
-    {"RandomNormalLike", V_2023_1, {"NPU"}},
     {"RandomNormal", V_2023_0, {"CPU", "GPU"}},
-    {"RandomNormal", V_2023_1, {"NPU"}},
     {"Range", V_2022_1, {"CPU", "GPU"}},
-    {"Range", V_2023_0, {"NPU"}},
     {"Reciprocal", V_2020_4, {"CPU", "GPU"}},
-    {"Reciprocal", V_2023_0, {"NPU"}},
     {"ReduceL1", V_2022_1, {"CPU", "GPU"}},
-    {"ReduceL1", V_2023_1, {"NPU"}},
     {"ReduceL2", V_2022_1, {"CPU", "GPU"}},
-    {"ReduceL2", V_2023_1, {"NPU"}},
     {"ReduceLogSum", V_2020_4, {"CPU"}},
     {"ReduceLogSum", V_2022_1, {"CPU", "GPU"}},
-    {"ReduceLogSum", V_2023_1, {"NPU"}},
     {"ReduceLogSumExp", V_2022_1, {"CPU", "GPU"}},
-    {"ReduceLogSumExp", V_2023_1, {"NPU"}},
     {"ReduceMax", V_2020_4, {"CPU", "GPU"}},
-    {"ReduceMax", V_2023_1, {"NPU"}},
     {"ReduceMean", V_2020_4, {"CPU", "GPU"}},
-    {"ReduceMean", V_2023_0, {"NPU"}},
     {"ReduceMin", V_2020_4, {"CPU", "GPU"}},
-    {"ReduceMin", V_2023_1, {"NPU"}},
     {"ReduceProd", V_2020_4, {"CPU"}},
     {"ReduceProd", V_2022_1, {"GPU"}},
-    {"ReduceProd", V_2023_1, {"NPU"}},
     {"ReduceSum", V_2020_4, {"CPU", "GPU"}},
-    // {"ReduceSum", V_2023_1, {"NPU"}},
     {"ReduceSumSquare", V_2020_4, {"CPU"}},
     {"ReduceSumSquare", V_2022_1, {"CPU", "GPU"}},
-    {"ReduceSumSquare", V_2023_1, {"NPU"}},
     {"Relu", V_2020_4, {"CPU", "GPU"}},
-    {"Relu", V_2023_0, {"NPU"}},
     {"Resize", V_2020_4, {"CPU"}},
     {"Resize", V_2022_1, {"GPU"}},
-    {"Resize", V_2023_1, {"NPU"}},
     {"Reshape", V_2020_4, {"CPU", "GPU"}},
-    {"Reshape", V_2023_0, {"NPU"}},
     {"ReverseSequence", V_2022_1, {"CPU", "GPU"}},
     {"RoiAlign", V_2021_1, {"CPU", "GPU"}},
-    {"RoiAlign", V_2023_1, {"NPU"}},
     {"Round", V_2021_4, {"CPU", "GPU"}},
-    {"Round", V_2023_1, {"NPU"}},
     {"Scatter", V_2022_1, {"CPU", "GPU"}},
-    {"Scatter", V_2023_1, {"NPU"}},
     {"ScatterElements", V_2022_1, {"CPU", "GPU"}},
-    {"ScatterElements", V_2023_1, {"NPU"}},
     {"ScatterND", V_2022_1, {"CPU", "GPU"}},
-    {"ScatterND", V_2023_1, {"NPU"}},
     {"Selu", V_2020_4, {"CPU", "GPU"}},
-    {"Selu", V_2023_1, {"NPU"}},
     {"Shape", V_2020_4, {"CPU", "GPU"}},
-    {"Shape", V_2023_0, {"NPU"}},
     {"Shrink", V_2022_1, {"CPU", "GPU"}},
-    {"Shrink", V_2023_0, {"NPU"}},
     {"Sigmoid", V_2020_4, {"CPU", "GPU"}},
-    {"Sigmoid", V_2023_0, {"NPU"}},
     {"Sign", V_2020_4, {"CPU"}},
     {"Sign", V_2022_1, {"GPU"}},
-    {"Sign", V_2023_0, {"NPU"}},
     {"Sin", V_2022_1, {"CPU", "GPU"}},
-    {"Sin", V_2023_0, {"NPU"}},
     {"Sinh", V_2020_4, {"CPU"}},
-    {"Sinh", V_2023_1, {"NPU"}},
     {"Size", V_2022_1, {"CPU", "GPU"}},
-    {"Size", V_2023_1, {"NPU"}},
     {"Slice", V_2020_4, {"CPU", "GPU"}},
-    {"Slice", V_2023_0, {"NPU"}},
     {"Softmax", V_2020_4, {"CPU", "GPU"}},
-    {"Softmax", V_2023_0, {"NPU"}},
     {"Softplus", V_2022_1, {"CPU", "GPU"}},
-    {"Softplus", V_2023_0, {"NPU"}},
     {"Softsign", V_2022_1, {"CPU", "GPU"}},
     {"SpaceToDepth", V_2020_4, {"CPU", "GPU"}},
-    {"SpaceToDepth", V_2023_0, {"NPU"}},
     {"Split", V_2020_4, {"CPU", "GPU"}},
-    {"Split", V_2023_0, {"NPU"}},
     {"Sqrt", V_2020_4, {"CPU", "GPU"}},
-    {"Sqrt", V_2023_0, {"NPU"}},
     {"Squeeze", V_2020_4, {"CPU", "GPU"}},
-    {"Squeeze", V_2023_0, {"NPU"}},
     {"Softsign", V_2020_4, {"CPU"}},
     {"Sub", V_2020_4, {"CPU", "GPU"}},
-    {"Sub", V_2023_0, {"NPU"}},
     {"Sum", V_2020_4, {"CPU", "GPU"}},
-    {"Sum", V_2023_0, {"NPU"}},
     {"Tan", V_2020_4, {"CPU", "GPU"}},
-    {"Tan", V_2023_1, {"NPU"}},
     {"Tanh", V_2020_4, {"CPU", "GPU"}},
-    {"Tanh", V_2023_0, {"NPU"}},
     {"ThresholdedRelu", V_2022_1, {"CPU", "GPU"}},
-    {"ThresholdedRelu", V_2023_0, {"NPU"}},
     {"Tile", V_2021_3, {"CPU", "GPU"}},
-    {"Tile", V_2023_0, {"NPU"}},
     {"Transpose", V_2020_4, {"CPU", "GPU"}},
-    {"Transpose", V_2023_0, {"NPU"}},
     {"Trilu", V_2023_0, {"CPU", "GPU"}},
-    {"Trilu", V_2023_1, {"NPU"}},
     {"TopK", V_2020_4, {"CPU", "GPU"}},
-    {"TopK", V_2023_0, {"NPU"}},
     {"Upsample", V_2020_4, {"CPU", "GPU"}},
     {"Unsqueeze", V_2020_4, {"CPU", "GPU"}},
-    {"Unsqueeze", V_2023_0, {"NPU"}},
     {"Where", V_2022_1, {"CPU", "GPU"}},
-    {"Where", V_2023_0, {"NPU"}},  // Added for whisper decoder model.
     {"Xor", V_2022_1, {"CPU", "GPU"}},
-    {"Xor", V_2023_1, {"NPU"}},
 };
 
 void DataOps::populate_types_supported() {
@@ -370,6 +251,8 @@ void DataOps::populate_types_supported() {
       std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32));
   supported_types_initializer_.insert(
       std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64));
+  supported_types_initializer_.insert(
+      std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16));
   supported_types_initializer_.insert(
       std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16));
   supported_types_initializer_.insert(
@@ -387,6 +270,8 @@ void DataOps::populate_types_supported() {
       std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8));
   supported_types_npu_.insert(
       std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16));
+  supported_types_npu_.insert(
+      std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16));
   supported_types_npu_.insert(
       std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32));
   supported_types_npu_.insert(
@@ -402,6 +287,8 @@ void DataOps::populate_types_supported() {
       std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32));
   supported_types_cpu_.insert(
       std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16));
+  supported_types_cpu_.insert(
+      std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16));
   supported_types_cpu_.insert(
       std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8));
   supported_types_cpu_.insert(
@@ -437,13 +324,12 @@ void DataOps::populate_op_mode_supported() {
   no_dimension_supported_.push_back({"DequantizeLinear", V_2021_4, {"All"}});
   no_dimension_supported_.push_back({"Equal", V_2022_1, {"CPU"}});
   no_dimension_supported_.push_back({"Equal", V_2023_0, {"GPU"}});
+  no_dimension_supported_.push_back({"Expand", V_2023_3, {"CPU"}});
   no_dimension_supported_.push_back({"Floor", V_2020_4, {"All"}});
   no_dimension_supported_.push_back({"Gather", V_2020_4, {"All"}});
-  no_dimension_supported_.push_back({"Greater", V_2023_0, {"NPU"}});
   no_dimension_supported_.push_back({"Identity", V_2023_0, {"All"}});
   no_dimension_supported_.push_back({"Less", V_2022_1, {"CPU"}});
   no_dimension_supported_.push_back({"Loop", V_2021_4, {"All"}});
-  no_dimension_supported_.push_back({"Max", V_2023_0, {"NPU"}});
   no_dimension_supported_.push_back({"Min", V_2020_4, {"All"}});
   no_dimension_supported_.push_back({"Mul", V_2020_4, {"All"}});
   no_dimension_supported_.push_back({"Neg", V_2023_0, {"CPU", "GPU"}});
@@ -476,9 +362,8 @@ void DataOps::populate_op_mode_supported() {
   {
     UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3},
                              [this](const Node* node, const InitializedTensorSet&) {
-                               // Abs is not supproted with INT8 or INT32 as input data type on GPU and NPU
-                               if ((device_id_.find("GPU") != std::string::npos) ||
-                                   (device_id_.find("NPU") != std::string::npos)) {
+                               // Abs is not supproted with INT8 or INT32 as input data type on GPU
+                               if ((device_id_.find("GPU") != std::string::npos)) {
                                  for (size_t i = 0; i < node->InputDefs().size(); i++) {
                                    if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() ==
                                            ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 ||
@@ -706,7 +591,7 @@ void DataOps::populate_op_mode_supported() {
     op_list_.insert({"PRelu", obj});
   }
   {
-    UnsupportedOpMode obj = {{V_2023_0, V_2023_1, V_2023_2, V_2023_3},
+    UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0},
                              [this](const Node* node, const InitializedTensorSet&) {
                                const auto& input_arg = node->InputDefs()[1];
                                auto shape = input_arg->Shape();
@@ -821,7 +706,7 @@ void DataOps::populate_op_mode_supported() {
     op_list_.insert({"Squeeze", obj});
   }
   {
-    UnsupportedOpMode obj = {{V_2023_0, V_2023_1, V_2023_2, V_2023_3},
+    UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0},
                              [this](const Node* node, const InitializedTensorSet&) {
                                // If the operator is unsqueeze
                                // If axes is an input, then we cannot produce a static graph.
@@ -836,7 +721,7 @@ void DataOps::populate_op_mode_supported() {
     op_list_.insert({"Unsqueeze", obj});
   }
   {
-    UnsupportedOpMode obj = {{V_2023_0, V_2023_1, V_2023_2, V_2023_3},
+    UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0},
                              [this](const Node* node, const InitializedTensorSet&) {
                                // check for attributes
                                auto& upsample_attr = node->GetAttributes();
@@ -961,7 +846,7 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) {
   } else {
     auto dtype = type_proto->tensor_type().elem_type();
 
-    if (device_id_.find("NPU") != std::string::npos || device_id_.find("HETERO") != std::string::npos ||
+    if (device_id_.find("HETERO") != std::string::npos ||
         device_id_.find("MULTI") != std::string::npos || device_id_.find("AUTO") != std::string::npos) {
       for (auto const& var : supported_types_npu_) {
         if ((var.first <= version_id_) &&
@@ -1063,8 +948,7 @@ bool DataOps::dimension_unsupported(const Node* node) {
   return true;
 }
 
-bool DataOps::node_is_supported(const std::map<std::string, std::set<std::string>>& op_map,
-                                const NodeIndex node_idx) {
+bool DataOps::node_is_supported(const NodeIndex node_idx) {
   const auto& node = graph_viewer_.GetNode(node_idx);
   const auto& optype = node->OpType();
 
@@ -1174,37 +1058,14 @@ bool DataOps::node_is_supported(const std::map<std::string, std::set<std::string
     return false;
   }
 
-  // Check 3b
-  const auto opset = op_map.find(domain);
-  const auto op_fun = ops_supported_as_function.find(node->OpType());
-  if (opset == op_map.end()) {
-#ifndef NDEBUG
-    if (openvino_ep::backend_utils::IsDebugEnabled()) {
-      std::cout << "Failed in Unsupported onnx model domain" << std::endl;
-    }
-#endif
-    return false;
-  }
-  if (opset->second.find(optype) == opset->second.end() && op_fun == ops_supported_as_function.end()) {
-#ifndef NDEBUG
-    if (openvino_ep::backend_utils::IsDebugEnabled()) {
-      std::cout << "The operator is not available in OpenVINO ngraph operators list"
-                << "nor the operator is a special ONNX function"
-                << std::endl;
-    }
-#endif
-    return false;
-  }
   return true;
 }
 
 std::vector<NodeIndex> DataOps::GetUnsupportedNodeIndices(std::unordered_set<std::string>& ng_required_initializers) {
-  const auto ng_supported_ops = GetNgSupportedOps(GetOnnxOpSet(graph_viewer_));
-
   std::vector<NodeIndex> unsupported_nodes_idx;
 
   for (const auto& node_idx : graph_viewer_.GetNodesInTopologicalOrder()) {
-    if (node_is_supported(ng_supported_ops, node_idx)) {
+    if (node_is_supported(node_idx)) {
       // Collect inputs that are initializers
       graph_viewer_.GetNode(node_idx)->ForEachDef([&ng_required_initializers, this](const NodeArg& node_arg,
                                                                                     bool is_input) {
diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h
index 87688601ad692..0990904908111 100644
--- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h
+++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #pragma once
@@ -26,7 +26,8 @@ enum versionNum {
   V_2023_0,
   V_2023_1,
   V_2023_2,
-  V_2023_3
+  V_2023_3,
+  V_2024_0
 };
 
 using VersionNum = enum versionNum;
@@ -67,9 +68,7 @@ class DataOps {
   bool dimension_unsupported(const Node* node);
   bool unsupported_op_mode(const Node* node);
   bool type_is_supported(const NodeArg* node_arg, bool is_initializer);
-  bool node_is_supported(const std::map<std::string,
-                                        std::set<std::string>>& op_map,
-                         const NodeIndex node_idx);
+  bool node_is_supported(const NodeIndex node_idx);
 
  public:
   DataOps(const GraphViewer& graph_viewer_param, VersionNum ver, const std::string dev_id, const std::string device_precision)
diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.cc b/onnxruntime/core/providers/openvino/ov_versions/utils.cc
index ee0bfddb7dc83..c5ed29df487b4 100644
--- a/onnxruntime/core/providers/openvino/ov_versions/utils.cc
+++ b/onnxruntime/core/providers/openvino/ov_versions/utils.cc
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 
 #include "core/providers/shared_library/provider_api.h"
@@ -11,14 +11,6 @@
 #pragma GCC diagnostic ignored "-Wunused-parameter"
 #endif
 
-#include "openvino/core/deprecated.hpp"
-#define IN_OV_COMPONENT
-#define NGRAPH_LEGACY_HEADER_INCLUDED
-#include <ngraph/frontend/onnx_import/onnx.hpp>
-
-#undef NGRAPH_LEGACY_HEADER_INCLUDED
-#undef IN_OV_COMPONENT
-
 #if defined(_MSC_VER)
 #pragma warning(default : 4244 4245)
 #elif __GNUC__
@@ -95,20 +87,6 @@ int GetOnnxOpSet(const GraphViewer& graph_viewer) {
   return dm_to_ver.at(kOnnxDomain);
 }
 
-std::map<std::string, std::set<std::string>> GetNgSupportedOps(const int onnx_opset) {
-  std::map<std::string, std::set<std::string>> ng_supported_ops;
-  OPENVINO_SUPPRESS_DEPRECATED_START
-  ng_supported_ops.emplace(kOnnxDomain, ngraph::onnx_import::get_supported_operators(onnx_opset, kOnnxDomain));
-
-  const std::set<std::string> ng_disabled_ops = {"LSTM"};  // Place-holder for ops not supported.
-
-  for (const auto& disabled_op : ng_disabled_ops) {
-    ng_supported_ops.at(kOnnxDomain).erase(disabled_op);
-  }
-  OPENVINO_SUPPRESS_DEPRECATED_END
-  return ng_supported_ops;
-}
-
 /**
  * Returns a vector clusters(or node_idx). For each unsupported node, the graph is split into 3 parts.
  * supported_cluster + (UNsupported_node + rest_of_the_graph). This functions returns vector of all supported_clusters by nGraph
diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.h b/onnxruntime/core/providers/openvino/ov_versions/utils.h
index b3edeef88dfec..34aa762ba9b67 100644
--- a/onnxruntime/core/providers/openvino/ov_versions/utils.h
+++ b/onnxruntime/core/providers/openvino/ov_versions/utils.h
@@ -1,4 +1,4 @@
-// Copyright (C) 2019-2022 Intel Corporation
+// Copyright (C) Intel Corporation
 // Licensed under the MIT License
 #pragma once
 
diff --git a/onnxruntime/core/providers/partitioning_utils.cc b/onnxruntime/core/providers/partitioning_utils.cc
index d537a4cf58b2d..c45f5cd0848dd 100644
--- a/onnxruntime/core/providers/partitioning_utils.cc
+++ b/onnxruntime/core/providers/partitioning_utils.cc
@@ -1,6 +1,9 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
+// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build
+#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
+
 #include "core/providers/partitioning_utils.h"
 
 #include <algorithm>
@@ -10,6 +13,7 @@
 
 #include "core/framework/compute_capability.h"
 #include "core/framework/execution_provider.h"
+#include "core/framework/node_unit.h"
 #include "core/graph/graph_viewer.h"
 #include "core/providers/common.h"
 
@@ -76,6 +80,11 @@ When selecting the next node to process, we first take:
 The remaining unsupported nodes mark the border of the current group so they will be processed later when we consider
 the next group.
 
+If node_unit_map is provided, we process NodeUnit instances (a logical 'Node' that can be a single node or a
+QDQ node group) instead of individual Node instances. As an EP must take complete NodeUnit instances (i.e. it
+must not break up a QDQ node group by taking a subset of nodes in it), this granularity of processing is valid.
+It is required to ensure we do not break up a QDQ node unit during partitioning.
+
 @param graph_viewer GraphViewer that IExecutionProvider::GetCapability is called with.
 @param is_node_supported_fn Callback to check whether a node is supported.
 @param on_group_closed_fn Callback to indicate a completed partition node group.
@@ -88,6 +97,7 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
     const IsNodeSupportedFn& is_node_supported_fn,
     const OnGroupClosedFn& on_group_closed_fn,
     const std::string& execution_provider_type,
+    const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map,
     bool debug_output) {
 #ifdef NDEBUG
   ORT_UNUSED_PARAMETER(debug_output);
@@ -111,7 +121,18 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
   // initialize in-degrees and find root nodes
   for (const auto& node_index : graph_viewer.GetNodesInTopologicalOrder()) {
     const auto& node = *graph_viewer.GetNode(node_index);
-    const auto node_input_edge_count = node.GetInputEdgesCount();
+    auto node_input_edge_count = node.GetInputEdgesCount();
+
+    if (node_unit_map != nullptr) {
+      const auto& node_unit = node_unit_map->at(&node);
+      if (&node_unit->GetNode() != &node) {
+        // only process the target node
+        continue;
+      }
+
+      node_input_edge_count = node_unit->InputEdgeCount();
+    }
+
     in_degree.insert({node.Index(), node_input_edge_count});
     if (node_input_edge_count == 0) {
       nodes_to_process.push_back(&node);
@@ -151,6 +172,8 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
     }
   };
 
+  size_t num_nodes_processed = 0;
+
   while (!nodes_to_process.empty() || !nodes_to_process_with_next_group.empty()) {
     if (nodes_to_process.empty()) {
       // we have processed all the nodes that we can while building this partition node group, start a new one
@@ -162,9 +185,13 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
     const Node& node = *nodes_to_process.front();
     nodes_to_process.pop_front();
 
+    const NodeUnit* node_unit = node_unit_map ? node_unit_map->at(&node) : nullptr;
+    const bool is_qdq_node_unit = node_unit && node_unit->UnitType() == NodeUnit::Type::QDQGroup;
+
     // a node that is already assigned to an EP other than current EP is unsupported
-    const bool is_node_supported =
-        (node.GetExecutionProviderType().empty() || node.GetExecutionProviderType() == execution_provider_type) && is_node_supported_fn(node);
+    const bool is_node_supported = (node.GetExecutionProviderType().empty() ||
+                                    node.GetExecutionProviderType() == execution_provider_type) &&
+                                   is_node_supported_fn(node);
 
     if (!is_node_supported && Contains(supported_group_border, &node)) {
       // an unsupported node on the border will be processed after the current partition node group
@@ -173,34 +200,62 @@ std::vector<std::vector<const Node*>> CreateSupportedPartitionNodeGroups(
     }
 
     if (is_node_supported) {
-      // add node to the partition node group
-      supported_group.push_back(&node);
+      if (is_qdq_node_unit) {
+        // add DQ -> node -> Q for the node unit. must be in topological order
+        for (const auto& dq : node_unit->GetDQNodes()) {
+          supported_group.push_back(dq);
+        }
 
-      // remove node from the border and add its outputs to the border
+        supported_group.push_back(&node);
+
+        for (const auto& q : node_unit->GetQNodes()) {
+          supported_group.push_back(q);
+        }
+      } else {
+        supported_group.push_back(&node);
+      }
+
+      // remove node from the border
       supported_group_border.erase(&node);
+    }
 
-      std::for_each(
-          node.OutputNodesBegin(), node.OutputNodesEnd(),
-          [&supported_group_border](const Node& output) {
-            supported_group_border.insert(&output);
-          });
+    // For each downstream node:
+    //   1: add the downstream node to the border if the current node is supported
+    //   2: adjust in-degrees of the nodes consuming the current node's outputs, and add any new nodes to process
+    const auto process_downstream_node = [&](const Node& downstream_node) {
+      if (is_node_supported) {
+        supported_group_border.insert(&downstream_node);
+      }
+
+      auto& downstream_node_in_degree = in_degree[downstream_node.Index()];
+      --downstream_node_in_degree;
+
+      if (downstream_node_in_degree == 0) {
+        nodes_to_process.push_back(&downstream_node);
+      }
+    };
+
+    if (node_unit_map) {
+      std::for_each(node_unit->OutputEdgesBegin(), node_unit->OutputEdgesEnd(),
+                    [&](const Node::EdgeEnd& edge_end) {
+                      const Node& n = edge_end.GetNode();
+                      const NodeUnit& downstream_node_unit = *node_unit_map->at(&n);
+                      const Node& output = downstream_node_unit.GetNode();
+
+                      process_downstream_node(output);
+                    });
+    } else {
+      std::for_each(node.OutputNodesBegin(), node.OutputNodesEnd(), process_downstream_node);
     }
 
-    // adjust in-degrees of the node outputs and add any new nodes to process
-    std::for_each(
-        node.OutputNodesBegin(), node.OutputNodesEnd(),
-        [&](const Node& output) {
-          auto& output_node_in_degree = in_degree[output.Index()];
-          --output_node_in_degree;
-
-          if (output_node_in_degree == 0) {
-            nodes_to_process.push_back(&output);
-          }
-        });
+    ++num_nodes_processed;
   }
 
   close_group();
 
+  ORT_ENFORCE(num_nodes_processed == in_degree.size(),
+              "Processed ", num_nodes_processed, " nodes. Expected to process ", in_degree.size());
+
   return supported_groups;
 }
 }  // namespace
@@ -318,11 +373,13 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
                           const GenerateMetadefNameFn& generate_metadef_name_fn,
                           const std::string& execution_provider_name,
                           const std::string& execution_provider_type,
+                          const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map,
                           bool debug_output) {
   const auto groups = CreateSupportedPartitionNodeGroups(graph_viewer,
                                                          is_node_supported_fn,
                                                          on_partition_closed_fn,
                                                          execution_provider_type,
+                                                         node_unit_map,
                                                          debug_output);
 
   std::vector<std::unique_ptr<ComputeCapability>> partitions{};
@@ -346,6 +403,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
                           const GenerateMetadefNameFn& generate_metadef_name_fn,
                           const std::string& execution_provider_name,
                           const std::string& execution_provider_type,
+                          const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map,
                           bool debug_output) {
   const auto excluded_nodes = CreateExcludedNodeSet(graph_viewer, stop_ops);
   const bool check_excluded_nodes = !excluded_nodes.empty();
@@ -360,8 +418,11 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
       generate_metadef_name_fn,
       execution_provider_name,
       execution_provider_type,
+      node_unit_map,
       debug_output);
 }
 
 }  // namespace utils
 }  // namespace onnxruntime
+
+#endif  // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
diff --git a/onnxruntime/core/providers/partitioning_utils.h b/onnxruntime/core/providers/partitioning_utils.h
index 136725c2f7250..c3f6b104e3f6a 100644
--- a/onnxruntime/core/providers/partitioning_utils.h
+++ b/onnxruntime/core/providers/partitioning_utils.h
@@ -3,6 +3,9 @@
 
 #pragma once
 
+// QDQ models require graph modification at runtime, so we know this infrastructure is not used in a minimal build
+#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
+
 #include <functional>
 #include <memory>
 #include <unordered_set>
@@ -14,8 +17,9 @@
 namespace onnxruntime {
 struct ComputeCapability;
 class GraphViewer;
-class NodeArg;
 class Node;
+class NodeArg;
+class NodeUnit;
 
 namespace utils {
 
@@ -56,6 +60,8 @@ Create the supported partitions for the execution provider.
 @param generate_metadef_name_fn Callback to create the name for the MetaDef.
 @param execution_provider_name Name of execution provider creating the ComputeCapability instance.
 @param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance.
+@param node_unit_map Map of each Node in the graph_viewer to its NodeUnit. Provide if EP handles QDQ format models.
+                     Should be created by EP calling GetAllNodeUnits.
 @param debug_output Print diagnostic output about the partitions and reasons for partition breaks.
                     No-op in a release build.
 
@@ -68,6 +74,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
                           const GenerateMetadefNameFn& generate_metadef_name_fn,
                           const std::string& execution_provider_name,
                           const std::string& execution_provider_type,
+                          const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map = nullptr,
                           bool debug_output = false);
 
 /**
@@ -79,6 +86,8 @@ Create the supported partitions for the execution provider.
 @param generate_metadef_name Functor to create the name for the MetaDef.
 @param execution_provider_name Name of execution provider creating the ComputeCapability instance.
 @param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance.
+@param node_unit_map Map of each Node in the graph_viewer to its NodeUnit. Provide if EP handles QDQ format models.
+                     Should be created by EP calling GetAllNodeUnits.
 @param debug_output Print diagnostic output about the partitions and reasons for partition breaks.
                     No-op in a release build.
 
@@ -91,6 +100,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer,
                           const GenerateMetadefNameFn& generate_metadef_name,
                           const std::string& execution_provider_name,
                           const std::string& execution_provider_type,
+                          const std::unordered_map<const Node*, const NodeUnit*>* node_unit_map = nullptr,
                           bool debug_output = false);
 
 /**
@@ -125,3 +135,5 @@ InlinedHashSet<const Node*> CreateExcludedNodeSet(const GraphViewer& graph_viewe
                                                   const std::unordered_set<std::string>& stop_ops);
 }  // namespace utils
 }  // namespace onnxruntime
+
+#endif  // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
index c2e71081b898e..2d8ec295d613b 100644
--- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
+++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
@@ -151,12 +151,14 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
 Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
                                const onnxruntime::PathString& ctx_onnx_model_path,
                                QnnBackendManager* qnn_backend_manager,
-                               std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
+                               std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
+                               const logging::Logger& logger) {
   Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, qnn_models);
 
   // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
   if (!status.IsOK()) {
-    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage());
+    LOGS(logger, ERROR) << "Failed to load from EpContext model. " << status.ErrorMessage();
+    return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContext model. ", status.ErrorMessage());
   }
 
   return Status::OK();
diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
index b1360b4e576fa..7d56b45a1dbcd 100644
--- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
+++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
@@ -56,7 +56,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
 Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
                                const onnxruntime::PathString& ctx_onnx_model_path,
                                QnnBackendManager* qnn_backend_manager,
-                               std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);
+                               std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
+                               const logging::Logger& logger);
 
 Status CreateEPContextNodes(Model* model,
                             unsigned char* buffer,
diff --git a/onnxruntime/core/providers/qnn/builder/op_builder.h b/onnxruntime/core/providers/qnn/builder/op_builder.h
index 018d9a2797a66..05398c3f22ea2 100644
--- a/onnxruntime/core/providers/qnn/builder/op_builder.h
+++ b/onnxruntime/core/providers/qnn/builder/op_builder.h
@@ -4,7 +4,7 @@
 #pragma once
 
 #include "core/graph/graph_viewer.h"
-#include "core/providers/shared/node_unit/node_unit.h"
+#include "core/framework/node_unit.h"
 #include "core/providers/shared/utils/utils.h"
 
 namespace onnxruntime {
diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h
index d95e2baa9457f..4a9106f0c06af 100644
--- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h
+++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h
@@ -94,5 +94,28 @@ void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r
 
 void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
 
+struct HandleConvertResult {
+  Status status;                // Indicates an unexpected error. Check if q_node_unit != nullptr to determine
+                                // whether a DQ -> Q sequence was successfully merged into a Convert.
+  const NodeUnit* q_node_unit;  // Non-null if successfully merged DQ -> Q sequence.
+                                // Set to nullptr if this node unit could not be merged into a Convert.
+};
+
+/**
+ * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from
+ * one quantization type (e.g., uint8_t) to another (e.g., uint16_t).
+ *
+ * \param qnn_model_wrapper The QNN model that is being built.
+ * \param maybe_dq_node_unit The node unit that could potentially start the DQ -> Q sequence.
+ * \param logger The logger.
+ * \param do_op_validation True if should call QNN operator validation APIs.
+ * \return An qnn::HandleConvertResult object that indicates success/failure and provides a pointer
+ *         to the Q node unit that was successfully merged with the provided DQ node unit.
+ */
+HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper,
+                                             const NodeUnit& maybe_dq_node_unit,
+                                             const std::unordered_map<const Node*, const NodeUnit*>& node_unit_map,
+                                             const logging::Logger& logger,
+                                             bool do_op_validation);
 }  // namespace qnn
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc
index 0a9f9889ad2d8..dc99687e78d30 100644
--- a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc
+++ b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc
@@ -36,6 +36,27 @@ class ClipOpBuilder : public BaseOpBuilder {
   Status ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const;
 };
 
+static Status ProcessClipMinMax(QnnModelWrapper& qnn_model_wrapper,
+                                const NodeUnitIODef& input,
+                                float& float_value) {
+  TensorInfo input_info = {};
+  std::vector<uint8_t> val_bytes;
+  ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input, input_info));
+  assert(input_info.is_initializer);  // Checked by ExplicitOpCheck().
+  if (QNN_DATATYPE_FLOAT_16 == input_info.qnn_data_type) {
+    ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info.initializer_tensor, val_bytes));
+    MLFloat16 fp16_value = *reinterpret_cast<const MLFloat16*>(val_bytes.data());
+    float_value = fp16_value.ToFloat();
+  } else {
+    ORT_RETURN_IF_NOT(QNN_DATATYPE_FLOAT_32 == input_info.qnn_data_type,
+                      "QNN EP: The 'min' input of the Clip operator must be of type float32.");
+    ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info.initializer_tensor, val_bytes));
+    float_value = *reinterpret_cast<const float*>(val_bytes.data());
+  }
+
+  return Status::OK();
+}
+
 Status ClipOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
   if (node_unit.Inputs().size() > 1) {
     const auto& min_input_name = node_unit.Inputs()[1].node_arg.Name();
@@ -75,54 +96,36 @@ Status ClipOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
   const Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32;
   std::vector<std::string> param_tensor_names;
 
-  auto get_f32_from_bytes = [](const std::vector<uint8_t>& bytes, float default_val) -> float {
-    return bytes.empty() ? default_val : *reinterpret_cast<const float*>(bytes.data());
-  };
-
   // Set the 'min' parameter.
-  {
-    std::vector<uint8_t> min_val_bytes;
-
-    if (num_inputs > 1 && !inputs[1].node_arg.Name().empty()) {
-      TensorInfo min_input_info = {};
-      ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], min_input_info));
-      ORT_RETURN_IF_NOT(min_input_info.qnn_data_type == qnn_data_type,
-                        "QNN EP: The 'min' input of the Clip operator must be of type float32.");
-      assert(min_input_info.is_initializer);  // Checked by ExplicitOpCheck().
-      ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*min_input_info.initializer_tensor, min_val_bytes));
-    }
+  Qnn_Scalar_t min_qnn_scalar = QNN_SCALAR_INIT;
+  min_qnn_scalar.dataType = qnn_data_type;
 
-    Qnn_Scalar_t min_qnn_scalar = QNN_SCALAR_INIT;
-    min_qnn_scalar.dataType = qnn_data_type;
-    min_qnn_scalar.floatValue = get_f32_from_bytes(min_val_bytes, std::numeric_limits<float>::lowest());
-    QnnParamWrapper min_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE,
-                                    min_qnn_scalar);
-    param_tensor_names.push_back(min_value_param.GetParamTensorName());
-    qnn_model_wrapper.AddParamWrapper(std::move(min_value_param));
+  if (num_inputs > 1 && !inputs[1].node_arg.Name().empty()) {
+    ORT_RETURN_IF_ERROR(ProcessClipMinMax(qnn_model_wrapper, inputs[1], min_qnn_scalar.floatValue));
+  } else {
+    min_qnn_scalar.floatValue = std::numeric_limits<float>::lowest();
   }
 
+  QnnParamWrapper min_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE,
+                                  min_qnn_scalar);
+  param_tensor_names.push_back(min_value_param.GetParamTensorName());
+  qnn_model_wrapper.AddParamWrapper(std::move(min_value_param));
+
   // Set the 'max' parameter.
-  {
-    std::vector<uint8_t> max_val_bytes;
-
-    if (num_inputs > 2 && !inputs[2].node_arg.Name().empty()) {
-      TensorInfo max_input_info = {};
-      ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[2], max_input_info));
-      ORT_RETURN_IF_NOT(max_input_info.qnn_data_type == qnn_data_type,
-                        "QNN EP: The 'max' input of the Clip operator must of type float32.");
-      assert(max_input_info.is_initializer);  // Checked by ExplicitOpCheck().
-      ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*max_input_info.initializer_tensor, max_val_bytes));
-    }
+  Qnn_Scalar_t max_qnn_scalar = QNN_SCALAR_INIT;
+  max_qnn_scalar.dataType = qnn_data_type;
 
-    Qnn_Scalar_t max_qnn_scalar = QNN_SCALAR_INIT;
-    max_qnn_scalar.dataType = qnn_data_type;
-    max_qnn_scalar.floatValue = get_f32_from_bytes(max_val_bytes, std::numeric_limits<float>::max());
-    QnnParamWrapper max_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE,
-                                    max_qnn_scalar);
-    param_tensor_names.push_back(max_value_param.GetParamTensorName());
-    qnn_model_wrapper.AddParamWrapper(std::move(max_value_param));
+  if (num_inputs > 2 && !inputs[2].node_arg.Name().empty()) {
+    ORT_RETURN_IF_ERROR(ProcessClipMinMax(qnn_model_wrapper, inputs[2], max_qnn_scalar.floatValue));
+  } else {
+    max_qnn_scalar.floatValue = std::numeric_limits<float>::max();
   }
 
+  QnnParamWrapper max_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE,
+                                  max_qnn_scalar);
+  param_tensor_names.push_back(max_value_param.GetParamTensorName());
+  qnn_model_wrapper.AddParamWrapper(std::move(max_value_param));
+
   ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit,
                                      std::move(input_names),
                                      std::move(param_tensor_names),
diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc
new file mode 100644
index 0000000000000..977a9e0b3d9d0
--- /dev/null
+++ b/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc
@@ -0,0 +1,103 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/graph/graph_utils.h"
+#include "core/optimizer/qdq_transformer/qdq_util.h"
+#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
+#include "core/providers/shared/utils/utils.h"
+#include "core/providers/qnn/builder/qnn_model_wrapper.h"
+#include "core/providers/qnn/builder/op_builder_factory.h"
+#include "core/common/safeint.h"
+#include "onnx/defs/data_type_utils.h"
+
+#include "QnnOpDef.h"  // From QNN SDK: contains QNN constants (e.g., op names, param values).
+
+namespace onnxruntime {
+namespace qnn {
+
+class ConvertOpBuilder : public BaseOpBuilder {
+ public:
+  ConvertOpBuilder() : BaseOpBuilder("ConvertOpBuilder") {}
+  ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConvertOpBuilder);
+
+  Status AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper,
+                                  const NodeUnit& dq_node_unit,
+                                  const NodeUnit& q_node_unit,
+                                  const logging::Logger& logger,
+                                  bool do_op_validation) const ORT_MUST_USE_RESULT;
+};
+
+Status ConvertOpBuilder::AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper,
+                                                  const NodeUnit& dq_node_unit,
+                                                  const NodeUnit& q_node_unit,
+                                                  const logging::Logger& logger,
+                                                  bool do_op_validation) const {
+  std::vector<std::string> input_names;
+
+  // Process the input from the DQ node
+  ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, dq_node_unit.Inputs()[0], logger, input_names));
+
+  // Process the output from the Q node. Override the QNN operator type to "Convert".
+  ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, q_node_unit, std::move(input_names), {},
+                                     logger, do_op_validation, QNN_OP_CONVERT));
+  return Status::OK();
+}
+
+HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper,
+                                             const NodeUnit& maybe_dq_node_unit,
+                                             const std::unordered_map<const Node*, const NodeUnit*>& node_unit_map,
+                                             const logging::Logger& logger,
+                                             bool do_op_validation) {
+  const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer();
+
+  // Looking for a standalone DQ to start the sequence.
+  if (maybe_dq_node_unit.OpType() != QDQ::DQOpName || maybe_dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) {
+    return {};
+  }
+
+  const Node& dq_node = maybe_dq_node_unit.GetNode();
+
+  // DQ must have a single Q child. DQ must not produce a graph output.
+  auto children = graph_utils::FindChildrenByType(dq_node, QDQ::QOpName);
+  if (children.size() != 1 || dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) {
+    return {};
+  }
+
+  const Node& q_node = *children[0];
+  const auto q_node_unit_it = node_unit_map.find(&q_node);
+
+  if (q_node_unit_it == node_unit_map.end()) {
+    return {ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node does not have a corresponding NodeUnit"), nullptr};
+  }
+
+  const NodeUnit* q_node_unit = q_node_unit_it->second;
+
+  // Q child must not already be part of a QDQ NodeUnit (i.e., be standalone).
+  if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) {
+    return {};
+  }
+
+  auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) {
+    return graph_viewer.GetConstantInitializer(initializer_name, true);
+  };
+
+  // DQ and Q must have equal scale type and different zp type.
+  if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) {
+    return {};
+  }
+
+  ConvertOpBuilder op_builder;
+
+  LOGS(logger, VERBOSE) << " Adding QNN Convert. dq_node name: [" << dq_node.Name()
+                        << "] dq_node optype: [" << dq_node.OpType()
+                        << "] q_node name: [" << q_node_unit->Name()
+                        << "] q_node optype: [" << q_node_unit->OpType()
+                        << "]";
+
+  auto status = op_builder.AddConvertToModelBuilder(qnn_model_wrapper, maybe_dq_node_unit, *q_node_unit, logger,
+                                                    do_op_validation);
+  return status.IsOK() ? HandleConvertResult{status, q_node_unit} : HandleConvertResult{status, nullptr};
+}
+
+}  // namespace qnn
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
index 5f0b87c7cb9d7..6bb57b6a3e56c 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
+++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
@@ -573,11 +573,16 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
 
   // More work to support multiple partition, how to map the graph name in compile to qnn graph name
   // Need the lower level framework to understand EPContext op and pass in the partition_name in fused_node during Compile
-  for (uint32_t i = 0; i < graph_count; ++i) {
-    std::string graph_name(graphs_info[i].graphInfoV1.graphName);
-    auto qnn_model_pos = qnn_models.find(graph_name);
-    ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names.");
-    ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i]));
+  if (1 == graph_count) {
+    auto qnn_model_pose = qnn_models.begin();
+    ORT_RETURN_IF_ERROR(qnn_model_pose->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[0]));
+  } else {
+    for (uint32_t i = 0; i < graph_count; ++i) {
+      std::string graph_name(graphs_info[i].graphInfoV1.graphName);
+      auto qnn_model_pos = qnn_models.find(graph_name);
+      ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names.");
+      ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i]));
+    }
   }
 
   qnn_sys_interface_.systemContextFree(sys_ctx_handle);
@@ -629,11 +634,6 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_
     LOGS(logger, VERBOSE) << "CreateContext succeed.";
   }
 
-  if (htp_performance_mode_ != HtpPerformanceMode::kHtpDefault) {
-    ORT_RETURN_IF_ERROR(SetHtpPowerConfig());
-    LOGS(logger, VERBOSE) << "SetHtpPowerConfig succeed.";
-  }
-
   LOGS(logger, VERBOSE) << "QNN SetupBackend succeed";
 
   backend_setup_completed_ = true;
@@ -641,7 +641,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_
   return Status::OK();
 }
 
-Status QnnBackendManager::SetHtpPowerConfig() {
+Status QnnBackendManager::CreateHtpPowerCfgId(uint32_t device_id, uint32_t core_id, uint32_t& htp_power_config_id) {
   QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
   auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);
   ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed.");
@@ -651,26 +651,40 @@ Status QnnBackendManager::SetHtpPowerConfig() {
                 "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type.");
   QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra;
   // Get power client id
-  status = htp_perf_infra.createPowerConfigId(/*device_id=*/0, /*core_id=*/0, &htp_power_config_client_id_);
+  status = htp_perf_infra.createPowerConfigId(device_id, core_id, &htp_power_config_id);
   ORT_RETURN_IF(QNN_SUCCESS != status, "createPowerConfigId failed.");
 
+  return Status::OK();
+}
+
+Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id,
+                                            HtpPerformanceMode htp_performance_mode) {
+  QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
+  auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);
+  ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed.");
+
+  auto* htp_infra = static_cast<QnnHtpDevice_Infrastructure_t*>(qnn_device_infra);
+  ORT_RETURN_IF(QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_PERF != htp_infra->infraType,
+                "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type.");
+  QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra;
+
   constexpr const int kNumConfigs = 1;
   std::vector<QnnHtpPerfInfrastructure_PowerConfig_t> power_configs(
       kNumConfigs);
   QnnHtpPerfInfrastructure_PowerConfig_t& dcvs_config = power_configs[0];
   dcvs_config.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_DCVS_V3;
   QnnHtpPerfInfrastructure_DcvsV3_t& dcvs_v3 = dcvs_config.dcvsV3Config;
-  dcvs_v3.contextId = htp_power_config_client_id_;
+  dcvs_v3.contextId = htp_power_config_client_id;
   dcvs_v3.setSleepDisable = 0;
   dcvs_v3.sleepDisable = 0;
   dcvs_v3.setDcvsEnable = 1;
-  dcvs_v3.dcvsEnable = kDcvsDisable;
   dcvs_v3.powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_PERFORMANCE_MODE;
   // choose performance mode
-  switch (htp_performance_mode_) {
+  switch (htp_performance_mode) {
     case HtpPerformanceMode::kHtpBurst:
       dcvs_v3.setSleepLatency = 1;  // true
       dcvs_v3.sleepLatency = kSleepMinLatency;
+      dcvs_v3.dcvsEnable = kDcvsDisable;
       dcvs_v3.setBusParams = 1;
       dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_MAX_VOLTAGE_CORNER;
       dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_MAX_VOLTAGE_CORNER;
@@ -684,6 +698,7 @@ Status QnnBackendManager::SetHtpPowerConfig() {
     case HtpPerformanceMode::kHtpHighPerformance:
       dcvs_v3.setSleepLatency = 1;  // true
       dcvs_v3.sleepLatency = kSleepLowLatency;
+      dcvs_v3.dcvsEnable = kDcvsDisable;
       dcvs_v3.setBusParams = 1;
       dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_TURBO;
       dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_TURBO;
@@ -693,33 +708,36 @@ Status QnnBackendManager::SetHtpPowerConfig() {
       dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_TURBO;
       dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_TURBO;
       break;
-    case HtpPerformanceMode::kHtpPowerSaver:
+    case HtpPerformanceMode::kHtpBalanced:
       dcvs_v3.setSleepLatency = 1;  // true
       dcvs_v3.sleepLatency = kSleepMediumLatency;
+      dcvs_v3.dcvsEnable = kDcvsEnable;
       dcvs_v3.setBusParams = 1;
-      dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS;
-      dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS;
-      dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS;
+      dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM_PLUS;
+      dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM_PLUS;
+      dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS;
       dcvs_v3.setCoreParams = 1;
-      dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS;
-      dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS;
-      dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS;
+      dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM_PLUS;
+      dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM_PLUS;
+      dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS;
       break;
-    case HtpPerformanceMode::kHtpLowPowerSaver:
+    case HtpPerformanceMode::kHtpLowBalanced:
       dcvs_v3.setSleepLatency = 1;  // true
       dcvs_v3.sleepLatency = kSleepMediumLatency;
+      dcvs_v3.dcvsEnable = kDcvsEnable;
       dcvs_v3.setBusParams = 1;
-      dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS2;
-      dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS2;
-      dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS2;
+      dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM;
+      dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM;
+      dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM;
       dcvs_v3.setCoreParams = 1;
-      dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS2;
-      dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS2;
-      dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS2;
+      dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM;
+      dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM;
+      dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM;
       break;
     case HtpPerformanceMode::kHtpHighPowerSaver:
       dcvs_v3.setSleepLatency = 1;  // true
       dcvs_v3.sleepLatency = kSleepMediumLatency;
+      dcvs_v3.dcvsEnable = kDcvsEnable;
       dcvs_v3.setBusParams = 1;
       dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS_PLUS;
       dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS_PLUS;
@@ -729,62 +747,81 @@ Status QnnBackendManager::SetHtpPowerConfig() {
       dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS_PLUS;
       dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS_PLUS;
       break;
-    case HtpPerformanceMode::kHtpExtremePowerSaver:
+    case HtpPerformanceMode::kHtpPowerSaver:
       dcvs_v3.setSleepLatency = 1;  // true
       dcvs_v3.sleepLatency = kSleepMediumLatency;
+      dcvs_v3.dcvsEnable = kDcvsEnable;
       dcvs_v3.setBusParams = 1;
-      dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_CORNER_DISABLE;
-      dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_CORNER_DISABLE;
-      dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_CORNER_DISABLE;
+      dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS;
+      dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS;
+      dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS;
       dcvs_v3.setCoreParams = 1;
-      dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_CORNER_DISABLE;
-      dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_CORNER_DISABLE;
-      dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_CORNER_DISABLE;
+      dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS;
+      dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS;
+      dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS;
       break;
-    case HtpPerformanceMode::kHtpLowBalanced:
+    case HtpPerformanceMode::kHtpLowPowerSaver:
       dcvs_v3.setSleepLatency = 1;  // true
       dcvs_v3.sleepLatency = kSleepMediumLatency;
+      dcvs_v3.dcvsEnable = kDcvsEnable;
       dcvs_v3.setBusParams = 1;
-      dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM;
-      dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM;
-      dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM;
+      dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS2;
+      dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS2;
+      dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS2;
       dcvs_v3.setCoreParams = 1;
-      dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM;
-      dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM;
-      dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM;
+      dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS2;
+      dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS2;
+      dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS2;
       break;
-    case HtpPerformanceMode::kHtpBalanced:
+    case HtpPerformanceMode::kHtpExtremePowerSaver:
+      dcvs_v3.powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_POWER_SAVER_MODE;
       dcvs_v3.setSleepLatency = 1;  // true
       dcvs_v3.sleepLatency = kSleepMediumLatency;
+      dcvs_v3.dcvsEnable = kDcvsEnable;
       dcvs_v3.setBusParams = 1;
-      dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM_PLUS;
-      dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM_PLUS;
-      dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS;
+      dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_CORNER_DISABLE;
+      dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_CORNER_DISABLE;
+      dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_CORNER_DISABLE;
       dcvs_v3.setCoreParams = 1;
-      dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM_PLUS;
-      dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM_PLUS;
-      dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS;
+      dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_CORNER_DISABLE;
+      dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_CORNER_DISABLE;
+      dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_CORNER_DISABLE;
       break;
     default:
-      ORT_THROW("Invalid performance profile %d", static_cast<int>(htp_performance_mode_));
+      ORT_THROW("Invalid performance profile %d", static_cast<int>(htp_performance_mode));
       break;
   }
   std::vector<const QnnHtpPerfInfrastructure_PowerConfig_t*> perf_power_configs_ptr = ObtainNullTermPtrVector(power_configs);
-  status = htp_perf_infra.setPowerConfig(htp_power_config_client_id_, perf_power_configs_ptr.data());
+  status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data());
   ORT_RETURN_IF(QNN_SUCCESS != status, "setPowerConfig failed for HTP performance mode.");
 
-  // Set rpc control latency here, but note that v68 doesn't support rpc polling mode.
-  if (rpc_control_latency_ != 0) {
+  return Status::OK();
+}
+
+Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_id,
+                                               uint32_t rpc_control_latency) {
+  if (rpc_control_latency != 0) {
+    QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
+    auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);
+    ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed.");
+
+    auto* htp_infra = static_cast<QnnHtpDevice_Infrastructure_t*>(qnn_device_infra);
+    ORT_RETURN_IF(QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_PERF != htp_infra->infraType,
+                  "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type.");
+    QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra;
+
+    // Set rpc control latency here, but note that v68 doesn't support rpc polling mode.
     constexpr int kNumRpcPollingPowerConfigs = 2;
     std::vector<QnnHtpPerfInfrastructure_PowerConfig_t> rpc_power_configs(kNumRpcPollingPowerConfigs);
-    QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency = rpc_power_configs[0];
+    QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency_cfg = rpc_power_configs[0];
     // v68 doesn't support this.
     QnnHtpPerfInfrastructure_PowerConfig_t& rpc_polling_time = rpc_power_configs[1];
-    rpc_control_latency.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY;
+    rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY;
     rpc_polling_time.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME;
-    rpc_control_latency.rpcControlLatencyConfig = rpc_control_latency_;
-    perf_power_configs_ptr = ObtainNullTermPtrVector(rpc_power_configs);
-    status = htp_perf_infra.setPowerConfig(htp_power_config_client_id_, perf_power_configs_ptr.data());
+    rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency;
+    std::vector<const QnnHtpPerfInfrastructure_PowerConfig_t*> perf_power_configs_ptr =
+        ObtainNullTermPtrVector(rpc_power_configs);
+    status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data());
     ORT_RETURN_IF(QNN_SUCCESS != status, "setPowerConfig failed for RPC control latency.");
   }
 
@@ -805,11 +842,7 @@ void QnnBackendManager::Split(std::vector<std::string>& split_string,
   }
 }
 
-Status QnnBackendManager::DestroyHTPPowerConfigID() {
-  if (htp_performance_mode_ == HtpPerformanceMode::kHtpDefault) {
-    return Status::OK();
-  }
-
+Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) {
   QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
   auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);
   ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed.");
@@ -819,7 +852,7 @@ Status QnnBackendManager::DestroyHTPPowerConfigID() {
                 "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type.");
   QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra;
 
-  Qnn_ErrorHandle_t destroy_ret = htp_perf_infra.destroyPowerConfigId(htp_power_config_client_id_);
+  Qnn_ErrorHandle_t destroy_ret = htp_perf_infra.destroyPowerConfigId(htp_power_config_id);
   ORT_RETURN_IF(QNN_SUCCESS != destroy_ret, "destroyPowerConfigId failed.");
   return Status::OK();
 }
@@ -829,12 +862,7 @@ void QnnBackendManager::ReleaseResources() {
     return;
   }
 
-  auto result = DestroyHTPPowerConfigID();
-  if (Status::OK() != result) {
-    ORT_THROW("Failed to DestroyHTPPowerConfigID.");
-  }
-
-  result = ReleaseContext();
+  auto result = ReleaseContext();
   if (Status::OK() != result) {
     ORT_THROW("Failed to ReleaseContext.");
   }
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
index 36375522b5a0a..ff97c4c3a991c 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
+++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
@@ -33,8 +33,6 @@ class QnnBackendManager {
  public:
   QnnBackendManager(std::string&& backend_path,
                     ProfilingLevel profiling_level,
-                    uint32_t rpc_control_latency,
-                    HtpPerformanceMode htp_performance_mode,
                     ContextPriority context_priority,
                     std::string&& qnn_saver_path,
                     uint32_t device_id,
@@ -42,8 +40,6 @@ class QnnBackendManager {
                     uint32_t soc_model)
       : backend_path_(backend_path),
         profiling_level_(profiling_level),
-        rpc_control_latency_(rpc_control_latency),
-        htp_performance_mode_(htp_performance_mode),
         context_priority_(context_priority),
         qnn_saver_path_(qnn_saver_path),
         device_id_(device_id),
@@ -92,7 +88,13 @@ class QnnBackendManager {
 
   Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context);
 
-  Status SetHtpPowerConfig();
+  Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id);
+
+  Status SetHtpPowerConfig(uint32_t htp_power_config_client_id,
+                           HtpPerformanceMode htp_performance_mode);
+
+  Status SetRpcControlLatency(uint32_t htp_power_config_client_id,
+                              uint32_t rpc_control_latency);
 
   const QNN_INTERFACE_VER_TYPE& GetQnnInterface() { return qnn_interface_; }
 
@@ -141,6 +143,8 @@ class QnnBackendManager {
 
   const std::string& GetSdkVersion() { return sdk_build_version_; }
 
+  Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id);
+
  private:
   void* LoadLib(const char* file_name, int flags, std::string& error_msg);
 
@@ -150,8 +154,6 @@ class QnnBackendManager {
 
   Status UnloadLib(void* handle);
 
-  Status DestroyHTPPowerConfigID();
-
   void* LibFunction(void* handle, const char* symbol, std::string& error_msg);
 
   template <class T>
@@ -232,15 +234,12 @@ class QnnBackendManager {
   QnnBackendType qnn_backend_type_ = QnnBackendType::CPU;
   Qnn_ProfileHandle_t profile_backend_handle_ = nullptr;
   std::vector<std::string> op_package_paths_;
-  uint32_t rpc_control_latency_ = 0;
-  HtpPerformanceMode htp_performance_mode_;
   ContextPriority context_priority_;
   std::string sdk_build_version_ = "";
 #ifdef _WIN32
   std::set<HMODULE> mod_handles_;
 #endif
   const std::string qnn_saver_path_;
-  uint32_t htp_power_config_client_id_ = 0;
   uint32_t device_id_ = 0;
   QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE;
   uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN;
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc
index 314cab4a36ca9..b3501dfec1ba8 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc
+++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc
@@ -9,6 +9,8 @@
 #include "core/providers/qnn/builder/op_builder_factory.h"
 #include "core/providers/shared/utils/utils.h"
 #include "core/framework/utils.h"
+#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
+#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
 #include "core/providers/qnn/builder/qnn_utils.h"
 
 namespace onnxruntime {
@@ -95,7 +97,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
   // valid throughout the lifetime of the ModelBuilder
   std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
   std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
-  std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer);
+  std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
 
   // This name must be same with the EPContext node name
   const auto& graph_name = fused_node.Name();
@@ -114,6 +116,8 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
     return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper.");
   }
 
+  std::unordered_set<const NodeUnit*> handled_node_units;
+
   // Op builer
   const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
   for (size_t i = 0; i < node_indices.size(); i++) {
@@ -122,20 +126,43 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
     // Check whether it's part of NodeUnit
     const NodeUnit& node_unit = GetNodeUnit(node, node_unit_map);
     // Q, DQ nodes in the node unit only carry the quantization parameters
-    // Add the QNN node when it is the target node (It's a normal node or a singel Q/DQ node)
+    // Add the QNN node when it is the target node (It's a normal node or a single Q/DQ node)
     const std::string& op_type = node_unit.OpType();
+
+    if (node != &node_unit.GetNode()) {
+      continue;
+    }
+
+    if (handled_node_units.count(&node_unit) != 0) {
+      continue;  // Already handled.
+    }
+
+    // Try to convert particular DQ -> Q sequences into QNN Convert op
+    auto convert_result = TryHandleConvertSequence(qnn_model_wrapper,
+                                                   node_unit,
+                                                   node_unit_map,
+                                                   logger_,
+                                                   false /*do_op_validation*/);
+    ORT_RETURN_IF_ERROR(convert_result.status);
+
+    if (convert_result.q_node_unit) {
+      // Successfully merged DQ -> Q sequence into a QNN Convert op.
+      // Mark both of these node units as handled.
+      handled_node_units.insert(&node_unit);
+      handled_node_units.insert(convert_result.q_node_unit);
+      continue;
+    }
+
     LOGS(logger_, VERBOSE) << " node name: [" << node->Name()
                            << "] node optype: [" << op_type
                            << "] as part of the NodeUnit type: [" << node_unit.OpType()
                            << "] name: [" << node_unit.Name()
                            << "]";
-    if (node != &node_unit.GetNode()) {
-      continue;
-    }
-
     if (const auto* op_builder = GetOpBuilder(op_type)) {
       ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(qnn_model_wrapper, node_unit, logger_));
     }
+
+    handled_node_units.insert(&node_unit);
   }
 
   ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph.");
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h
index d0dd091cb1688..8fed2f364ba5a 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_model.h
+++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h
@@ -6,13 +6,13 @@
 #include <vector>
 
 #include "core/common/status.h"
+#include "core/framework/node_unit.h"
 #include "core/graph/graph_viewer.h"
 #include "core/platform/ort_mutex.h"
 #include "core/providers/qnn/builder/qnn_def.h"
 #include "core/providers/qnn/builder/qnn_model_wrapper.h"
 #include "core/providers/qnn/builder/qnn_backend_manager.h"
 #include "core/session/onnxruntime_cxx_api.h"
-#include "core/providers/shared/node_unit/node_unit.h"
 
 namespace onnxruntime {
 namespace qnn {
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h
index 8ae489c749f31..1e2993f246ae4 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h
+++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h
@@ -11,8 +11,8 @@
 #include "QnnInterface.h"
 #include "qnn_def.h"
 #include "core/common/logging/logging.h"
+#include "core/framework/node_unit.h"
 #include "core/graph/graph_viewer.h"
-#include "core/providers/shared/node_unit/node_unit.h"
 #include "core/providers/shared/utils/utils.h"
 
 namespace onnxruntime {
diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
index b58f6e10df94c..ef90b1f629b26 100644
--- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
+++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
@@ -7,8 +7,11 @@
 #include "core/framework/compute_capability.h"
 #include "core/graph/graph_viewer.h"
 #include "core/session/onnxruntime_session_options_config_keys.h"
+#include "core/session/onnxruntime_run_options_config_keys.h"
 #include "core/session/onnxruntime_cxx_api.h"
 #include "core/framework/kernel_registry.h"
+#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
+#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
 #include "core/platform/env.h"
 #include "core/providers/common.h"
 #include "core/providers/partitioning_utils.h"
@@ -18,11 +21,36 @@
 #include "core/providers/qnn/builder/op_builder_factory.h"
 #include "core/providers/qnn/builder/qnn_def.h"
 #include "core/providers/qnn/builder/onnx_ctx_model_helper.h"
+#include "core/framework/run_options.h"
 
 namespace onnxruntime {
 
 constexpr const char* QNN = "QNN";
 
+static std::unique_ptr<std::vector<std::function<void()>>> s_run_on_unload_;
+
+void RunOnUnload(std::function<void()> function) {
+  OrtMutex mutex;
+  std::lock_guard<OrtMutex> guard(mutex);
+  if (!s_run_on_unload_) {
+    s_run_on_unload_ = std::make_unique<std::vector<std::function<void()>>>();
+  }
+  s_run_on_unload_->push_back(std::move(function));
+}
+
+struct OnUnload {
+  ~OnUnload() {
+    if (!s_run_on_unload_)
+      return;
+
+    for (auto& function : *s_run_on_unload_)
+      function();
+
+    s_run_on_unload_.reset();
+  }
+
+} g_on_unload;
+
 static void ParseProfilingLevel(std::string profiling_level_string,
                                 qnn::ProfilingLevel& profiling_level) {
   std::transform(profiling_level_string.begin(),
@@ -193,18 +221,18 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
   }
 
   static const std::string RPC_CONTROL_LANTENCY = "rpc_control_latency";
-  uint32_t rpc_control_latency = 0;
   auto latency_pos = provider_options_map.find(RPC_CONTROL_LANTENCY);
   if (latency_pos != provider_options_map.end()) {
-    rpc_control_latency = static_cast<uint32_t>(std::stoul(latency_pos->second));
-    LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency;
+    default_rpc_control_latency_ = static_cast<uint32_t>(std::stoul(latency_pos->second));
+    LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << default_rpc_control_latency_;
   }
 
-  qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault;
+  // default_htp_performance_mode from QNN EP option.
+  // set it once only for each thread as default so user don't need to set it for every session run
   static const std::string HTP_PERFORMANCE_MODE = "htp_performance_mode";
   auto htp_performance_mode_pos = provider_options_map.find(HTP_PERFORMANCE_MODE);
   if (htp_performance_mode_pos != provider_options_map.end()) {
-    ParseHtpPerformanceMode(htp_performance_mode_pos->second, htp_performance_mode);
+    ParseHtpPerformanceMode(htp_performance_mode_pos->second, default_htp_performance_mode_);
   }
 
   htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault;
@@ -241,15 +269,14 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
   }
 
   static const std::string QNN_DEVICE_ID = "device_id";
-  uint32_t device_id = 0;
   auto dev_id_pos = provider_options_map.find(QNN_DEVICE_ID);
   if (dev_id_pos != provider_options_map.end()) {
     int value = std::stoi(dev_id_pos->second);
     if (value < 0) {
       LOGS_DEFAULT(WARNING) << "Invalid device ID '" << value
-                            << "', only >= 0 allowed. Set to " << device_id << ".";
+                            << "', only >= 0 allowed. Set to " << device_id_ << ".";
     } else {
-      device_id = static_cast<uint32_t>(value);
+      device_id_ = static_cast<uint32_t>(value);
     }
   }
 
@@ -273,46 +300,58 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
     }
   }
 
+  static const std::string QNN_HTP_FP16_MODE = "enable_htp_fp16_precision";
+  auto htp_fp16_mode_pos = provider_options_map.find(QNN_HTP_FP16_MODE);
+  if (htp_fp16_mode_pos != provider_options_map.end()) {
+    if ("1" == htp_fp16_mode_pos->second) {
+      enable_HTP_FP16_precision_ = true;
+    } else if ("0" == htp_fp16_mode_pos->second) {
+      enable_HTP_FP16_precision_ = false;
+    } else {
+      LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_fp16_precision: " << enable_HTP_FP16_precision_ << " only 0 or 1 allowed. Set to 0.";
+    }
+    LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_fp16_precision: " << enable_HTP_FP16_precision_;
+  }
+
   qnn_backend_manager_ = std::make_unique<qnn::QnnBackendManager>(
       std::move(backend_path),
       profiling_level,
-      rpc_control_latency,
-      htp_performance_mode,
       context_priority,
       std::move(qnn_saver_path),
-      device_id,
+      device_id_,
       htp_arch,
       soc_model);
 }
 
+QNNExecutionProvider::~QNNExecutionProvider() {
+  // clean up thread local context caches
+  std::lock_guard<OrtMutex> lock(context_state_.mutex);
+  for (const auto& cache_weak : context_state_.caches_to_update_on_destruction) {
+    const auto cache = cache_weak.lock();
+    if (!cache) continue;
+    ORT_IGNORE_RETURN_VALUE(cache->erase(this));
+  }
+}
+
 bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
-                                           std::unordered_map<const NodeUnit*, bool>& node_unit_supported_result,
                                            const logging::Logger& logger) const {
-  // If we have visited one of the nodes in the node_unit, use the result directly
-  const auto it = node_unit_supported_result.find(&node_unit);
-  if (it != node_unit_supported_result.cend()) {
-    return it->second;
+  const std::string& op_type = node_unit.OpType();
+  bool supported = false;
+  const auto* op_builder = qnn::GetOpBuilder(op_type);
+  if (op_builder == nullptr) {
+    LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP."
+                          << node_unit.OpType() << " node `" << node_unit.Name()
+                          << "` will not be assigned to QNN EP.";
   } else {
-    const std::string& op_type = node_unit.OpType();
-
-    bool supported = false;
-    const auto* op_builder = qnn::GetOpBuilder(op_type);
-    if (op_builder == nullptr) {
-      LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP."
-                            << node_unit.OpType() << " node `" << node_unit.Name()
-                            << "` will not be assigned to QNN EP.";
-    } else {
-      auto status = op_builder->IsOpSupported(qnn_model_wrapper,
-                                              node_unit, logger);
-      if (Status::OK() != status) {
-        LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name()
-                              << "` is not supported: " << status.ErrorMessage();
-      }
-      supported = (Status::OK() == status);
+    auto status = op_builder->IsOpSupported(qnn_model_wrapper,
+                                            node_unit, logger);
+    if (Status::OK() != status) {
+      LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name()
+                            << "` is not supported: " << status.ErrorMessage();
     }
-    node_unit_supported_result[&node_unit] = supported;
-    return supported;
+    supported = (Status::OK() == status);
   }
+  return supported;
 }
 
 std::unordered_set<const Node*>
@@ -391,24 +430,51 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer,
     if (node != &node_unit->GetNode()) {
       continue;
     }
-    const bool supported = IsNodeSupported(qnn_model_wrapper,
-                                           *node_unit,
-                                           node_unit_supported_result,
-                                           logger);
-    LOGS(logger, VERBOSE) << "Node supported: [" << supported
-                          << "] index: [" << node->Index()
-                          << "] name: [" << node->Name()
-                          << "] Operator type: [" << node->OpType()
-                          << "] as part of the NodeUnit type: [" << node_unit->OpType()
-                          << "] index: [" << node_unit->Index()
-                          << "] name: [" << node_unit->Name()
-                          << "]";
+
+    if (node_unit_supported_result.count(node_unit) != 0) {
+      continue;  // Already handled this node unit
+    }
+
+    // Try to convert certain standalone DQ -> Q sequences into QNN Convert op
+    auto convert_result = TryHandleConvertSequence(qnn_model_wrapper,
+                                                   *node_unit,
+                                                   node_unit_map,
+                                                   logger,
+                                                   true /*do_op_validation*/);
+    if (!convert_result.status.IsOK()) {
+      LOGS(logger, WARNING) << "Failed to convert DQ -> Q sequence to QNN Convert. "
+                            << "Type: " << node_unit->OpType() << ", Node name: " << node_unit->Name() << ", "
+                            << "Message: " << convert_result.status.ErrorMessage();
+    }
+
+    bool supported = false;
+
+    if (convert_result.status.IsOK() && convert_result.q_node_unit) {  // Merged DQ -> Q sequence into QNN Convert op
+      supported = true;
+
+      // Mark the Q node unit as handled and supported here so that we don't try to process it again.
+      node_unit_supported_result.insert({convert_result.q_node_unit, true});
+      supported_nodes.insert(&convert_result.q_node_unit->GetNode());
+    } else {
+      supported = IsNodeSupported(qnn_model_wrapper, *node_unit, logger);
+      LOGS(logger, VERBOSE) << "Node supported: [" << supported
+                            << "] index: [" << node->Index()
+                            << "] name: [" << node->Name()
+                            << "] Operator type: [" << node->OpType()
+                            << "] as part of the NodeUnit type: [" << node_unit->OpType()
+                            << "] index: [" << node_unit->Index()
+                            << "] name: [" << node_unit->Name()
+                            << "]";
+    }
+
     if (supported) {
       // If the node_unit is supported, add all of its nodes to the supported list.
       for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) {
         supported_nodes.insert(node_in_group);
       }
     }
+
+    node_unit_supported_result.insert({node_unit, supported});
   }
 
   return supported_nodes;
@@ -443,7 +509,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
   std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
   std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
 
-  std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer);
+  std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
 
   const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(),
                                                  is_qnn_ctx_model, logger);
@@ -483,44 +549,39 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
   size_t num_of_supported_nodes = 0;
 
   // Create partitions from supported nodes.
-  {
-    std::vector<std::unique_ptr<ComputeCapability>> partitions = utils::CreateSupportedPartitions(graph_viewer,
-                                                                                                  supported_nodes, {},
-                                                                                                  gen_metadef_name, QNN,
-                                                                                                  kQnnExecutionProvider,
-                                                                                                  true);
-
-    // Filter out partitions that consist of a single QuantizeLinear or DequantizeLinear node.
-    // We also count the number of supported nodes in all valid partitions.
-    for (auto& partition : partitions) {
-      bool is_valid_partition = true;
-      size_t nodes_in_partition = 0;
-
-      if (partition && partition->sub_graph) {
-        nodes_in_partition = partition->sub_graph->nodes.size();
-
-        if (nodes_in_partition == 1 && !is_qnn_ctx_model) {
-          const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]);
-
-          if (!node) {
-            LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node.";
-            is_valid_partition = false;
-          } else if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") {
-            LOGS(logger, WARNING) << "QNN EP does not support a single Quantize/Dequantize node in a partition.";
-            is_valid_partition = false;
-          }
+  std::vector<std::unique_ptr<ComputeCapability>> partitions = utils::CreateSupportedPartitions(
+      graph_viewer, supported_nodes, {}, gen_metadef_name, QNN, kQnnExecutionProvider, &node_unit_map, true);
+
+  // Filter out partitions that consist of a single QuantizeLinear or DequantizeLinear node.
+  // We also count the number of supported nodes in all valid partitions.
+  for (auto& partition : partitions) {
+    bool is_valid_partition = true;
+    size_t nodes_in_partition = 0;
+
+    if (partition && partition->sub_graph) {
+      nodes_in_partition = partition->sub_graph->nodes.size();
+
+      if (nodes_in_partition == 1 && !is_qnn_ctx_model) {
+        const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]);
+
+        if (!node) {
+          LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node.";
+          is_valid_partition = false;
+        } else if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") {
+          LOGS(logger, WARNING) << "QNN EP does not support a single Quantize/Dequantize node in a partition.";
+          is_valid_partition = false;
         }
-      } else {
-        LOGS(logger, ERROR) << "QNN EP: Invalid partition.";
-        is_valid_partition = false;
       }
+    } else {
+      LOGS(logger, ERROR) << "QNN EP: Invalid partition.";
+      is_valid_partition = false;
+    }
 
-      if (is_valid_partition) {
-        result.push_back(std::move(partition));
-        num_of_supported_nodes += nodes_in_partition;
-      }
-    }  // for
-  }
+    if (is_valid_partition) {
+      result.push_back(std::move(partition));
+      num_of_supported_nodes += nodes_in_partition;
+    }
+  }  // for
 
   const size_t num_of_partitions = result.size();
   const auto summary_msg = MakeString("Number of partitions supported by QNN EP: ", num_of_partitions,
@@ -589,6 +650,16 @@ void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnConfigsBuilder<QnnGraph_C
       graph_opt_config_vtcm.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM;
       graph_opt_config_vtcm.customConfig = &htp_graph_opt_config_vtcm;
     }
+
+    if (enable_HTP_FP16_precision_) {
+      QnnHtpGraph_CustomConfig_t& htp_graph_precision_config = configs_builder.PushCustomConfig();
+      htp_graph_precision_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_PRECISION;
+      htp_graph_precision_config.precision = QNN_PRECISION_FLOAT16;
+
+      QnnGraph_Config_t& graph_precision_config = configs_builder.PushConfig();
+      graph_precision_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM;
+      graph_precision_config.customConfig = &htp_graph_precision_config;
+    }
   }
 }
 
@@ -652,7 +723,8 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
     ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer,
                                                      context_cache_path,
                                                      qnn_backend_manager_.get(),
-                                                     qnn_models));
+                                                     qnn_models,
+                                                     logger));
 
     for (auto fused_node_and_graph : fused_nodes_and_graphs) {
       const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
@@ -706,4 +778,147 @@ const InlinedVector<const Node*> QNNExecutionProvider::GetEpContextNodes() const
 
   return ep_context_nodes;
 }
+
+QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager,
+                                                         uint32_t device_id,
+                                                         uint32_t core_id,
+                                                         qnn::HtpPerformanceMode default_htp_performance_mode,
+                                                         uint32_t default_rpc_control_latency)
+    : qnn_backend_manager_(qnn_backend_manager) {
+  Status rt = qnn_backend_manager_->CreateHtpPowerCfgId(device_id, core_id, htp_power_config_id_);
+  is_htp_power_config_id_valid_ = rt.IsOK();
+  // default_htp_performance_mode and default_rpc_control_latency are from QNN EP option.
+  // set it once only for each thread as default so user don't need to set it for every session run
+  if (is_htp_power_config_id_valid_) {
+    if (qnn::HtpPerformanceMode::kHtpDefault != default_htp_performance_mode) {
+      ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetHtpPowerConfig(htp_power_config_id_,
+                                                                      default_htp_performance_mode));
+    }
+    if (default_rpc_control_latency > 0) {
+      ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcControlLatency(htp_power_config_id_,
+                                                                         default_rpc_control_latency));
+    }
+  }
+}
+
+QNNExecutionProvider::PerThreadContext::~PerThreadContext() {
+  if (is_htp_power_config_id_valid_) {
+    ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->DestroyHTPPowerConfigID(htp_power_config_id_));
+  }
+}
+
+QNNExecutionProvider::PerThreadContext& QNNExecutionProvider::GetPerThreadContext() const {
+  const auto& per_thread_context_cache = PerThreadContextCache();
+
+  // try to use cached context
+  auto cached_context_it = per_thread_context_cache->find(this);
+  if (cached_context_it != per_thread_context_cache->end()) {
+    auto cached_context = cached_context_it->second.lock();
+    ORT_ENFORCE(cached_context);
+    return *cached_context;
+  }
+
+  // get context and update cache
+  std::shared_ptr<PerThreadContext> context;
+  {
+    std::lock_guard<OrtMutex> lock(context_state_.mutex);
+
+    // get or create a context
+    if (context_state_.retired_context_pool.empty()) {
+      uint32_t core_id = 0;
+      context = std::make_shared<PerThreadContext>(qnn_backend_manager_.get(), device_id_, core_id,
+                                                   default_htp_performance_mode_, default_rpc_control_latency_);
+    } else {
+      context = context_state_.retired_context_pool.back();
+      context_state_.retired_context_pool.pop_back();
+    }
+
+    // insert into active_contexts, should not already be present
+    const auto active_contexts_insert_result = context_state_.active_contexts.insert(context);
+    ORT_ENFORCE(active_contexts_insert_result.second);
+
+    // insert into caches_to_update_on_destruction, may already be present
+    ORT_IGNORE_RETURN_VALUE(context_state_.caches_to_update_on_destruction.insert(per_thread_context_cache));
+  }
+
+  per_thread_context_cache->insert(std::make_pair(this, context));
+
+  return *context;
+}
+
+void QNNExecutionProvider::ReleasePerThreadContext() const {
+  const auto& per_thread_context_cache = PerThreadContextCache();
+
+  auto cached_context_it = per_thread_context_cache->find(this);
+  ORT_ENFORCE(cached_context_it != per_thread_context_cache->end());
+  auto cached_context = cached_context_it->second.lock();
+  ORT_ENFORCE(cached_context);
+
+  {
+    std::lock_guard<OrtMutex> lock(context_state_.mutex);
+    context_state_.active_contexts.erase(cached_context);
+    context_state_.retired_context_pool.push_back(cached_context);
+  }
+
+  per_thread_context_cache->erase(cached_context_it);
+}
+
+Status QNNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) {
+  auto backend_type = qnn_backend_manager_->GetQnnBackendType();
+  if (qnn::QnnBackendType::HTP != backend_type && qnn::QnnBackendType::DSP != backend_type) {
+    return Status::OK();
+  }
+
+  std::string htp_perf_mode = "";
+  qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault;
+  if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnPerfMode, htp_perf_mode)) {
+    // set power mode
+    ParseHtpPerformanceMode(htp_perf_mode, htp_performance_mode);
+  }
+
+  std::string rpc_latency = "";
+  uint32_t rpc_control_latency = 0;
+  if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnRpcControlLatency, rpc_latency)) {
+    rpc_control_latency = static_cast<uint32_t>(std::stoul(rpc_latency));
+    LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency;
+  }
+
+  if (GetPerThreadContext().IsHtpPowerConfigIdValid()) {
+    if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) {
+      ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(),
+                                                                  htp_performance_mode));
+    }
+
+    if (rpc_control_latency > 0) {
+      ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcControlLatency(GetPerThreadContext().GetHtpPowerConfigId(),
+                                                                     rpc_control_latency));
+    }
+  }
+
+  return Status::OK();
+}
+
+Status QNNExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& run_options) {
+  auto backend_type = qnn_backend_manager_->GetQnnBackendType();
+  if (qnn::QnnBackendType::HTP != backend_type && qnn::QnnBackendType::DSP != backend_type) {
+    return Status::OK();
+  }
+
+  std::string htp_perf_mode = "";
+  qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault;
+  if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, htp_perf_mode)) {
+    // set power mode
+    ParseHtpPerformanceMode(htp_perf_mode, htp_performance_mode);
+  }
+
+  if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) {
+    if (!GetPerThreadContext().IsHtpPowerConfigIdValid()) {
+      return Status::OK();
+    }
+    ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(),
+                                                                htp_performance_mode));
+  }
+
+  return Status::OK();
+}
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h
index 09bcb24db4dc2..82dceb8ae3973 100644
--- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h
+++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h
@@ -12,14 +12,19 @@
 #include "core/providers/qnn/builder/qnn_model.h"
 #include "core/providers/qnn/builder/qnn_configs_helper.h"
 #include "HTP/QnnHtpGraph.h"
+#include <vector>
+#include <set>
+#include <unordered_map>
 
 namespace onnxruntime {
 
+void RunOnUnload(std::function<void()> function);
+
 // Logical device representation.
 class QNNExecutionProvider : public IExecutionProvider {
  public:
   explicit QNNExecutionProvider(const ProviderOptions& provider_options_map, const SessionOptions* session_options);
-  virtual ~QNNExecutionProvider() = default;
+  virtual ~QNNExecutionProvider();
   ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QNNExecutionProvider);
 
   // we implement the Compile that takes FusedNodeAndGraph instances
@@ -40,9 +45,12 @@ class QNNExecutionProvider : public IExecutionProvider {
 
   const InlinedVector<const Node*> GetEpContextNodes() const override;
 
+  Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
+
+  Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override;
+
  private:
   bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
-                       std::unordered_map<const NodeUnit*, bool>& node_unit_supported_result,
                        const logging::Logger& logger) const;
 
   std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewer,
@@ -73,6 +81,69 @@ class QNNExecutionProvider : public IExecutionProvider {
   int32_t vtcm_size_in_mb_ = 0;
   std::unique_ptr<onnxruntime::Model> qnn_ep_context_model_;
   ModelMetadefIdGenerator metadef_id_generator_;
+  uint32_t device_id_ = 0;
+  qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault;
+  uint32_t default_rpc_control_latency_ = 0;
+  bool enable_HTP_FP16_precision_ = false;
+
+  class PerThreadContext final {
+   public:
+    PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager,
+                     uint32_t device_id, uint32_t core_id,
+                     qnn::HtpPerformanceMode default_htp_performance_mode,
+                     uint32_t default_rpc_control_latency);
+    ~PerThreadContext();
+    ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext);
+
+    bool IsHtpPowerConfigIdValid() { return is_htp_power_config_id_valid_; }
+
+    uint32_t GetHtpPowerConfigId() { return htp_power_config_id_; }
+
+   private:
+    bool is_htp_power_config_id_valid_ = false;
+    uint32_t htp_power_config_id_ = 0;
+    qnn::QnnBackendManager* qnn_backend_manager_;
+  };
+
+  using PerThreadContextMap = std::unordered_map<const QNNExecutionProvider*, std::weak_ptr<PerThreadContext>>;
+
+  struct ContextCacheHolder {
+    ContextCacheHolder() {
+      RunOnUnload([&, weak_p_ = std::weak_ptr<PerThreadContextMap>(p)] {
+        if (auto lock = weak_p_.lock())
+          p.reset();
+      });
+    }
+
+    std::shared_ptr<PerThreadContextMap> p = std::make_shared<PerThreadContextMap>();
+  };
+
+  static const std::shared_ptr<PerThreadContextMap>& PerThreadContextCache() {
+    thread_local const ContextCacheHolder per_thread_context_cache;
+    return per_thread_context_cache.p;
+  }
+
+  struct PerThreadContextState {
+    // contexts that are currently active
+    std::set<std::shared_ptr<PerThreadContext>, std::owner_less<std::shared_ptr<PerThreadContext>>> active_contexts;
+    // contexts available for reuse
+    std::vector<std::shared_ptr<PerThreadContext>> retired_context_pool;
+    // weak references to thread local caches from which this QNNExecutionProvider instance's entry should be removed
+    // upon destruction
+    std::set<std::weak_ptr<PerThreadContextMap>, std::owner_less<std::weak_ptr<PerThreadContextMap>>>
+        caches_to_update_on_destruction;
+    // synchronizes access to PerThreadContextState members
+    OrtMutex mutex;
+  };
+
+  // The execution provider maintains the PerThreadContexts in this structure.
+  // Synchronization is required to update the contained structures.
+  // On the other hand, access to an individual PerThreadContext is assumed to be from a single thread at a time,
+  // so synchronization is not required for that.
+  mutable PerThreadContextState context_state_;
+
+  PerThreadContext& GetPerThreadContext() const;
+  void ReleasePerThreadContext() const;
 };
 
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh
index 5f966ac746fcb..1698e5ca8478c 100644
--- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh
+++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh
@@ -335,6 +335,157 @@ __device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) {
   return fmodf((float)a, (float)b);
 }
 
+namespace isinf_details {
+template <typename T>
+struct IsInfTyped {
+  static __device__ __inline__ bool IsInf(T a) {
+    // cast is needed because on non MS compilers,
+    // because there isinf() returns int
+    // and we want to avoid stupid warnings
+    return static_cast<bool>(isinf(a));
+  }
+  static __device__ __inline__ bool IsInfPos(T a) {
+    return a == std::numeric_limits<T>::infinity();
+  }
+  static __device__ __inline__ bool IsInfNeg(T a) {
+    return a == -std::numeric_limits<T>::infinity();
+  }
+};
+
+template <>
+struct IsInfTyped<half> {
+  static __device__ __inline__ bool IsInf(half a) {
+    return MLFloat16::kPositiveInfinityBits ==
+           static_cast<uint16_t>(*reinterpret_cast<uint16_t*>(&a) & ~MLFloat16::kSignMask);
+  }
+  static __device__ __inline__ bool IsInfPos(half a) {
+    return MLFloat16::kPositiveInfinityBits == *reinterpret_cast<uint16_t*>(&a);
+  }
+  static __device__ __inline__ bool IsInfNeg(half a) {
+    return MLFloat16::kNegativeInfinityBits == *reinterpret_cast<uint16_t*>(&a);
+  }
+};
+
+template <>
+struct IsInfTyped<BFloat16> {
+  static __device__ __inline__ bool IsInf(BFloat16 a) {
+    return BFloat16::kPositiveInfinityBits ==
+           static_cast<uint16_t>(*reinterpret_cast<uint16_t*>(&a) & ~BFloat16::kSignMask);
+  }
+  static __device__ __inline__ bool IsInfPos(BFloat16 a) {
+    return BFloat16::kPositiveInfinityBits == *reinterpret_cast<uint16_t*>(&a);
+  }
+  static __device__ __inline__ bool IsInfNeg(BFloat16 a) {
+    return BFloat16::kNegativeInfinityBits == *reinterpret_cast<uint16_t*>(&a);
+  }
+};
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+
+template <typename T>
+struct ReturnFalse {
+  constexpr static bool __device__ __inline__ IsInf(T) { return false; }
+  constexpr static bool __device__ __inline__ IsInfPos(T) { return false; }
+  constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; }
+};
+
+template <>
+struct IsInfTyped<Float8E4M3FN> : ReturnFalse<Float8E4M3FN> {};
+
+template <>
+struct IsInfTyped<Float8E4M3FNUZ> : ReturnFalse<Float8E4M3FNUZ> {};
+
+template <>
+struct IsInfTyped<Float8E5M2> {
+  static __device__ __inline__ bool IsInf(Float8E5M2 a) {
+    return a.val == 0b01111100 || a.val == 0b11111100;
+  }
+  static __device__ __inline__ bool IsInfPos(Float8E5M2 a) {
+    return a.val == 0b01111100;
+  }
+  static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) {
+    return a.val == 0b11111100;
+  }
+};
+
+template <>
+struct IsInfTyped<Float8E5M2FNUZ> : ReturnFalse<Float8E5M2FNUZ> {};
+
+#endif
+}  // namespace isinf_details
+
+template <typename T, bool detect_positive, bool detect_negative>
+struct _IsInf {
+  __device__ __inline__ bool operator()(T a) const {
+    if constexpr (detect_positive && detect_negative) {
+      return isinf_details::IsInfTyped<T>::IsInf(a);
+    } else if constexpr (detect_positive) {
+      return isinf_details::IsInfTyped<T>::IsInfPos(a);
+    } else if constexpr (detect_negative) {
+      return isinf_details::IsInfTyped<T>::IsInfNeg(a);
+    } else {
+      return false;
+    }
+  }
+};
+
+// float and double
+template <typename T>
+struct _IsNan {
+  __device__ __inline__ bool operator()(T a) const {
+    return isnan(a);
+  }
+};
+
+template <>
+struct _IsNan<half> {
+  __device__ __inline__ bool operator()(half a) const {
+    return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~MLFloat16::kSignMask) 
+                                > MLFloat16::kPositiveInfinityBits;
+  }
+};
+
+template <>
+struct _IsNan<BFloat16> {
+  __device__ __inline__ bool operator()(BFloat16 a) const {
+    return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~BFloat16::kSignMask) 
+                               > BFloat16::kPositiveInfinityBits;
+  }
+};
+
+#if !defined(DISABLE_FLOAT8_TYPES)
+
+template <>
+struct _IsNan<Float8E4M3FN> {
+  __device__ __inline__ bool operator()(Float8E4M3FN a) const {
+    return (*reinterpret_cast<const uint8_t*>(&a) & 0x7f) == 0x7f;
+  }
+};
+
+template <>
+struct _IsNan<Float8E4M3FNUZ> {
+  __device__ __inline__ bool operator()(Float8E4M3FNUZ a) const {
+    return *reinterpret_cast<const uint8_t*>(&a) == 0x80;
+  }
+};
+
+template <>
+struct _IsNan<Float8E5M2> {
+  __device__ __inline__ bool operator()(Float8E5M2 a) const {
+    uint8_t c = *reinterpret_cast<const uint8_t*>(&a);
+    return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00);
+  }
+};
+
+template <>
+struct _IsNan<Float8E5M2FNUZ> {
+  __device__ __inline__ bool operator()(Float8E5M2FNUZ a) const {
+    return *reinterpret_cast<const uint8_t*>(&a) == 0x80;
+  }
+};
+
+#endif
+
 // We would like to use 64-bit integer to support large matrices. However, ROCM seems to support only 32-bit integer
 // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.
 #ifndef HIP_LONG
diff --git a/onnxruntime/core/providers/rocm/nn/pool.cc b/onnxruntime/core/providers/rocm/nn/pool.cc
index 045c8b55c0b0d..3a82ab598004b 100644
--- a/onnxruntime/core/providers/rocm/nn/pool.cc
+++ b/onnxruntime/core/providers/rocm/nn/pool.cc
@@ -257,7 +257,7 @@ Status Pool<T, MaxPool<8>>::ComputeInternal(OpKernelContext* context) const {
   Tensor* I = context->Output(1, TensorShape(y_dims));
   if (nullptr != I || !this->pool_attrs_.default_dilations) {
     auto i_data = nullptr == I ? nullptr : I->MutableData<int64_t>();
-    MaxPoolWithIndex<HipT>(
+    MaxPoolWithIndex<HipT, false>(
         this->Stream(context),
         x_shape,
         TensorShape(y_dims),
diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
index ee3578326ac6d..87daaeea969ac 100644
--- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
+++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
@@ -183,23 +183,24 @@ bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const {
   return regular_run_count_before_graph_capture_ >= min_num_runs_before_hip_graph_capture_;
 }
 
-void ROCMExecutionProvider::PerThreadContext::CaptureBegin() {
+void ROCMExecutionProvider::PerThreadContext::CaptureBegin(int) {
   hip_graph_.Reset();
-  hip_graph_.CaptureBegin();
+  hip_graph_.CaptureBegin(0);
 }
 
-void ROCMExecutionProvider::PerThreadContext::CaptureEnd() {
-  hip_graph_.CaptureEnd();
+void ROCMExecutionProvider::PerThreadContext::CaptureEnd(int) {
+  hip_graph_.CaptureEnd(0);
   is_graph_captured_ = true;
 }
 
-bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured() const {
+bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured(int) const {
   return is_graph_captured_;
 }
 
-Status ROCMExecutionProvider::PerThreadContext::ReplayGraph() {
-  ORT_ENFORCE(IsGraphCaptured());
-  return hip_graph_.Replay();
+Status ROCMExecutionProvider::PerThreadContext::ReplayGraph(int graph_annotation_id) {
+  ORT_ENFORCE(IsGraphCaptured(graph_annotation_id));
+
+  return hip_graph_.Replay(graph_annotation_id);
 }
 
 void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() {
@@ -353,23 +354,23 @@ Status ROCMExecutionProvider::Sync() const {
   return Status::OK();
 }
 
-Status ROCMExecutionProvider::OnRunStart() {
+Status ROCMExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
   // always set ROCM device when session::Run() in case it runs in a worker thread
   HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId()));
-  if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {
+  if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured(0)) {
     LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model";
-    GetPerThreadContext().CaptureBegin();
+    GetPerThreadContext().CaptureBegin(0);
   }
   return Status::OK();
 }
 
-Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) {
-  if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) {
+Status ROCMExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) {
+  if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(0)) {
     if (GetPerThreadContext().IsGraphCaptureAllowed()) {
-      GetPerThreadContext().CaptureEnd();
+      GetPerThreadContext().CaptureEnd(0);
       // HIP work issued to a capturing stream doesn’t actually run on the GPU,
       // so run the captured graph here to actually execute the work.
-      ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph());
+      ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(0));
     } else {
       GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture();
     }
@@ -400,12 +401,12 @@ bool ROCMExecutionProvider::IsGraphCaptureEnabled() const {
   return info_.enable_hip_graph;
 }
 
-bool ROCMExecutionProvider::IsGraphCaptured() const {
-  return GetPerThreadContext().IsGraphCaptured();
+bool ROCMExecutionProvider::IsGraphCaptured(int) const {
+  return GetPerThreadContext().IsGraphCaptured(0);
 }
 
-Status ROCMExecutionProvider::ReplayGraph() {
-  return GetPerThreadContext().ReplayGraph();
+Status ROCMExecutionProvider::ReplayGraph(int /*graph_annotation_id*/) {
+  return GetPerThreadContext().ReplayGraph(0);
 }
 
 namespace rocm {
@@ -733,6 +734,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, Shrink);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, Shrink);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, Shrink);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, IsNaN);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, Less);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, Less);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Less);
@@ -793,6 +795,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu);
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, TopK);
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, Mod);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 19, IsInf);
 
 // opset 11
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax);
@@ -1065,6 +1068,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint32_t, Cast);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint64_t, Cast);
 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, bool, Cast);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 19, IsNaN);
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Reshape);
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size);
@@ -1145,11 +1149,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSumSquare);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, GatherND);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Dropout);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Resize);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, int32_t, Resize);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, uint8_t, Resize);
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, If);
 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, Loop);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Flatten);
@@ -1304,6 +1308,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, Resize);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, uint8_t, Resize);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split);
 
 // Opset 19
@@ -1337,6 +1346,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, R
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Scan);
 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Shape);
 
+// Opset 20
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsInf);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsNaN);
+
 template <>
 KernelCreateInfo BuildKernelCreateInfo<void>() {
   return {};
@@ -1521,6 +1534,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, Erf)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, Erf)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Erf)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, IsNaN)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, bool, Not)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization)>,
     // BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization)>,
@@ -1733,6 +1747,8 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, Mod)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10,
+                                                                                                           19, IsInf)>,
 
     // opset 11
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax)>,
@@ -1929,6 +1945,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Abs)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Abs)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Abs)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 19, IsNaN)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Neg)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Neg)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Neg)>,
@@ -2081,11 +2098,16 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSumSquare)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, GatherND)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Dropout)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Resize)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Resize)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Resize)>,
-    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17,
+                                                                          float, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17,
+                                                                          double, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17,
+                                                                          MLFloat16, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17,
+                                                                          int32_t, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17,
+                                                                          uint8_t, Resize)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, If)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, Loop)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Flatten)>,
@@ -2240,6 +2262,16 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18,
+                                                                                                     float, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18,
+                                                                                                     double, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18,
+                                                                                                   MLFloat16, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18,
+                                                                                                     int32_t, Resize)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18,
+                                                                                                     uint8_t, Resize)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split)>,
 
     // Opset 19
@@ -2274,6 +2306,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Reshape)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Scan)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Shape)>,
+
+    // opset 20
+    BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsInf)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsNaN)>,
   };
 
   for (auto& function_table_entry : function_table) {
diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h
index 37d5f7b42210f..6d6c05027e7bd 100644
--- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h
+++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h
@@ -28,9 +28,9 @@ class ROCMExecutionProvider : public IExecutionProvider {
 
   Status Sync() const override;
 
-  Status OnRunStart() override;
+  Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
 
-  Status OnRunEnd(bool sync_stream) override;
+  Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override;
 
   const void* GetExecutionHandle() const noexcept override {
     // The ROCM interface does not return anything interesting.
@@ -75,8 +75,8 @@ class ROCMExecutionProvider : public IExecutionProvider {
   std::unique_ptr<profiling::EpProfiler> GetProfiler() override;
 
   bool IsGraphCaptureEnabled() const override;
-  bool IsGraphCaptured() const override;
-  Status ReplayGraph() override;
+  bool IsGraphCaptured(int graph_annotation_id) const override;
+  Status ReplayGraph(int graph_annotation_id) override;
   void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
   OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
   std::vector<AllocatorPtr> CreatePreferredAllocators() override;
@@ -139,10 +139,10 @@ class ROCMExecutionProvider : public IExecutionProvider {
     }
 
     bool IsGraphCaptureAllowed() const;
-    void CaptureBegin();
-    void CaptureEnd();
-    bool IsGraphCaptured() const;
-    Status ReplayGraph();
+    void CaptureBegin(int graph_annotation_id);
+    void CaptureEnd(int graph_annotation_id);
+    bool IsGraphCaptured(int graph_annotation_id) const;
+    Status ReplayGraph(int graph_annotation_id);
     void IncrementRegularRunCountBeforeGraphCapture();
 
    private:
diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc
index b557f92287f2b..3cb826437a54f 100644
--- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc
+++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc
@@ -13,6 +13,8 @@ namespace onnxruntime {
 namespace rocm {
 namespace provider_option_names {
 constexpr const char* kDeviceId = "device_id";
+constexpr const char* kHasUserComputeStream = "has_user_compute_stream";
+constexpr const char* kUserComputeStream = "user_compute_stream";
 constexpr const char* kMemLimit = "gpu_mem_limit";
 constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
 constexpr const char* kMiopenConvExhaustiveSearch = "miopen_conv_exhaustive_search";
@@ -38,6 +40,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
   void* alloc = nullptr;
   void* free = nullptr;
   void* empty_cache = nullptr;
+  void* user_compute_stream = nullptr;
   ORT_THROW_IF_ERROR(
       ProviderOptionsParser{}
           .AddValueParser(
@@ -52,6 +55,15 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
                     ", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
                 return Status::OK();
               })
+          .AddAssignmentToReference(rocm::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream)
+          .AddValueParser(
+              rocm::provider_option_names::kUserComputeStream,
+              [&user_compute_stream](const std::string& value_str) -> Status {
+                size_t address;
+                ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
+                user_compute_stream = reinterpret_cast<void*>(address);
+                return Status::OK();
+              })
           .AddValueParser(
               rocm::provider_option_names::kGpuExternalAlloc,
               [&alloc](const std::string& value_str) -> Status {
@@ -108,12 +120,18 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
 
   ROCMExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache};
   info.external_allocator_info = alloc_info;
+
+  info.user_compute_stream = user_compute_stream;
+  info.has_user_compute_stream = (user_compute_stream != nullptr);
+
   return info;
 }
 
 ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecutionProviderInfo& info) {
   const ProviderOptions options{
       {rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
+      {rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
+      {rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
       {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
       {rocm::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
       {rocm::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
@@ -135,6 +153,8 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution
 ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const OrtROCMProviderOptions& info) {
   const ProviderOptions options{
       {rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
+      {rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
+      {rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
       {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
       {rocm::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast<onnxruntime::ArenaExtendStrategy>(info.arena_extend_strategy))},
       {rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)},
diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.h b/onnxruntime/core/providers/rocm/rocm_stream_handle.h
index 1f3e5b75548e7..30983ce03568f 100644
--- a/onnxruntime/core/providers/rocm/rocm_stream_handle.h
+++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.h
@@ -8,6 +8,7 @@
 #include "core/framework/stream_handles.h"
 
 namespace onnxruntime {
+void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
 
 struct RocmStream : Stream {
   RocmStream(hipStream_t stream,
@@ -36,6 +37,8 @@ struct RocmStream : Stream {
 
   void* GetResource(int version, int id) const override;
 
+  WaitNotificationFn GetWaitNotificationFn() const override { return WaitRocmNotificationOnDevice; }
+
  private:
   std::vector<void*> deferred_cpu_buffers_;
   AllocatorPtr cpu_allocator_;
@@ -50,5 +53,4 @@ void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
                                bool use_existing_stream,
                                miopenHandle_t external_miopen_handle,
                                rocblas_handle external_rocblas_handle);
-void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/shared/node_unit/node_unit.cc b/onnxruntime/core/providers/shared/node_unit/node_unit.cc
deleted file mode 100644
index 10dd58ba28375..0000000000000
--- a/onnxruntime/core/providers/shared/node_unit/node_unit.cc
+++ /dev/null
@@ -1,319 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "node_unit.h"
-#include "core/graph/graph_viewer.h"
-#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
-#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
-
-namespace onnxruntime {
-
-namespace {
-
-enum class QLinearOpType : uint8_t {
-  Unknown,  // Unknown or not a linear quantized op
-  DequantizeLinear,
-  QuantizeLinear,
-  QLinearConv,
-  QLinearMatMul,
-  QLinearAdd,
-  QLinearSigmoid,
-  QLinearAveragePool,
-  QLinearMul,
-  QLinearReduceMean,
-  QLinearConcat,
-  QLinearGlobalAveragePool,
-  QLinearLeakyRelu,
-};
-
-QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) {
-  const auto& op_type = node.OpType();
-  if (op_type == "DequantizeLinear")
-    return QLinearOpType::DequantizeLinear;
-  else if (op_type == "QuantizeLinear")
-    return QLinearOpType::QuantizeLinear;
-  else if (op_type == "QLinearConv")
-    return QLinearOpType::QLinearConv;
-  else if (op_type == "QLinearMatMul")
-    return QLinearOpType::QLinearMatMul;
-  else if (op_type == "QLinearAdd")
-    return QLinearOpType::QLinearAdd;
-  else if (op_type == "QLinearSigmoid")
-    return QLinearOpType::QLinearSigmoid;
-  else if (op_type == "QLinearAveragePool")
-    return QLinearOpType::QLinearAveragePool;
-  else if (op_type == "QLinearMul")
-    return QLinearOpType::QLinearMul;
-  else if (op_type == "QLinearReduceMean")
-    return QLinearOpType::QLinearReduceMean;
-  else if (op_type == "QLinearConcat")
-    return QLinearOpType::QLinearConcat;
-  else if (op_type == "QLinearGlobalAveragePool")
-    return QLinearOpType::QLinearGlobalAveragePool;
-  else if (op_type == "QLinearLeakyRelu")
-    return QLinearOpType::QLinearLeakyRelu;
-
-  return QLinearOpType::Unknown;
-}
-
-// Ops have 1 input
-bool IsUnaryQLinearOp(QLinearOpType type) {
-  return type == QLinearOpType::QLinearSigmoid ||
-         type == QLinearOpType::QLinearAveragePool ||
-         type == QLinearOpType::QLinearGlobalAveragePool ||
-         type == QLinearOpType::QLinearLeakyRelu ||
-         type == QLinearOpType::QLinearReduceMean;
-}
-
-// Ops have 2 inputs
-bool IsBinaryQLinearOp(QLinearOpType type) {
-  return type == QLinearOpType::QLinearConv ||
-         type == QLinearOpType::QLinearMatMul ||
-         type == QLinearOpType::QLinearAdd ||
-         type == QLinearOpType::QLinearMul;
-}
-
-// Ops have 1 or more inputs
-bool IsVariadicQLinearOp(QLinearOpType type) {
-  return type == QLinearOpType::QLinearConcat;
-}
-
-const std::vector<const Node*> GetQDQIONodes(const GraphViewer& graph_viewer,
-                                             const QDQ::NodeGroup& node_group, bool is_input) {
-  std::vector<const Node*> io_nodes;
-  const auto& src_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes;
-  io_nodes.reserve(src_nodes.size());
-  for (const auto& node_idx : src_nodes) {
-    io_nodes.push_back(graph_viewer.GetNode(node_idx));
-  }
-  return io_nodes;
-}
-
-// Get the input or output NodeUnitIODef(s) for the given QDQ NodeGroup
-std::vector<NodeUnitIODef> GetQDQIODefs(const Node& target_node, const QDQ::NodeGroup& node_group,
-                                        bool is_input) {
-  const auto& dq_or_q_nodes = is_input ? node_group.dq_nodes : node_group.q_nodes;
-  const auto target_node_io_defs = is_input ? target_node.InputDefs() : target_node.OutputDefs();
-  const size_t target_node_io_defs_size = target_node_io_defs.size();
-
-  // Find all the quantized IO defs and indices (for the input to the target node)
-  std::unordered_map<size_t, NodeUnitIODef> quantized_io_defs;
-  quantized_io_defs.reserve(target_node_io_defs_size);
-
-  auto cur = is_input ? target_node.InputEdgesBegin() : target_node.OutputEdgesBegin();
-  auto end = is_input ? target_node.InputEdgesEnd() : target_node.OutputEdgesEnd();
-  for (; cur != end; ++cur) {
-    const Node& node = cur->GetNode();
-
-    // If we can find the node index in the dq or q nodes, then this is a quantize node (can be DQ or Q depends on is_input)
-    if (std::find(dq_or_q_nodes.cbegin(), dq_or_q_nodes.cend(), node.Index()) != dq_or_q_nodes.cend()) {
-      const auto node_inputs = node.InputDefs();
-      // quantization scale and zp are always the input[1, 2]
-      NodeUnitIODef::QuantParam quant_param{
-          *node_inputs[1],
-          node_inputs.size() == 3 ? node_inputs[2] : nullptr};
-      if (is_input) {
-        // DQ is input to the target node, use the DstArgIndex
-        auto idx = cur->GetDstArgIndex();
-        // This is a DQ node, we are using x, x_scale, x_zp (input[0, 1, 2])
-        quantized_io_defs.insert({idx, NodeUnitIODef{*node_inputs[0], quant_param}});
-      } else {
-        // Q is output of the target node, use the SrcArgIndex
-        auto idx = cur->GetSrcArgIndex();
-        // This is a Q node, we are using y (output[0]), y_scale, y_zp (input[1, 2])
-        const auto node_outputs = node.OutputDefs();
-        quantized_io_defs.insert({idx, NodeUnitIODef{*node_outputs[0], quant_param}});
-      }
-    }
-  }
-
-  // Construct the IODefs for this QDQ NodeGroup
-  std::vector<NodeUnitIODef> io_defs;
-  io_defs.reserve(target_node_io_defs_size);
-  for (size_t i = 0; i < target_node_io_defs_size; i++) {
-    // If we can find the NodeUnitIODef for this index, this is a quantized input
-    if (quantized_io_defs.find(i) != quantized_io_defs.cend()) {
-      io_defs.push_back(std::move(quantized_io_defs.at(i)));
-    } else {
-      // This is a regular input
-      io_defs.push_back({*target_node_io_defs[i], std::nullopt});
-    }
-  }
-
-  return io_defs;
-}
-
-}  // namespace
-
-NodeUnit::NodeUnit(const Node& node)
-    : target_node_(node),
-      type_(Type::SingleNode) {
-  InitForSingleNode();
-}
-
-NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group)
-    : q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)},
-      dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)},
-      target_node_(*graph_viewer.GetNode(node_group.target_node)),
-      type_(Type::QDQGroup),
-      inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)},
-      outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} {
-  ORT_THROW_IF_ERROR(QDQ::ValidateNodeGroupDQNodes(graph_viewer, target_node_, dq_nodes_));
-}
-
-const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); }
-const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); }
-const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); }
-int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); }
-NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); }
-const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); }
-ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); }
-
-void NodeUnit::InitForSingleNode() {
-  const auto& input_defs = target_node_.InputDefs();
-  const auto& output_defs = target_node_.OutputDefs();
-  auto qlinear_type = GetQLinearOpType(target_node_);
-  if (qlinear_type == QLinearOpType::Unknown ||
-      IsVariadicQLinearOp(qlinear_type)) {  // TODO, add variadic support
-    // Not a Qlinear op, add all inputs / outputs
-    auto add_all_io = [](std::vector<NodeUnitIODef>& defs,
-                         const ConstPointerContainer<std::vector<NodeArg*>>& node_defs) {
-      defs.reserve(node_defs.size());
-
-      for (const auto def : node_defs) {
-        defs.push_back(NodeUnitIODef{*def, std::nullopt});
-      }
-    };
-    add_all_io(inputs_, input_defs);
-    add_all_io(outputs_, output_defs);
-  } else if (IsUnaryQLinearOp(qlinear_type)) {
-    // Unary QLinear Op has 5 inputs
-    // x, x_scale, x_zp, y_scale, y_zp (optional)
-    inputs_.push_back(NodeUnitIODef{
-        *input_defs[0],
-        NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
-
-    outputs_.push_back(NodeUnitIODef{
-        *output_defs[0],
-        NodeUnitIODef::QuantParam{*input_defs[3],
-                                  input_defs.size() > 4
-                                      ? input_defs[4]
-                                      : nullptr}});
-  } else if (IsBinaryQLinearOp(qlinear_type)) {
-    // Binary QLinear Op has 9 inputs
-    // x1, x1_scale, x1_zp, x2/w, x2_scale, x2_zp, y_scale , y_zp, B
-    inputs_.push_back(NodeUnitIODef{
-        *input_defs[0],
-        NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}});
-    inputs_.push_back(NodeUnitIODef{
-        *input_defs[3],
-        NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}});
-
-    if (input_defs.size() == 9) {  // has Bias
-      inputs_.push_back(NodeUnitIODef{
-          *input_defs[8],
-          std::nullopt});  // for Bias the scale and zp are optional
-    }
-
-    outputs_.push_back(NodeUnitIODef{
-        *output_defs[0],
-        NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}});
-  } else if (qlinear_type == QLinearOpType::DequantizeLinear) {
-    // DequantizeLinear has 3 inputs
-    // x, x_scale, x_zp
-    // output is not quantized
-    inputs_.push_back(NodeUnitIODef{
-        *input_defs[0],
-        NodeUnitIODef::QuantParam{*input_defs[1],
-                                  input_defs.size() == 3
-                                      ? input_defs[2]
-                                      : nullptr}});
-    outputs_.push_back(NodeUnitIODef{*output_defs[0], std::nullopt});
-  } else if (qlinear_type == QLinearOpType::QuantizeLinear) {
-    // QuantizeLinear the input is not quantized and has 3 inputs
-    // x, y_scale, y_zp (optional)
-    // The output is quantized
-    inputs_.push_back(NodeUnitIODef{*input_defs[0], std::nullopt});
-    outputs_.push_back(NodeUnitIODef{
-        *output_defs[0],
-        NodeUnitIODef::QuantParam{*input_defs[1],
-                                  input_defs.size() == 3
-                                      ? input_defs[2]
-                                      : nullptr}});
-  } else {
-    ORT_THROW("The QLinear op [", static_cast<uint8_t>(qlinear_type), "] is not supported");
-  }
-}
-
-Node::EdgeConstIterator NodeUnit::OutputEdgesBegin(size_t index) const {
-  if (type_ == Type::SingleNode) {
-    ORT_ENFORCE(index == 0, "invalid output node index");
-    return target_node_.OutputEdgesBegin();
-  } else {
-    ORT_ENFORCE(index < q_nodes_.size(), "invalid output node index");
-    return q_nodes_[index]->OutputEdgesBegin();
-  }
-}
-
-Node::EdgeConstIterator NodeUnit::OutputEdgesEnd(size_t index) const {
-  if (type_ == Type::SingleNode) {
-    ORT_ENFORCE(index == 0, "invalid output node index");
-    return target_node_.OutputEdgesEnd();
-  } else {
-    ORT_ENFORCE(index < q_nodes_.size(), "invalid output node index");
-    return q_nodes_[index]->OutputEdgesEnd();
-  }
-}
-
-std::vector<const Node*> NodeUnit::GetAllNodesInGroup() const noexcept {
-  std::vector<const Node*> all_nodes = dq_nodes_;
-  all_nodes.push_back(&target_node_);
-  all_nodes.insert(all_nodes.end(), q_nodes_.begin(), q_nodes_.end());
-  return all_nodes;
-}
-
-std::pair<std::vector<std::unique_ptr<NodeUnit>>, std::unordered_map<const Node*, const NodeUnit*>>
-GetAllNodeUnits(const GraphViewer& graph_viewer) {
-  std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
-  std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
-
-  const auto add_node_unit_to_map = [&](const std::vector<NodeIndex>& node_indices, const NodeUnit* node_unit) {
-    for (const auto& node_idx : node_indices) {
-      const auto* node = graph_viewer.GetNode(node_idx);
-      node_unit_map.insert({node, node_unit});
-    }
-  };
-
-  // Get QDQ NodeUnits first
-  QDQ::SelectorManager selector_mgr;
-  const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer);
-
-  for (const auto& qdq_selection : qdq_selections) {
-    auto qdq_unit = std::make_unique<NodeUnit>(graph_viewer, qdq_selection);
-
-    // Fill the node to node_unit map for all nodes in the QDQ Group
-    add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get());
-    add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get());
-    add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get());
-
-    node_unit_holder.push_back(std::move(qdq_unit));
-  }
-
-  // Get the left over SingleNode NodeUnits
-  const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
-  for (const auto node_idx : node_indices) {
-    const auto* node(graph_viewer.GetNode(node_idx));
-
-    // This is already part of a QDQ NodeUnit
-    if (node_unit_map.find(node) != node_unit_map.cend())
-      continue;
-
-    auto node_unit = std::make_unique<NodeUnit>(*node);
-    node_unit_map[node] = node_unit.get();
-    node_unit_holder.push_back(std::move(node_unit));
-  }
-
-  return std::make_pair(std::move(node_unit_holder), std::move(node_unit_map));
-}
-
-}  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc
index 37ad14ac2e9b1..2088618538de5 100644
--- a/onnxruntime/core/providers/shared/utils/utils.cc
+++ b/onnxruntime/core/providers/shared/utils/utils.cc
@@ -4,12 +4,12 @@
 
 #include "utils.h"
 
-#include <core/common/safeint.h>
-#include <core/framework/tensorprotoutils.h>
-#include <core/graph/graph.h>
-#include <core/providers/common.h>
-#include "core/providers/shared/node_unit/node_unit.h"
+#include "core/common/safeint.h"
+#include "core/framework/node_unit.h"
+#include "core/framework/tensorprotoutils.h"
+#include "core/graph/graph.h"
 #include "core/optimizer/initializer.h"
+#include "core/providers/common.h"
 
 namespace onnxruntime {
 
@@ -118,84 +118,134 @@ NodeAttrHelper::NodeAttrHelper(const NodeUnit& node_unit)
     : node_attributes_(node_unit.GetNode().GetAttributes()) {}
 
 float NodeAttrHelper::Get(const std::string& key, float def_val) const {
-  if (!HasAttr(key))
-    return def_val;
+  if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
+    return entry->second.f();
+  }
 
-  return node_attributes_.at(key).f();
+  return def_val;
 }
 
 int32_t NodeAttrHelper::Get(const std::string& key, int32_t def_val) const {
-  if (!HasAttr(key))
-    return def_val;
+  if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
+    return narrow<int32_t>(entry->second.i());
+  }
 
-  return SafeInt<int32_t>(node_attributes_.at(key).i());
+  return def_val;
 }
 
 uint32_t NodeAttrHelper::Get(const std::string& key, uint32_t def_val) const {
-  if (!HasAttr(key))
-    return def_val;
+  if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
+    return narrow<uint32_t>(entry->second.i());
+  }
 
-  return SafeInt<uint32_t>(node_attributes_.at(key).i());
+  return def_val;
 }
 
 int64_t NodeAttrHelper::Get(const std::string& key, int64_t def_val) const {
-  if (!HasAttr(key))
-    return def_val;
+  if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
+    return entry->second.i();
+  }
 
-  return node_attributes_.at(key).i();
+  return def_val;
 }
 
 const std::string& NodeAttrHelper::Get(const std::string& key, const std::string& def_val) const {
-  if (!HasAttr(key))
-    return def_val;
+  if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
+    return entry->second.s();
+  }
 
-  return node_attributes_.at(key).s();
+  return def_val;
 }
 
 std::vector<int32_t> NodeAttrHelper::Get(const std::string& key, const std::vector<int32_t>& def_val) const {
-  if (!HasAttr(key))
-    return def_val;
+  if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
+    const auto& attr = entry->second;
+    std::vector<int32_t> v;
+    v.reserve(static_cast<size_t>(attr.ints_size()));
+    std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v),
+                   [](int64_t val) -> int32_t { return narrow<int32_t>(val); });
+    return v;
+  }
 
-  const auto& attr(node_attributes_.at(key));
-  std::vector<int32_t> v;
-  v.reserve(static_cast<size_t>(attr.ints_size()));
-  std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v),
-                 [](int64_t val) -> int32_t { return SafeInt<int32_t>(val); });
-  return v;
+  return def_val;
 }
 
 std::vector<uint32_t> NodeAttrHelper::Get(const std::string& key, const std::vector<uint32_t>& def_val) const {
-  if (!HasAttr(key))
-    return def_val;
+  if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
+    const auto& attr = entry->second;
+    std::vector<uint32_t> v;
+    v.reserve(static_cast<size_t>(attr.ints_size()));
+    std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v),
+                   [](int64_t val) -> uint32_t { return narrow<uint32_t>(val); });
+    return v;
+  }
 
-  const auto& attr(node_attributes_.at(key));
-  std::vector<uint32_t> v;
-  v.reserve(static_cast<size_t>(attr.ints_size()));
-  std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v),
-                 [](int64_t val) -> uint32_t { return SafeInt<uint32_t>(val); });
-  return v;
+  return def_val;
 }
 
 std::vector<int64_t> NodeAttrHelper::Get(const std::string& key, const std::vector<int64_t>& def_val) const {
-  if (!HasAttr(key))
-    return def_val;
+  if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
+    const auto& values = entry->second.ints();
+    return std::vector<int64_t>{values.cbegin(), values.cend()};
+  }
 
-  const auto& source(node_attributes_.at(key).ints());
-  return std::vector<int64_t>{source.cbegin(), source.cend()};
+  return def_val;
 }
 
 std::vector<float> NodeAttrHelper::Get(const std::string& key, const std::vector<float>& def_val) const {
-  if (!HasAttr(key))
-    return def_val;
+  if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
+    const auto& values = entry->second.floats();
+    return std::vector<float>{values.cbegin(), values.cend()};
+  }
 
-  const auto& source(node_attributes_.at(key).floats());
-  return std::vector<float>{source.cbegin(), source.cend()};
+  return def_val;
 }
 
-std::optional<int64_t> NodeAttrHelper::GetInt(const std::string& key) const {
-  if (!HasAttr(key))
-    return std::nullopt;
-  return node_attributes_.at(key).i();
+std::optional<float> NodeAttrHelper::GetFloat(const std::string& key) const {
+  std::optional<float> result;
+  if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
+    result = entry->second.f();
+  }
+
+  return result;
+}
+
+std::optional<int64_t> NodeAttrHelper::GetInt64(const std::string& key) const {
+  std::optional<int64_t> result;
+  if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
+    result = entry->second.i();
+  }
+
+  return result;
+}
+
+std::optional<std::vector<float>> NodeAttrHelper::GetFloats(const std::string& key) const {
+  std::optional<std::vector<float>> result;
+  if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
+    const auto& values = entry->second.floats();
+    result = std::vector<float>(values.begin(), values.end());
+  }
+
+  return result;
+}
+
+std::optional<std::vector<int64_t>> NodeAttrHelper::GetInt64s(const std::string& key) const {
+  std::optional<std::vector<int64_t>> result;
+  if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
+    const auto& values = entry->second.ints();
+    result = std::vector<int64_t>(values.begin(), values.end());
+  }
+
+  return result;
+}
+
+std::optional<std::string> NodeAttrHelper::GetString(const std::string& key) const {
+  std::optional<std::string> result;
+  if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
+    result = entry->second.s();
+  }
+
+  return result;
 }
 
 bool NodeAttrHelper::HasAttr(const std::string& key) const {
diff --git a/onnxruntime/core/providers/shared/utils/utils.h b/onnxruntime/core/providers/shared/utils/utils.h
index 31b1aba2e1a63..5813dcc48d72b 100644
--- a/onnxruntime/core/providers/shared/utils/utils.h
+++ b/onnxruntime/core/providers/shared/utils/utils.h
@@ -47,15 +47,17 @@ class NodeAttrHelper {
   // Get the attributes from the target node of the node_unit
   explicit NodeAttrHelper(const NodeUnit& node_unit);
 
+  /*
+   * Get with default
+   */
   float Get(const std::string& key, float def_val) const;
+  std::vector<float> Get(const std::string& key, const std::vector<float>& def_val) const;
 
   int64_t Get(const std::string& key, int64_t def_val) const;
+  std::vector<int64_t> Get(const std::string& key, const std::vector<int64_t>& def_val) const;
 
   const std::string& Get(const std::string& key, const std::string& def_val) const;
 
-  std::vector<int64_t> Get(const std::string& key, const std::vector<int64_t>& def_val) const;
-  std::vector<float> Get(const std::string& key, const std::vector<float>& def_val) const;
-
   // Convert the i() or ints() of the attribute from int64_t to int32_t
   int32_t Get(const std::string& key, int32_t def_val) const;
   std::vector<int32_t> Get(const std::string& key, const std::vector<int32_t>& def_val) const;
@@ -64,7 +66,16 @@ class NodeAttrHelper {
   uint32_t Get(const std::string& key, uint32_t def_val) const;
   std::vector<uint32_t> Get(const std::string& key, const std::vector<uint32_t>& def_val) const;
 
-  std::optional<int64_t> GetInt(const std::string& key) const;
+  /*
+   * Get without default.
+   */
+  std::optional<float> GetFloat(const std::string& key) const;
+  std::optional<std::vector<float>> GetFloats(const std::string& key) const;
+
+  std::optional<int64_t> GetInt64(const std::string& key) const;
+  std::optional<std::vector<int64_t>> GetInt64s(const std::string& key) const;
+
+  std::optional<std::string> GetString(const std::string& key) const;
 
   bool HasAttr(const std::string& key) const;
 
diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h
index b78279040acb6..1cebe4a256fd4 100644
--- a/onnxruntime/core/providers/shared_library/provider_api.h
+++ b/onnxruntime/core/providers/shared_library/provider_api.h
@@ -159,6 +159,7 @@ class OpKernel;
 struct OpKernelContext;
 struct OpKernelInfo;
 struct PrimitiveDataTypeBase;
+struct OrtRunOptions;
 struct Tensor;
 struct SparseTensor;
 class TensorSeq;
diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
index da17135878fe5..7b73ab36b3742 100644
--- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
+++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
@@ -24,6 +24,7 @@
 #include "core/providers/cpu/tensor/size.h"
 #include "core/providers/cpu/tensor/scatter_nd.h"
 #include "core/providers/cpu/tensor/unsqueeze.h"
+#include "core/providers/cpu/tensor/upsamplebase.h"
 #include "core/providers/cpu/tensor/tile.h"
 
 #ifndef DISABLE_CONTRIB_OPS
@@ -572,6 +573,11 @@ std::unique_ptr<EinsumTypedComputeProcessor<double>> EinsumTypedComputeProcessor
 template <>
 std::unique_ptr<EinsumTypedComputeProcessor<MLFloat16>> EinsumTypedComputeProcessor<MLFloat16>::Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) { return g_host_cpu.EinsumTypedComputeProcessor_MLFloat16__Create(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); }
 
+void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span<const int64_t> input_dims,
+                                            InlinedVector<float>& scales) const {
+  g_host_cpu.UpsampleBase__AdjustOutputSizeAsPolicy(this, output_dims, input_dims, scales);
+}
+
 #ifndef DISABLE_CONTRIB_OPS
 namespace contrib {
 Status embed_layer_norm::CheckInputs(const OpKernelContext* context, bool quantizedVersion) {
@@ -648,7 +654,6 @@ Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state, c
                                             const SessionState& subgraph_session_state) {
   return g_host_cpu.Sampling__SetupSubgraphExecutionInfo(this, session_state, attribute_name, subgraph_session_state);
 }
-
 }  // namespace transformers
 
 #ifdef ENABLE_ATEN
diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h
index f5a8327443864..8c8d5b1fd460a 100644
--- a/onnxruntime/core/providers/shared_library/provider_interfaces.h
+++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h
@@ -2,6 +2,7 @@
 // Licensed under the MIT License.
 
 #include <optional>
+#include <list>
 
 // Public wrappers around internal ort interfaces (currently)
 #include "core/providers/shared_library/provider_host_api.h"
@@ -34,6 +35,7 @@ struct ProviderHostCPU;
 class PhiloxGenerator;
 using ProviderType = const std::string&;
 class RandomGenerator;
+class IOnnxRuntimeOpSchemaCollection;
 
 #ifdef ENABLE_TRAINING_TORCH_INTEROP
 namespace contrib {
@@ -93,6 +95,8 @@ using NodeIndex = size_t;
 // using NodeAttributes = std::unordered_map<std::string, ONNX_NAMESPACE::AttributeProto_Copyable>;
 using ModelMetaData = std::unordered_map<std::string, std::string>;
 
+using IOnnxRuntimeOpSchemaCollectionPtr = std::shared_ptr<IOnnxRuntimeOpSchemaCollection>;
+using IOnnxRuntimeOpSchemaRegistryList = std::list<IOnnxRuntimeOpSchemaCollectionPtr>;
 using InitializedTensorSet = std::unordered_map<std::string, const ONNX_NAMESPACE::TensorProto*>;
 
 struct Node__NodeIterator {
@@ -435,6 +439,7 @@ struct ProviderHost {
   virtual void TensorProto__clear_int64_data(ONNX_NAMESPACE::TensorProto* p) = 0;
   virtual void TensorProto__clear_double_data(ONNX_NAMESPACE::TensorProto* p) = 0;
   virtual void TensorProto__clear_uint64_data(ONNX_NAMESPACE::TensorProto* p) = 0;
+  virtual void TensorProto__set_data_location(ONNX_NAMESPACE::TensorProto* p, ONNX_NAMESPACE::TensorProto_DataLocation data_location) = 0;
 
   virtual bool TensorProto_DataType_IsValid(int value) = 0;
 
@@ -481,6 +486,9 @@ struct ProviderHost {
   // ConfigOptions
   virtual std::optional<std::string> ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) = 0;
 
+  // OrtRunOptions
+  virtual const ConfigOptions& RunOptions__GetConfigOptions(const RunOptions* p) = 0;
+
   // ComputeCapability
   virtual std::unique_ptr<ComputeCapability> ComputeCapability__construct(std::unique_ptr<IndexedSubGraph> t_sub_graph) = 0;
   virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0;
@@ -752,8 +760,9 @@ struct ProviderHost {
   virtual void NodeAttributes__reserve(NodeAttributes* p, size_t size) = 0;
 
   // Model
-  virtual std::unique_ptr<Model> Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto,
-                                                  const PathString& model_path, const logging::Logger& logger) = 0;
+  virtual std::unique_ptr<Model> Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path,
+                                                  const IOnnxRuntimeOpSchemaRegistryList* local_registries,
+                                                  const logging::Logger& logger) = 0;
   virtual void Model__operator_delete(Model* p) = 0;
   virtual Graph& Model__MainGraph(Model* p) = 0;
   virtual std::unique_ptr<ONNX_NAMESPACE::ModelProto> Model__ToProto(Model* p) = 0;
@@ -811,6 +820,7 @@ struct ProviderHost {
   virtual Node* Graph__GetNode(Graph* p, NodeIndex node_index) noexcept = 0;
   virtual const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const = 0;
   virtual const NodeArg* Graph__GetNodeArg(const Graph* p, const std::string& name) const = 0;
+  virtual IOnnxRuntimeOpSchemaCollectionPtr Graph__GetSchemaRegistry(const Graph* p) const = 0;
 
   // GraphViewer
   virtual void GraphViewer__operator_delete(GraphViewer* p) = 0;
diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
index dde4005c80b9d..bdad18c7edec0 100644
--- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
+++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
@@ -205,6 +205,7 @@ struct TensorProto final {
 
   bool has_data_location() const { return g_host->TensorProto__has_data_location(this); }
   TensorProto_DataLocation data_location() const { return TensorProto_DataLocation(g_host->TensorProto__data_location(this)); }
+  void set_data_location(TensorProto_DataLocation data_location) { return g_host->TensorProto__set_data_location(this, data_location); }
 
   bool has_raw_data() const { return g_host->TensorProto__has_raw_data(this); }
   const std::string& raw_data() const { return g_host->TensorProto__raw_data(this); }
@@ -393,6 +394,14 @@ struct ConfigOptions final {
   PROVIDER_DISALLOW_ALL(ConfigOptions)
 };
 
+struct OrtRunOptions final {
+  const ConfigOptions& GetConfigOptions() const {
+    return g_host->RunOptions__GetConfigOptions(this);
+  }
+
+  PROVIDER_DISALLOW_ALL(OrtRunOptions)
+};
+
 struct ComputeCapability final {
   static std::unique_ptr<ComputeCapability> Create(std::unique_ptr<IndexedSubGraph> t_sub_graph) { return g_host->ComputeCapability__construct(std::move(t_sub_graph)); }
   static void operator delete(void* p) { g_host->ComputeCapability__operator_delete(reinterpret_cast<ComputeCapability*>(p)); }
@@ -770,8 +779,8 @@ struct NodeAttributes final {
 
 struct Model final {
   static std::unique_ptr<Model> Create(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path,
-                                       const logging::Logger& logger) {
-    return g_host->Model__construct(std::move(model_proto), model_path, logger);
+                                       const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) {
+    return g_host->Model__construct(std::move(model_proto), model_path, local_registries, logger);
   }
   static void operator delete(void* p) { g_host->Model__operator_delete(reinterpret_cast<Model*>(p)); }
   static Status Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { return g_host->Model__Load(file_path, model_proto); }
@@ -849,6 +858,7 @@ struct Graph final {
   const Node* GetNode(NodeIndex node_index) const noexcept { return g_host->Graph__GetNode(this, node_index); }
   Node* GetNode(NodeIndex node_index) noexcept { return g_host->Graph__GetNode(this, node_index); }
   const NodeArg* GetNodeArg(const std::string& name) const { return g_host->Graph__GetNodeArg(this, name); }
+  IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return g_host->Graph__GetSchemaRegistry(this); }
 
   PROVIDER_DISALLOW_ALL(Graph)
 };
diff --git a/onnxruntime/core/providers/tensorrt/nv_includes.h b/onnxruntime/core/providers/tensorrt/nv_includes.h
new file mode 100644
index 0000000000000..c3e9f7a3a2a77
--- /dev/null
+++ b/onnxruntime/core/providers/tensorrt/nv_includes.h
@@ -0,0 +1,20 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#pragma once
+
+// File to include the required TRT headers with workarounds for warnings we can't fix.
+
+// Ignore warning C4100: unreferenced formal parameter
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4100)
+#endif
+
+#include <NvInfer.h>
+#include <NvInferPlugin.h>
+#include <NvInferRuntime.h>
+#include <NvOnnxParser.h>
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
index bf3bf9e3495d7..9f1e5178428e7 100644
--- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
+++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
@@ -6,7 +6,7 @@
 #include <string>
 #include <filesystem>
 
-#include "NvInfer.h"
+#include "core/providers/tensorrt/nv_includes.h"
 #include "core/providers/shared_library/provider_api.h"
 
 namespace onnxruntime {
diff --git a/onnxruntime/core/providers/tensorrt/ort_trt_int8_cal_table.fbs.h b/onnxruntime/core/providers/tensorrt/ort_trt_int8_cal_table.fbs.h
index 9e4324fb9f516..a2e027f56fbd9 100644
--- a/onnxruntime/core/providers/tensorrt/ort_trt_int8_cal_table.fbs.h
+++ b/onnxruntime/core/providers/tensorrt/ort_trt_int8_cal_table.fbs.h
@@ -3,7 +3,7 @@
 #ifndef FLATBUFFERS_GENERATED_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_
 #define FLATBUFFERS_GENERATED_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_
 
-#include "flatbuffers/flatbuffers.h"
+#include "core/common/flatbuffers.h"
 
 namespace CalTableFlatBuffers {
 
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
index c0bf29e486c88..632d521dc21a8 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
@@ -7,6 +7,7 @@
 #define ORT_API_MANUAL_INIT
 #include "core/session/onnxruntime_cxx_api.h"
 #include "core/common/common.h"
+#include "core/common/narrow.h"
 #include "core/common/safeint.h"
 #include "tensorrt_execution_provider.h"
 #include "tensorrt_execution_provider_utils.h"
@@ -137,10 +138,10 @@ std::vector<std::string> SplitToStringVec(std::string const& s, char separator)
   return splitted;
 }
 
-nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_sting) {
+nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) {
   nvinfer1::TacticSources disabledTactics = 0;
   nvinfer1::TacticSources enabledTactics = 0;
-  std::vector<std::string> tacticList = SplitToStringVec(tactic_sting, ',');
+  std::vector<std::string> tacticList = SplitToStringVec(tactic_string, ',');
   for (auto& t : tacticList) {
     bool enable{false};
     if (t.front() == '+') {
@@ -151,8 +152,8 @@ nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_sting) {
     t.erase(0, 1);
 
     const auto toUpper = [](std::string& sourceName) {
-      std::transform(
-          sourceName.begin(), sourceName.end(), sourceName.begin(), [](char c) { return std::toupper(c); });
+      std::transform(sourceName.begin(), sourceName.end(), sourceName.begin(),
+                     [](char c) { return onnxruntime::narrow<char>(std::toupper(c)); });
       return sourceName;
     };
 
@@ -288,7 +289,8 @@ void CudaCall<cudnnStatus_t, true>(cudnnStatus_t retCode, const char* exprString
   return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line);
 }
 
-void* OutputAllocator::reallocateOutput(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment) noexcept {
+void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size,
+                                        uint64_t /*alignment*/) noexcept {
   // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr
   // even for empty tensors, so allocate a dummy byte.
   size = std::max(size, static_cast<uint64_t>(1));
@@ -304,7 +306,7 @@ void* OutputAllocator::reallocateOutput(char const* tensorName, void* currentMem
   return outputPtr;
 }
 
-void OutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept {
+void OutputAllocator::notifyShape(char const* /*tensorName*/, nvinfer1::Dims const& dims) noexcept {
   output_shapes.clear();
   output_shapes.reserve(dims.nbDims);
   for (int i = 0; i < dims.nbDims; i++) {
@@ -613,20 +615,22 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector<nvinfer1::IOptimizatio
       tensor_shape_values[input_name].resize(shape_size);
       switch (tensor_type) {
         case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
-          auto input = std::make_unique<int32_t[]>(shape_size);
-          CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData<int32_t>(), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream));
+          auto input_shape = std::make_unique<int32_t[]>(shape_size);
+          CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_shape.get(), input_tensor.GetTensorData<int32_t>(),
+                                               shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream));
           CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
           for (int j = 0; j < shape_size; ++j) {
-            tensor_shape_values[input_name][j] = input[j];
+            tensor_shape_values[input_name][j] = input_shape[j];
           }
           break;
         }
         case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
-          auto input = std::make_unique<int64_t[]>(shape_size);
-          CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData<int64_t>(), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream));
+          auto input_shape = std::make_unique<int64_t[]>(shape_size);
+          CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_shape.get(), input_tensor.GetTensorData<int64_t>(),
+                                               shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream));
           CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
           for (int j = 0; j < shape_size; ++j) {
-            tensor_shape_values[input_name][j] = static_cast<int32_t>(input[j]);
+            tensor_shape_values[input_name][j] = static_cast<int32_t>(input_shape[j]);
           }
           break;
         }
@@ -717,6 +721,77 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector<nvinfer1::IOptimizatio
   return Status::OK();
 }
 
+#define CASE_GET_INPUT_TENSOR(DATA_TYPE, SrcT)                                              \
+  case DATA_TYPE: {                                                                         \
+    auto input_tensor_ptr = input_tensor.GetTensorData<SrcT>();                             \
+    if (input_tensor_ptr != nullptr && elem_cnt > 0) {                                      \
+      data = const_cast<SrcT*>(input_tensor_ptr);                                           \
+    } else {                                                                                \
+      scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, 1)); \
+      data = scratch_buffers.back().get();                                                  \
+    }                                                                                       \
+    break;                                                                                  \
+  }
+
+#define CASE_GET_CAST_INPUT_TENSOR(DATA_TYPE, SrcT, DstT)                                                         \
+  case DATA_TYPE: {                                                                                               \
+    auto input_tensor_ptr = input_tensor.GetTensorData<SrcT>();                                                   \
+    if (input_tensor_ptr != nullptr && elem_cnt > 0) {                                                            \
+      scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, elem_cnt * sizeof(DstT))); \
+      data = scratch_buffers.back().get();                                                                        \
+      cuda::Impl_Cast<SrcT, DstT>(stream, input_tensor_ptr, reinterpret_cast<DstT*>(data), elem_cnt);             \
+    } else {                                                                                                      \
+      scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, 1));                       \
+      data = scratch_buffers.back().get();                                                                        \
+    }                                                                                                             \
+    break;                                                                                                        \
+  }
+
+#define CASE_GET_OUTPUT_TENSOR(DATA_TYPE, SrcT)                                             \
+  case DATA_TYPE: {                                                                         \
+    auto output_tensor_ptr = output_tensor.GetTensorMutableData<SrcT>();                    \
+    if (output_tensor_ptr != nullptr && elem_cnt > 0) {                                     \
+      buffers[output_name] = output_tensor_ptr;                                             \
+    } else {                                                                                \
+      scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, 1)); \
+      buffers[output_name] = scratch_buffers.back().get();                                  \
+    }                                                                                       \
+    break;                                                                                  \
+  }
+
+#define CASE_GET_CAST_OUTPUT_TENSOR(DATA_TYPE, SrcT, DstT)                                                        \
+  case DATA_TYPE: {                                                                                               \
+    auto output_tensor_ptr = output_tensor.GetTensorMutableData<SrcT>();                                          \
+    if (output_tensor_ptr != nullptr && elem_cnt > 0) {                                                           \
+      scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, elem_cnt * sizeof(DstT))); \
+      buffers[output_name] = scratch_buffers.back().get();                                                        \
+      output_dim_sizes[i] = static_cast<int>(elem_cnt);                                                           \
+    } else {                                                                                                      \
+      scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, 1));                       \
+      buffers[output_name] = scratch_buffers.back().get();                                                        \
+      output_dim_sizes[i] = 1;                                                                                    \
+    }                                                                                                             \
+    break;                                                                                                        \
+  }
+
+#define CASE_COPY_TENSOR(DATA_TYPE, DstT)                                                                                                          \
+  case DATA_TYPE: {                                                                                                                                \
+    auto output_tensor_ptr = output_tensor.GetTensorMutableData<DstT>();                                                                           \
+    if (output_tensor_ptr != nullptr && elem_cnt > 0) {                                                                                            \
+      CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(DstT), cudaMemcpyDeviceToDevice, stream)); \
+    }                                                                                                                                              \
+    break;                                                                                                                                         \
+  }
+
+#define CASE_CAST_TENSOR(DATA_TYPE, SrcT, DstT)                                                                                                   \
+  case DATA_TYPE: {                                                                                                                               \
+    auto output_tensor_ptr = output_tensor.GetTensorMutableData<DstT>();                                                                          \
+    if (output_tensor_ptr != nullptr && elem_cnt > 0) {                                                                                           \
+      cuda::Impl_Cast<SrcT, DstT>(stream, reinterpret_cast<SrcT*>(allocator->getBuffer()), reinterpret_cast<DstT*>(output_tensor_ptr), elem_cnt); \
+    }                                                                                                                                             \
+    break;                                                                                                                                        \
+  }
+
 /*
  * Set TensorRT execution context input.
  *
@@ -737,6 +812,17 @@ Status BindContextInput(Ort::KernelContext& ctx,
   auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
   const auto tensor_shapes = tensor_info.GetShape();
   const auto tensor_type = tensor_info.GetElementType();
+  /*
+   * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other).
+   * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1.
+   *
+   * Examples:<br>
+   * [] = 1<br>
+   * [1,3,4] = 12<br>
+   * [2,0,4] = 0<br>
+   * [-1,3,4] = -1<br>
+   */
+  const auto elem_cnt = tensor_info.GetElementCount();
 
   if (trt_engine->isShapeInferenceIO(input_name)) {
     // Get the shape value of "shape tensor"
@@ -765,113 +851,24 @@ Status BindContextInput(Ort::KernelContext& ctx,
       ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                          "TensorRT EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'"));
     }
-    // Bind "execution tensor" input buffers
+
+    // Bind "execution tensor" input buffer
+    //
+    // Note: If an engine binding is an empty tensor, it still needs a non-null memory address, and different tensors should have different addresses.
+    //       Therefore, in the case of empty tensor, TRT EP always allocates a dummy byte.
+    //       https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#empty-tensors
     void* data = nullptr;
     switch (tensor_type) {
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
-        auto input_tensor_ptr = input_tensor.GetTensorData<float>();
-        if (input_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(float)));
-          data = scratch_buffers.back().get();
-        } else {
-          data = const_cast<float*>(input_tensor_ptr);
-        }
-        break;
-      }
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: {
-        auto input_tensor_ptr = input_tensor.GetTensorData<uint16_t>();
-        if (input_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(uint16_t)));
-          data = scratch_buffers.back().get();
-        } else {
-          data = const_cast<uint16_t*>(input_tensor_ptr);
-        }
-        break;
-      }
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
-        auto input_tensor_ptr = input_tensor.GetTensorData<bool>();
-        if (input_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(bool)));
-          data = scratch_buffers.back().get();
-        } else {
-          data = const_cast<bool*>(input_tensor_ptr);
-        }
-        break;
-      }
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
-        auto input_tensor_ptr = input_tensor.GetTensorData<int8_t>();
-        if (input_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(int8_t)));
-          data = scratch_buffers.back().get();
-        } else {
-          data = const_cast<int8_t*>(input_tensor_ptr);
-        }
-        break;
-      }
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: {
-        auto input_tensor_ptr = input_tensor.GetTensorData<uint8_t>();
-        if (input_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(uint8_t)));
-          data = scratch_buffers.back().get();
-        } else {
-          data = const_cast<uint8_t*>(input_tensor_ptr);
-        }
-        break;
-      }
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
-        auto input_tensor_ptr = input_tensor.GetTensorData<int32_t>();
-        if (input_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(int32_t)));
-          data = scratch_buffers.back().get();
-        } else {
-          data = const_cast<int32_t*>(input_tensor_ptr);
-        }
-        break;
-      }
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
-        // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64
-        auto input_tensor_ptr = input_tensor.GetTensorData<int64_t>();
-        if (input_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(int32_t)));
-          data = scratch_buffers.back().get();
-        } else {
-          SafeInt<int> input_dim_size = 1;
-          for (int j = 0, end = nb_dims; j < end; ++j) {
-            if (tensor_shapes[j] == 0) {
-              input_dim_size = 1;
-              break;
-            } else {
-              input_dim_size *= tensor_shapes[j];
-            }
-          }
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, input_dim_size * sizeof(int32_t)));
-          data = scratch_buffers.back().get();
-          cuda::Impl_Cast<int64_t, int32_t>(stream, input_tensor_ptr, reinterpret_cast<int32_t*>(data), input_dim_size);
-        }
-        break;
-      }
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
-        // Cast DOUBLE input to FLOAT because TensorRT doesn't fully support INT64
-        auto input_tensor_ptr = input_tensor.GetTensorData<double>();
-        if (input_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(float)));
-          data = scratch_buffers.back().get();
-        } else {
-          SafeInt<int> input_dim_size = 1;
-          for (int j = 0, end = nb_dims; j < end; ++j) {
-            if (tensor_shapes[j] == 0) {
-              input_dim_size = 1;
-              break;
-            } else {
-              input_dim_size *= tensor_shapes[j];
-            }
-          }
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, input_dim_size * sizeof(float)));
-          data = scratch_buffers.back().get();
-          cuda::Impl_Cast<double, float>(stream, input_tensor_ptr, reinterpret_cast<float*>(data), input_dim_size);
-        }
-        break;
-      }
+      CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float)
+      CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t)
+      CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool)
+      CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
+      CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
+      CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
+      // Cast int64 input to int32 input because TensorRT doesn't support int64
+      CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t)
+      // Cast double input to float because TensorRT doesn't support double
+      CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float)
       default: {
         return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                "TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported.");
@@ -884,7 +881,7 @@ Status BindContextInput(Ort::KernelContext& ctx,
 }
 
 /*
- * Set TensorRT execution context output.
+ * Bind TensorRT execution context output.
  *
  * Please note that the "data-depedent shape" output needs corresponding allocator provided.
  *
@@ -912,7 +909,6 @@ Status BindContextOutput(Ort::KernelContext& ctx,
                          size_t i,
                          std::unordered_map<size_t, Ort::UnownedValue>& output_tensors,
                          std::unordered_map<size_t, int>& output_dim_sizes,
-                         std::unordered_set<char const*>& dds_output_set,
                          DDSOutputAllocatorMap& dds_output_allocator_map,
                          std::vector<IAllocatorUniquePtr<void>>& scratch_buffers,
                          OrtAllocator* alloc,
@@ -920,142 +916,47 @@ Status BindContextOutput(Ort::KernelContext& ctx,
   // Get output shape
   nvinfer1::Dims dims = trt_context->getTensorShape(output_name);
   int nb_dims = dims.nbDims;
-  bool is_dds_output = false;
+  bool is_DDS = false;
   std::vector<int64_t> output_shapes(nb_dims);
   for (int j = 0, end = nb_dims; j < end; ++j) {
     // data-dependent shape
     if (dims.d[j] == -1) {
-      is_dds_output = true;
-      dds_output_set.emplace(output_name);
+      is_DDS = true;
       break;
     }
     output_shapes[j] = dims.d[j];
   }
 
+  auto known_DDS = dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end();
+
   // If the output tensor has data-dependent shape, TRT EP will provide an IOutputAllocator for enqueueV3 to dynamically allocate memory buffer.
   // Once enqueueV3 returns, TRT EP will then bind the output allocation to ORT kernel context output.
   // (Please note that we take strategy A mentioned in https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#dynamic-shaped-output,
   //  which we defer allocation until the size is known and don't call IExecution::setTensorAddress)
   //
   // Otherwise, if the shape of the output tensor is known prior to the runtime, ORT will pre-allocate memory buffer for the output tensor for enqueueV3.
-  if (is_dds_output) {
-    if (dds_output_allocator_map.find(output_name) == dds_output_allocator_map.end()) {
+  if (is_DDS || known_DDS) {
+    if (!known_DDS) {
       auto allocatorPtr = std::make_unique<OutputAllocator>();
       trt_context->setOutputAllocator(output_name, allocatorPtr.get());
       dds_output_allocator_map[output_name] = std::move(allocatorPtr);
-    } else {
-      trt_context->setOutputAllocator(output_name, dds_output_allocator_map[output_name].get());
     }
   } else {
     output_tensors[i] = ctx.GetOutput(output_index, output_shapes);
     auto& output_tensor = output_tensors[i];
+    const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount();
+
     switch (output_type) {
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
-        auto output_tensor_ptr = output_tensor.GetTensorMutableData<float>();
-        if (output_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(float)));
-          buffers[output_name] = scratch_buffers.back().get();
-        } else {
-          buffers[output_name] = output_tensor_ptr;
-        }
-        break;
-      }
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: {
-        auto output_tensor_ptr = output_tensor.GetTensorMutableData<uint16_t>();
-        if (output_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(uint16_t)));
-          buffers[output_name] = scratch_buffers.back().get();
-        } else {
-          buffers[output_name] = output_tensor_ptr;
-        }
-        break;
-      }
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
-        auto output_tensor_ptr = output_tensor.GetTensorMutableData<bool>();
-        if (output_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(bool)));
-          buffers[output_name] = scratch_buffers.back().get();
-        } else {
-          buffers[output_name] = output_tensor_ptr;
-        }
-        break;
-      }
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
-        auto output_tensor_ptr = output_tensor.GetTensorMutableData<int8_t>();
-        if (output_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(int8_t)));
-          buffers[output_name] = scratch_buffers.back().get();
-        } else {
-          buffers[output_name] = output_tensor_ptr;
-        }
-        break;
-      }
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: {
-        auto output_tensor_ptr = output_tensor.GetTensorMutableData<uint8_t>();
-        if (output_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(uint8_t)));
-          buffers[output_name] = scratch_buffers.back().get();
-        } else {
-          buffers[output_name] = output_tensor_ptr;
-        }
-        break;
-      }
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
-        auto output_tensor_ptr = output_tensor.GetTensorMutableData<int32_t>();
-        if (output_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(int32_t)));
-          buffers[output_name] = scratch_buffers.back().get();
-        } else {
-          buffers[output_name] = output_tensor_ptr;
-        }
-        break;
-      }
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
-        // Allocate INT32 CUDA memory for INT64 output type because TensorRT doesn't fully support INT64
-        auto output_tensor_ptr = output_tensor.GetTensorMutableData<int64_t>();
-        if (output_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(int32_t)));
-          buffers[output_name] = scratch_buffers.back().get();
-          output_dim_sizes[i] = 1;
-        } else {
-          SafeInt<int> output_dim_size(1);
-          for (int j = 0, end = nb_dims; j < end; ++j) {
-            if (dims.d[j] == 0) {
-              output_dim_size = 1;
-              break;
-            } else {
-              output_dim_size *= dims.d[j];
-            }
-          }
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, output_dim_size * sizeof(int32_t)));
-          buffers[output_name] = scratch_buffers.back().get();
-          output_dim_sizes[i] = output_dim_size;
-        }
-        break;
-      }
-      case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
-        // Allocate FLOAT CUDA memory for DOUBLE output type because TensorRT doesn't fully support DOUBLE
-        auto output_tensor_ptr = output_tensor.GetTensorMutableData<double>();
-        if (output_tensor_ptr == nullptr) {
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, sizeof(float)));
-          buffers[output_name] = scratch_buffers.back().get();
-          output_dim_sizes[i] = 1;
-        } else {
-          SafeInt<int> output_dim_size(1);
-          for (int j = 0, end = nb_dims; j < end; ++j) {
-            if (dims.d[j] == 0) {
-              output_dim_size = 1;
-              break;
-            } else {
-              output_dim_size *= dims.d[j];
-            }
-          }
-          scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, output_dim_size * sizeof(float)));
-          buffers[output_name] = scratch_buffers.back().get();
-          output_dim_sizes[i] = output_dim_size;
-        }
-        break;
-      }
+      CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float)
+      CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t)
+      CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool)
+      CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
+      CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
+      CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
+      // Allocate int32 CUDA memory for int64 output type because TensorRT doesn't support int64
+      CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t)
+      // Allocate float CUDA memory for double output type because TensorRT doesn't support double
+      CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float)
       default: {
         return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                                "TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.");
@@ -1068,13 +969,16 @@ Status BindContextOutput(Ort::KernelContext& ctx,
 }
 
 /*
- * Set ORT kernel context Output.
+ * Bind ORT kernel context Output.
  *
- * Note: In the case of DDS (data-dependent shape) output, TRT requires a provided allocator to allocate memory during runtime.
+ * In the case of DDS (data-dependent shape) output, TRT requires a provided allocator to allocate memory during runtime.
  * Once the output has been put in the allocation buffer, ORT calls this function to bind the allocation to ORT kernel context output.
+ *
+ * Note: Current approach of setting the ORT kernel context output is copying the output data from allocation buffer to ORT context output address which is not optimal,
+ * we are waiting for ORT core to support "assign" memory address to ORT context output. Some works need to be done in ORT memory planner to be aware of this memory support.
  */
 Status BindKernelOutput(Ort::KernelContext& ctx,
-                        OrtMemoryInfo* mem_info,
+                        OrtMemoryInfo* /*mem_info*/,
                         DDSOutputAllocatorMap& allocator_map,
                         char const* output_name,
                         size_t output_index,
@@ -1083,93 +987,46 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
   auto allocator = allocator_map[output_name].get();
   auto& shape = allocator->getOutputShape();
   auto output_tensor = ctx.GetOutput(output_index, shape);
+
+  /*
+   * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other).
+   * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1.
+   *
+   * Examples:<br>
+   * [] = 1<br>
+   * [1,3,4] = 12<br>
+   * [2,0,4] = 0<br>
+   * [-1,3,4] = -1<br>
+   */
   auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount();
 
+  /*
+   * Copy output data from allocation buffer to ORT kernel context output location or
+   * cast (int32 or float) -> (int64 or double) to ORT kernel context output location.
+   *
+   * Note:
+   * 1. If the output tensor is empty tensor (i.e. any of the dimension is 0) which means element count is 0,
+   *    TRT EP does not perform cuda memory copy nor cuda cast to prevent overwriting other location that might belong to other tensors.
+   * 2. The cudaMemcpyAsync() and cuda::Impl_Cast() (implemented as _UnaryElementWise() in cuda ep) are all async, but we
+   *    don't need to explicitly call cudaStreamSynchronize() after those APIs due to CUDA EP and TRT EP uses same stream,
+   *    and within the same stream, operations are guaranteed to be executed in order.
+   */
   switch (output_type) {
-    case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
-      auto output_tensor_ptr = output_tensor.GetTensorMutableData<float>();
-      if (output_tensor_ptr != nullptr) {
-        CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(float), cudaMemcpyDeviceToDevice, stream));
-      }
-      break;
-    }
-    case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: {
-      auto output_tensor_ptr = output_tensor.GetTensorMutableData<uint16_t>();
-      if (output_tensor_ptr != nullptr) {
-        CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(uint16_t), cudaMemcpyDeviceToDevice, stream));
-      }
-      break;
-    }
-    case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
-      auto output_tensor_ptr = output_tensor.GetTensorMutableData<bool>();
-      if (output_tensor_ptr != nullptr) {
-        CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(bool), cudaMemcpyDeviceToDevice, stream));
-      }
-      break;
-    }
-    case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
-      auto output_tensor_ptr = output_tensor.GetTensorMutableData<int8_t>();
-      if (output_tensor_ptr != nullptr) {
-        CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(int8_t), cudaMemcpyDeviceToDevice, stream));
-      }
-      break;
-    }
-    case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: {
-      auto output_tensor_ptr = output_tensor.GetTensorMutableData<uint8_t>();
-      if (output_tensor_ptr != nullptr) {
-        CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(uint8_t), cudaMemcpyDeviceToDevice, stream));
-      }
-      break;
-    }
-    case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
-      auto output_tensor_ptr = output_tensor.GetTensorMutableData<int32_t>();
-      if (output_tensor_ptr != nullptr) {
-        CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(int32_t), cudaMemcpyDeviceToDevice, stream));
-      }
-      break;
-    }
-    case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
-      // The allocation buffer holds the INT32 output data since TRT doesn't support INT64 but INT32.
-      // So, we need to cast the data from INT32 to INT64 and then set INT64 output data to kernel context.
-      SafeInt<int> output_dim_size(1);
-      for (size_t i = 0; i < shape.size(); ++i) {
-        if (shape[i] == 0) {
-          output_dim_size = 1;
-          break;
-        } else {
-          output_dim_size *= shape[i];
-        }
-      }
-      auto output_tensor_ptr = output_tensor.GetTensorMutableData<int64_t>();
-      if (output_tensor_ptr != nullptr) {
-        cuda::Impl_Cast<int32_t, int64_t>(stream, reinterpret_cast<int32_t*>(allocator->getBuffer()), reinterpret_cast<int64_t*>(output_tensor_ptr), output_dim_size);
-      }
-      break;
-    }
-    case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
-      // The allocation buffer holds the FLOAT output data since TRT doesn't support DOUBLE but FLOAT.
-      // So, we need to cast the data from FLOAT to DOUBEL and then set DOUBLE output data to kernel context.
-      SafeInt<int> output_dim_size(1);
-      for (size_t i = 0; i < shape.size(); ++i) {
-        if (shape[i] == 0) {
-          output_dim_size = 1;
-          break;
-        } else {
-          output_dim_size *= shape[i];
-        }
-      }
-      auto output_tensor_ptr = output_tensor.GetTensorMutableData<double>();
-      if (output_tensor_ptr != nullptr) {
-        cuda::Impl_Cast<float, double>(stream, reinterpret_cast<float*>(allocator->getBuffer()), reinterpret_cast<double*>(output_tensor_ptr), output_dim_size);
-      }
-      break;
-    }
+    CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float)
+    CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t)
+    CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool)
+    CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
+    CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
+    CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
+    // The allocation buffer holds the int32 output data since TRT doesn't support int64. So, we need to cast the data (int32 -> int64) for ORT kernel output.
+    CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int32_t, int64_t)
+    // The allocation buffer holds the float output data since TRT doesn't support double. So, we need to cast the data (float -> double) for ORT kernel output.
+    CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double)
     default: {
       return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
                              "TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.");
     }
   }
-  CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
   return Status::OK();
 }
 
@@ -1290,7 +1147,8 @@ TensorrtExecutionProvider::PerThreadContext& TensorrtExecutionProvider::GetPerTh
 
     // get or create a context
     if (context_state_.retired_context_pool.empty()) {
-      context = std::make_shared<PerThreadContext>(info_.device_id, info_.has_user_compute_stream, stream_);
+      context = std::make_shared<PerThreadContext>(narrow<OrtDevice::DeviceId>(info_.device_id),
+                                                   info_.has_user_compute_stream, stream_);
     } else {
       context = context_state_.retired_context_pool.back();
       context_state_.retired_context_pool.pop_back();
@@ -1310,7 +1168,11 @@ TensorrtExecutionProvider::PerThreadContext& TensorrtExecutionProvider::GetPerTh
 }
 
 TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProviderInfo& info)
-    : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info), device_id_(info.device_id) {
+    : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider,
+                         OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT,
+                                   narrow<OrtDevice::DeviceId>(info.device_id))},
+      info_(info),
+      device_id_(info.device_id) {
   InitProviderOrtApi();
 
   CUDA_CALL_THROW(cudaSetDevice(device_id_));
@@ -1771,26 +1633,26 @@ bool TensorrtExecutionProvider::IsGraphCaptureAllowed() const {
   return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_;
 }
 
-void TensorrtExecutionProvider::CaptureBegin() {
+void TensorrtExecutionProvider::CaptureBegin(int) {
   cuda_graph_.Reset();
-  cuda_graph_.CaptureBegin();
+  cuda_graph_.CaptureBegin(0);
 }
 
-void TensorrtExecutionProvider::CaptureEnd() {
-  cuda_graph_.CaptureEnd();
+void TensorrtExecutionProvider::CaptureEnd(int) {
+  cuda_graph_.CaptureEnd(0);
   is_graph_captured_ = true;
 }
 
-bool TensorrtExecutionProvider::IsGraphCaptured() const {
+bool TensorrtExecutionProvider::IsGraphCaptured(int) const {
   return is_graph_captured_;
 }
 
-Status TensorrtExecutionProvider::ReplayGraph() {
-  ORT_ENFORCE(IsGraphCaptured());
+Status TensorrtExecutionProvider::ReplayGraph(int) {
+  ORT_ENFORCE(IsGraphCaptured(0));
   // Please note that CUDAGraph::Replay() is not thread safe.
-  // ORT TRT calls ReplayGraph() in compute_func() where synchromization is enforced due to lock_guard(),
+  // ORT TRT calls ReplayGraph() in compute_func() where synchronization is enforced due to lock_guard(),
   // therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe.
-  return cuda_graph_.Replay();
+  return cuda_graph_.Replay(0);
 }
 
 void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() {
@@ -1802,7 +1664,8 @@ void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() {
 
 std::vector<AllocatorPtr> TensorrtExecutionProvider::CreatePreferredAllocators() {
   AllocatorCreationInfo default_memory_info(
-      [](OrtDevice::DeviceId device_id) { return CreateCUDAAllocator(device_id, onnxruntime::CUDA); }, device_id_);
+      [](OrtDevice::DeviceId device_id) { return CreateCUDAAllocator(device_id, onnxruntime::CUDA); },
+      narrow<OrtDevice::DeviceId>(device_id_));
 
   AllocatorCreationInfo pinned_allocator_info(
       [](OrtDevice::DeviceId device_id) {
@@ -1818,11 +1681,11 @@ std::unique_ptr<IDataTransfer> TensorrtExecutionProvider::GetDataTransfer() cons
   return onnxruntime::CreateGPUDataTransfer();
 }
 
-Status TensorrtExecutionProvider::OnRunStart() {
+Status TensorrtExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
   return Status::OK();
 }
 
-Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) {
+Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) {
   if (sync_stream && external_stream_) {
     CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_));
   }
@@ -3183,7 +3046,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
     std::unordered_set<std::string> input_names;
     std::unordered_map<std::string, std::vector<int32_t>> tensor_shape_values;
 
-    OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id_), device_id_);
+    OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow<OrtDevice::DeviceId>(device_id_));
+    OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_);
     if (alloc_ == nullptr) {
       Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_));
     }
@@ -3513,7 +3377,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
     output_tensors.reserve(num_outputs);
     std::unordered_map<size_t, int> output_dim_sizes;
     output_dim_sizes.reserve(num_outputs);
-    std::unordered_set<char const*> dds_output_set;
 
     for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
       char const* output_name = output_binding_names[i];
@@ -3531,7 +3394,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
       }
 
       Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes,
-                                        dds_output_set, dds_output_allocator_map, scratch_buffers, alloc, buffers);
+                                        dds_output_allocator_map, scratch_buffers, alloc, buffers);
       if (status != Status::OK()) {
         return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
       }
@@ -3549,10 +3412,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
     // Start CUDA graph capture.
     // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because
     // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream.
-    if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) {
+    if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) {
       LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model";
       cuda_graph_.SetStream(stream);
-      CaptureBegin();
+      CaptureBegin(0);
     }
 
     // Run TRT inference
@@ -3590,7 +3453,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
         output_type = iter->second;
       }
 
-      if (dds_output_set.find(output_name) != dds_output_set.end()) {
+      if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) {
         size_t output_index = 0;
         const auto& index_iter = output_indexes.find(output_name);
         if (index_iter != output_indexes.end()) {
@@ -3620,12 +3483,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
     // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture
     // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc.
     // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis.
-    if (cuda_graph_enable_ && !IsGraphCaptured()) {
+    if (cuda_graph_enable_ && !IsGraphCaptured(0)) {
       if (IsGraphCaptureAllowed()) {
-        CaptureEnd();
+        CaptureEnd(0);
         // CUDA work issued to a capturing stream doesn’t actually run on the GPU,
         // so run the captured graph here to actually execute the work.
-        ORT_RETURN_IF_ERROR(ReplayGraph());
+        ORT_RETURN_IF_ERROR(ReplayGraph(0));
       } else {
         IncrementRegularRunCountBeforeGraphCapture();
       }
@@ -3751,7 +3614,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
     // int num_inputs = static_cast<int>(input_indexes.size());
     int num_outputs = static_cast<int>(output_indexes.size());
 
-    OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id_), device_id_);
+    OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow<OrtDevice::DeviceId>(device_id_));
+    OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_);
     if (alloc_ == nullptr) {
       Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_));
     }
@@ -3806,7 +3670,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
     output_tensors.reserve(num_outputs);
     std::unordered_map<size_t, int> output_dim_sizes;
     output_dim_sizes.reserve(num_outputs);
-    std::unordered_set<char const*> dds_output_set;
 
     for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
       char const* output_name = output_binding_names[i];
@@ -3824,7 +3687,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
       }
 
       Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes,
-                                        dds_output_set, dds_output_allocator_map, scratch_buffers, alloc, buffers);
+                                        dds_output_allocator_map, scratch_buffers, alloc, buffers);
       if (status != Status::OK()) {
         return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
       }
@@ -3842,10 +3705,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
     // Start CUDA graph capture.
     // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because
     // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream.
-    if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) {
+    if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) {
       LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model";
       cuda_graph_.SetStream(stream);
-      CaptureBegin();
+      CaptureBegin(0);
     }
 
     // Run TRT inference
@@ -3883,7 +3746,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
         output_type = iter->second;
       }
 
-      if (dds_output_set.find(output_name) != dds_output_set.end()) {
+      if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) {
         size_t output_index = 0;
         const auto& index_iter = output_indexes.find(output_name);
         if (index_iter != output_indexes.end()) {
@@ -3913,12 +3776,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
     // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture
     // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc.
     // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis.
-    if (cuda_graph_enable_ && !IsGraphCaptured()) {
+    if (cuda_graph_enable_ && !IsGraphCaptured(0)) {
       if (IsGraphCaptureAllowed()) {
-        CaptureEnd();
+        CaptureEnd(0);
         // CUDA work issued to a capturing stream doesn’t actually run on the GPU,
         // so run the captured graph here to actually execute the work.
-        ORT_RETURN_IF_ERROR(ReplayGraph());
+        ORT_RETURN_IF_ERROR(ReplayGraph(0));
       } else {
         IncrementRegularRunCountBeforeGraphCapture();
       }
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
index e86f997b6597a..f73031eaefceb 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
@@ -5,8 +5,9 @@
 #include <ctime>
 #include <cudnn.h>
 #include <cublas_v2.h>
-#include "NvInfer.h"
-#include "NvOnnxParser.h"
+
+#include "core/providers/tensorrt/nv_includes.h"
+
 #include "core/platform/ort_mutex.h"
 #include "core/providers/cuda/cuda_graph.h"
 #include "tensorrt_execution_provider_info.h"
@@ -233,8 +234,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {
   common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
                          std::vector<NodeComputeInfo>& node_compute_funcs) override;
 
-  Status OnRunStart() override;
-  Status OnRunEnd(bool sync_stream) override;
+  Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
+  Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override;
 
   ProviderOptions GetProviderOptions() const override {
     return TensorrtExecutionProviderInfo::ToProviderOptions(info_);
@@ -249,8 +250,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {
   std::vector<AllocatorPtr> CreatePreferredAllocators() override;
 
   bool IsGraphCaptureEnabled() const override;
-  bool IsGraphCaptured() const override;
-  Status ReplayGraph() override;
+  bool IsGraphCaptured(int graph_annotation_id) const override;
+  Status ReplayGraph(int graph_annotation_id) override;
 
  private:
   mutable TensorrtExecutionProviderInfo info_;
@@ -372,10 +373,10 @@ class TensorrtExecutionProvider : public IExecutionProvider {
     void InitCUDAGraph();
     void SetGraphStream(cudaStream_t stream);
     bool IsGraphCaptureAllowed() const;
-    void CaptureBegin();
-    void CaptureEnd();
-    bool IsGraphCaptured() const;
-    Status ReplayGraph();
+    void CaptureBegin(int graph_annotation_id);
+    void CaptureEnd(int graph_annotation_id);
+    bool IsGraphCaptured(int graph_annotation_id) const;
+    Status ReplayGraph(int graph_annotation_id);
     void IncrementRegularRunCountBeforeGraphCapture();
 
    private:
@@ -539,8 +540,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {
                                         std::vector<NodeComputeInfo>& node_compute_funcs);
 
   bool IsGraphCaptureAllowed() const;
-  void CaptureBegin();
-  void CaptureEnd();
+  void CaptureBegin(int graph_annotation_id);
+  void CaptureEnd(int graph_annotation_id);
   void IncrementRegularRunCountBeforeGraphCapture();
 
   /**
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc
index eb340ba1e64b6..b4f348159440f 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc
@@ -1,12 +1,11 @@
 // Copyright (c) Microsoft Corporation. All rights reserved.
 // Licensed under the MIT License.
 
+#include <unordered_set>
+
 #include "core/framework/provider_options.h"
 #include "tensorrt_execution_provider_custom_ops.h"
 #include "tensorrt_execution_provider.h"
-#include <NvInferRuntime.h>
-#include <NvInferPlugin.h>
-#include <unordered_set>
 
 namespace onnxruntime {
 extern TensorrtLogger& GetTensorrtLogger();
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h
index b19d9ab0f66d0..54212d34aa2ce 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h
@@ -13,7 +13,8 @@ using namespace onnxruntime;
 namespace onnxruntime {
 
 common::Status LoadDynamicLibrary(onnxruntime::PathString library_name);
-common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths);
+common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list,
+                                                const std::string extra_plugin_lib_paths);
 common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info);
 void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain);
 void ReleaseTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list);
@@ -23,16 +24,22 @@ struct TensorRTCustomKernel {
       : compute_stream_(compute_stream) {
   }
 
-  void Compute(OrtKernelContext* context){};  // The implementation is in TensorRT plugin. No need to implement it here.
+  void Compute(OrtKernelContext* /*context*/){
+      // The implementation is in TensorRT plugin. No need to implement it here.
+  };
 
  private:
   void* compute_stream_;
 };
 
 struct TensorRTCustomOp : Ort::CustomOpBase<TensorRTCustomOp, TensorRTCustomKernel> {
-  explicit TensorRTCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {}
+  explicit TensorRTCustomOp(const char* provider, void* compute_stream) : provider_(provider),
+                                                                          compute_stream_(compute_stream) {
+  }
 
-  void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { return new TensorRTCustomKernel(info, compute_stream_); };
+  void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const {
+    return new TensorRTCustomKernel(info, compute_stream_);
+  };
 
   const char* GetName() const { return name_; };
 
@@ -46,7 +53,9 @@ struct TensorRTCustomOp : Ort::CustomOpBase<TensorRTCustomOp, TensorRTCustomKern
 
   ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; };
 
-  OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; };
+  OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const {
+    return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC;
+  };
 
   size_t GetOutputTypeCount() const { return num_outputs_; };
 
@@ -54,7 +63,9 @@ struct TensorRTCustomOp : Ort::CustomOpBase<TensorRTCustomOp, TensorRTCustomKern
 
   ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; };
 
-  OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; };
+  OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const {
+    return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC;
+  };
 
   bool GetVariadicInputHomogeneity() const {
     return false;  // heterogenous
diff --git a/onnxruntime/core/providers/utils.cc b/onnxruntime/core/providers/utils.cc
index ca3fc4fc1972b..b2f9d265ca053 100644
--- a/onnxruntime/core/providers/utils.cc
+++ b/onnxruntime/core/providers/utils.cc
@@ -2,7 +2,7 @@
 // Licensed under the MIT License.
 
 #include "core/framework/tensorprotoutils.h"
-#include "utils.h"
+#include "core/providers/utils.h"
 
 namespace onnxruntime {
 namespace utils {
@@ -23,6 +23,5 @@ common::Status OutputOptionalWithoutDataHelper(const ONNX_NAMESPACE::TypeProto&
   return Status::OK();
 }
 #endif
-
 }  // namespace utils
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc
index f609d40f459b7..eba3230d283cf 100644
--- a/onnxruntime/core/providers/vitisai/imp/global_api.cc
+++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc
@@ -188,7 +188,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
     auto file_path = ToPathString(filename);
     auto status = Model::Load(file_path, *model_proto);
     vai_assert(status.IsOK(), "load model proto error");
-    auto model = Model::Create(std::move(*model_proto), file_path, logger);
+    auto model = Model::Create(std::move(*model_proto), file_path, nullptr, logger);
     return model.release();
   };
   the_global_api.model_delete = [](Model* model) { delete model; };
@@ -198,7 +198,8 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
     auto& model = const_cast<onnxruntime::Model&>(const_model);
     auto model_proto = model.ToProto();
     auto file_path = model.MainGraph().ModelPath().ToPathString();
-    auto ret = Model::Create(std::move(*model_proto), file_path, logger);
+    auto local_registries = IOnnxRuntimeOpSchemaRegistryList{model.MainGraph().GetSchemaRegistry()};
+    auto ret = Model::Create(std::move(*model_proto), file_path, &local_registries, logger);
     auto status = ret->MainGraph().Resolve();
     vai_assert(status.IsOK(), status.ErrorMessage());
     return ret.release();
diff --git a/onnxruntime/core/providers/vitisai/imp/node.cc b/onnxruntime/core/providers/vitisai/imp/node.cc
index 0565171fb7f40..432d7f7daead2 100644
--- a/onnxruntime/core/providers/vitisai/imp/node.cc
+++ b/onnxruntime/core/providers/vitisai/imp/node.cc
@@ -34,9 +34,17 @@ vaip_core::DllSafe<std::vector<const NodeArg*>> node_get_output_node_args(const
   auto ret = std::vector<const NodeArg*>(size);
   for (auto i = 0u; i < size; ++i) {
     auto output = outputs[i];
-    ret[i] = output;
     assert(output != nullptr);
-    vai_assert(output->Exists(), std::string("output must exists. name=" + output->Name()));
+    // Optional Outputs
+    // Some operators have outputs that are optional. When an actual output parameter of an operator is not specified, the operator implementation MAY forgo computing values for such outputs.
+    // There are two ways to leave an optional input or output unspecified: the first, available only for trailing inputs and outputs, is to simply not provide that input; the second method is to use an empty string in place of an input or output name.
+    // so optional output maybe output != null && output->Exists() return false
+    // Our processing : nullptr means optional output , and clinet code needs to handle nullptr
+    if (output->Exists()) {
+      ret[i] = output;
+    } else {
+      ret[i] = nullptr;
+    }
   }
   return vaip_core::DllSafe(ret);
 }
diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc
index 48dcd220a150c..671d852abb0d6 100644
--- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc
+++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc
@@ -22,6 +22,7 @@ gsl::span<const char> tensor_proto_as_raw(const ONNX_NAMESPACE::TensorProto& ten
     mut_tensor.clear_double_data();
     mut_tensor.clear_uint64_data();
     memcpy(mut_tensor.mutable_raw_data()->data(), unpacked_tensor.data(), unpacked_tensor.size());
+    mut_tensor.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_DEFAULT);
   }
   return gsl::span<const char>(tensor.raw_data().data(), tensor.raw_data().size());
 }
diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h
index d94729e60d029..d7892fe02c1ba 100644
--- a/onnxruntime/core/providers/webnn/builders/helper.h
+++ b/onnxruntime/core/providers/webnn/builders/helper.h
@@ -195,7 +195,7 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
     {"LessOrEqual", {"lesserOrEqual", false}},
     {"Log", {"log", false}},
     {"LpPool", {"l2Pool2d", false}},
-    {"MatMul", {"matmul", false}},
+    {"MatMul", {"matmul", true}},
     {"MatMulInteger", {"matmulInteger", false}},
     {"Max", {"max", true}},
     {"MaxPool", {"maxPool2d", true}},
diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc
index 4bf991a1b0105..ed320132169e9 100644
--- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc
+++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc
@@ -29,7 +29,7 @@ class GemmOpBuilder : public BaseOpBuilder {
 
 // Add operator related.
 Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
-                                            const logging::Logger& /* logger */) const {
+                                            const logging::Logger& logger) const {
   const auto& op_type = node.OpType();
   const auto& input_defs = node.InputDefs();
   const size_t a_idx = 0, b_idx = 1, c_idx = 2;  // A*B+C
@@ -38,7 +38,58 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
   emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name());
   emscripten::val output = emscripten::val::object();
   if (op_type == "MatMul") {
-    output = model_builder.GetBuilder().call<emscripten::val>("matmul", a, b);
+    std::vector<int64_t> a_shape;
+    if (!GetShape(*input_defs[a_idx], a_shape, logger)) {
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Can not get shape of A.");
+    }
+    std::vector<int64_t> b_shape;
+    if (!GetShape(*input_defs[b_idx], b_shape, logger)) {
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Can not get shape of B.");
+    }
+    // If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions.
+    bool extended_a_shape = false;
+    if (a_shape.size() == 1) {
+      extended_a_shape = true;
+      a_shape.insert(a_shape.begin(), 1);
+      a = model_builder.GetBuilder().call<emscripten::val>("reshape", a,
+                                                           emscripten::val::array(GetVecUint32FromVecInt64(a_shape)));
+    }
+    // If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions.
+    bool extended_b_shape = false;
+    if (b_shape.size() == 1) {
+      extended_b_shape = true;
+      b_shape.push_back(1);
+      b = model_builder.GetBuilder().call<emscripten::val>("reshape", b,
+                                                           emscripten::val::array(GetVecUint32FromVecInt64(b_shape)));
+    }
+    // The inputs of MatMul must be at least 3D for WebNN CPU backend. Use GEMM for 2D case.
+    // TODO: Remove this workaround when it is fixed in Chromium.
+    if (model_builder.GetWebnnDeviceType() == WebnnDeviceType::CPU && a_shape.size() == 2) {
+      output = model_builder.GetBuilder().call<emscripten::val>("gemm", a, b);
+    } else {
+      output = model_builder.GetBuilder().call<emscripten::val>("matmul", a, b);
+    }
+    // If the inputs are both 1D, reduce the output to a scalar.
+    if (extended_a_shape && extended_b_shape) {
+      output = model_builder.GetBuilder().call<emscripten::val>("reshape", output, emscripten::val::array());
+    }
+    // After matrix multiplication the prepended 1 is removed.
+    else if (extended_a_shape) {
+      std::vector<uint32_t> new_shape;
+      for (size_t i = 0; i < b_shape.size() - 2; i++) {
+        new_shape.push_back(narrow<uint32_t>(b_shape[i]));
+      }
+      new_shape.push_back(narrow<uint32_t>(b_shape.back()));
+      output = model_builder.GetBuilder().call<emscripten::val>("reshape", output, emscripten::val::array(new_shape));
+    }
+    // After matrix multiplication the appended 1 is removed.
+    else if (extended_b_shape) {
+      std::vector<uint32_t> new_shape;
+      for (size_t i = 0; i < a_shape.size() - 1; i++) {
+        new_shape.push_back(narrow<uint32_t>(a_shape[i]));
+      }
+      output = model_builder.GetBuilder().call<emscripten::val>("reshape", output, emscripten::val::array(new_shape));
+    }
   } else if (op_type == "MatMulInteger") {
     emscripten::val a_zero_point = emscripten::val::null();
     emscripten::val b_zero_point = emscripten::val::null();
@@ -81,44 +132,33 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
 
 bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
                                       const Node& node,
-                                      const WebnnDeviceType /* device_type */,
+                                      const WebnnDeviceType device_type,
                                       const logging::Logger& logger) const {
   (void)initializers;
   const auto& op_type = node.OpType();
   const auto& input_defs(node.InputDefs());
   const size_t a_idx = 0, b_idx = 1, c_idx = 2;  // A*B+C
 
-  if (op_type == "Gemm") {
-    std::vector<int64_t> a_shape;
-    {
-      if (!GetShape(*input_defs[a_idx], a_shape, logger))
-        return false;
-
-      if (a_shape.size() != 2) {
-        LOGS(logger, VERBOSE) << "A must be 2D";
-        return false;
-      }
-
-      if (Product(a_shape) == 0) {
-        LOGS(logger, VERBOSE) << "A must be non-empty";
-        return false;
-      }
-    }
-
-    std::vector<int64_t> b_shape;
-    {
-      if (!GetShape(*input_defs[b_idx], b_shape, logger))
-        return false;
+  std::vector<int64_t> a_shape;
+  if (!GetShape(*input_defs[a_idx], a_shape, logger))
+    return false;
+  if (Product(a_shape) == 0) {
+    LOGS(logger, VERBOSE) << "A must be non-empty";
+    return false;
+  }
 
-      if (b_shape.size() != 2) {
-        LOGS(logger, VERBOSE) << "B must be 2D";
-        return false;
-      }
+  std::vector<int64_t> b_shape;
+  if (!GetShape(*input_defs[b_idx], b_shape, logger))
+    return false;
+  if (Product(b_shape) == 0) {
+    LOGS(logger, VERBOSE) << "B must be non-empty";
+    return false;
+  }
 
-      if (Product(b_shape) == 0) {
-        LOGS(logger, VERBOSE) << "B must be non-empty";
-        return false;
-      }
+  if (op_type == "Gemm") {
+    if (a_shape.size() != 2 || b_shape.size() != 2) {
+      LOGS(logger, VERBOSE) << "A and B must be 2D for Gemm";
+      return false;
     }
 
     // C of Gemm.
@@ -152,6 +192,30 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
     }
   }
 
+  if (op_type == "MatMul") {
+    // If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions.
+    // If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions.
+    if (a_shape.size() == 1) a_shape.insert(a_shape.begin(), 1);
+    if (b_shape.size() == 1) b_shape.push_back(1);
+
+    // WebNN CPU backend has two more constraints.
+    // https://source.chromium.org/chromium/chromium/src/+/main:third_party/blink/renderer/modules/ml/webnn/ml_graph_xnnpack.cc;l=1177
+    // TODO: Remove this workaround when Chromium enables broadcast for MatMul on WebNN CPU backend.
+    if (device_type == WebnnDeviceType::CPU) {
+      if (a_shape.size() != b_shape.size()) {
+        LOGS(logger, VERBOSE) << "The rank of two inputs for WebNN CPU backend MatMul must be the same.";
+        return false;
+      }
+
+      for (size_t i = 0; i < a_shape.size() - 2; i++) {
+        if (a_shape[i] != b_shape[i]) {
+          LOGS(logger, VERBOSE) << "WebNN CPU backend can't support broadcasting for MatMul.";
+          return false;
+        }
+      }
+    }
+  }
+
   return true;
 }
 
diff --git a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc
index 52b5518857773..9852db0abc9d2 100644
--- a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc
+++ b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc
@@ -88,15 +88,15 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
     const auto& pads_tensor = *initializers.at(input_defs[1]->Name());
     ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(pads_tensor, pads, logger), "Error while read pads tensor");
 
-    // Constant value and axes are optional.
-    if (input_defs.size() >= 3) {
+    // Constant value and axes are optional. Make sure they are not empty.
+    if (!GetTensorName(input_defs, 2).empty()) {
       const auto value_tensor = *initializers.at(input_defs[2]->Name());
       emscripten::val value = emscripten::val::object();
       ORT_RETURN_IF_NOT(ReadScalarTensorData(value_tensor, value, logger), "Cannot read constant value");
       options.set("value", value);
     }
 
-    if (input_defs.size() == 4) {
+    if (!GetTensorName(input_defs, 3).empty()) {
       const auto input_rank = input_shape.size();
       std::vector<int64_t> axes;
       const auto& axes_tensor = *initializers.at(input_defs[3]->Name());
diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc
index f446a7b81d1c0..c0954f7cf6fb1 100644
--- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc
+++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc
@@ -65,7 +65,7 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
   if (opset >= 18 || (op_type == "ReduceSum" && opset >= 13)) {
     // 'axes' is an optional input.
     const auto noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0);
-    if (input_defs.size() > 1) {
+    if (!GetTensorName(input_defs, 1).empty()) {
       // Optional input axes is provided, use axes initializer data.
       const auto& initializers(model_builder.GetInitializerTensors());
       const auto& axes_tensor = *initializers.at(input_defs[1]->Name());
diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc
index 91f21b196be54..9819e4ce7ac5b 100644
--- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc
+++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc
@@ -57,7 +57,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
   axis = SafeInt<int32_t>(HandleNegativeAxis(axis, rank));
   options.set("axis", axis);
 
-  if (input_defs.size() == 2) {
+  if (!GetTensorName(input_defs, 1).empty()) {
     // Inputs contains optional 'split' input
     std::vector<int32_t> splits;
     const auto& initializers(model_builder.GetInitializerTensors());
diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc
index 15149bd8fe821..8e6feb62fa8c4 100644
--- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc
+++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc
@@ -58,7 +58,7 @@ Status SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil
   std::vector<int32_t> axes_data;
   auto rank = input_rank;
 
-  if (node.SinceVersion() >= 13 && input_defs.size() > 1) {
+  if (node.SinceVersion() >= 13 && !GetTensorName(input_defs, 1).empty()) {
     // Input axes is provided, use axes initializer data.
     const auto& initializers = model_builder.GetInitializerTensors();
     const auto& axes_tensor = *initializers.at(input_defs[1]->Name());
diff --git a/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc b/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc
index 8e7e228f974e6..e2d71cda68ec4 100644
--- a/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc
+++ b/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc
@@ -6,12 +6,12 @@
 #include <unordered_map>
 
 #include "core/common/common.h"
+#include "core/framework/node_unit.h"
 #include "core/framework/op_node_proto_helper.h"
 #include "core/graph/graph_utils.h"
 #include "core/graph/graph_viewer.h"
 #include "core/providers/common.h"
 #include "core/providers/cpu/nn/pool_attributes.h"
-#include "core/providers/shared/node_unit/node_unit.h"
 #include "core/providers/xnnpack/detail/utils.h"
 
 // each operator provides a helper to check if supported
diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc
index 1a32612981120..f9cb45ebc8abc 100644
--- a/onnxruntime/core/providers/xnnpack/detail/utils.cc
+++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc
@@ -6,14 +6,14 @@
 #include <vector>
 
 #include "core/common/common.h"
+#include "core/common/safeint.h"
+#include "core/framework/node_unit.h"
 #include "core/framework/tensorprotoutils.h"
 #include "core/graph/indexed_sub_graph.h"
 #include "core/graph/node_attr_utils.h"
+#include "core/optimizer/initializer.h"
 
-#include "core/providers/shared/node_unit/node_unit.h"
 #include "onnx/defs/attr_proto_util.h"
-#include "core/common/safeint.h"
-#include "core/optimizer/initializer.h"
 
 namespace onnxruntime {
 namespace xnnpack {
diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.h b/onnxruntime/core/providers/xnnpack/detail/utils.h
index 2bbf3ac8c2cb5..d555ee2286b84 100644
--- a/onnxruntime/core/providers/xnnpack/detail/utils.h
+++ b/onnxruntime/core/providers/xnnpack/detail/utils.h
@@ -10,10 +10,10 @@
 #include <string>
 #include <utility>
 
+#include "core/framework/node_unit.h"
 #include "core/framework/op_kernel.h"
 #include "core/graph/indexed_sub_graph.h"
 #include "core/providers/common.h"
-#include "core/providers/shared/node_unit/node_unit.h"
 
 #include "xnnpack.h"
 
diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc
index 0c9e2e9fc17a2..09666c8039402 100644
--- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc
+++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc
@@ -288,7 +288,7 @@ Status Resize::Compute(OpKernelContext* ctx) const {
 
     // Get scales data
     const auto* scales = ctx->Input<Tensor>(scales_input_idx_);
-    std::vector<float> scales_array(X->Shape().GetDims().size());
+    InlinedVector<float> scales_array(X->Shape().GetDims().size());
 
     if (scales != nullptr && scales->Shape().Size() != 0) {
       ORT_RETURN_IF_ERROR(ParseScalesData(scales, scales_array, output_shape.size()));
diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc
index eafbfae6f01e1..12e567e7080b3 100644
--- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc
+++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc
@@ -6,17 +6,17 @@
 #include <unordered_set>
 #include <utility>
 
-#include "core/graph/function_utils.h"
-#include "xnnpack_execution_provider.h"
-#include "detail/utils.h"
-#include "detail/node_support_checker.h"
-
 #include "core/framework/compute_capability.h"
 #include "core/framework/kernel_registry.h"
-#include "core/providers/shared/node_unit/node_unit.h"
+#include "core/framework/node_unit.h"
+#include "core/graph/function_utils.h"
 #include "core/session/onnxruntime_session_options_config_keys.h"
-
-#include "xnnpack_init.h"
+#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
+#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
+#include "core/providers/xnnpack/xnnpack_execution_provider.h"
+#include "core/providers/xnnpack/detail/utils.h"
+#include "core/providers/xnnpack/detail/node_support_checker.h"
+#include "core/providers/xnnpack/xnnpack_init.h"
 
 namespace onnxruntime {
 
@@ -268,7 +268,7 @@ std::vector<std::unique_ptr<ComputeCapability>> XnnpackExecutionProvider::GetCap
   // Get all the NodeUnits in the GraphViewer so we can check if something is in a QDQ node group
   std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
   std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
-  std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph);
+  std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph);
 
   // This holds the result of whether a NodeUnit is supported or not,
   // to prevent nodes in a NodeUnit being checked for multiple times
diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc
index 7a233c57cfdf3..513aafcdadb7d 100644
--- a/onnxruntime/core/session/custom_ops.cc
+++ b/onnxruntime/core/session/custom_ops.cc
@@ -736,6 +736,32 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetLogger, _In_ const OrtKernelInfo* inf
   });
 }
 
+ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out) {
+  return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
+    onnxruntime::AllocatorPtr allocator = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAllocator(mem_type);
+    if (!allocator) {
+      return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available");
+    }
+    auto p = std::make_unique<onnxruntime::OrtAllocatorImplWrappingIAllocator>(std::move(allocator));
+    *out = p.release();
+    return nullptr;
+  });
+}
+
+ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) {
+  if (count_or_bytes == 0) {
+    *out = nullptr;
+    return nullptr;
+  }
+  onnxruntime::AllocatorPtr allocator = reinterpret_cast<const onnxruntime::OpKernelContext*>(context)->GetAllocator(mem_info->device);
+  if (!allocator) {
+    return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available");
+  }
+  onnxruntime::Stream* stream = reinterpret_cast<const onnxruntime::OpKernelContext*>(context)->GetComputeStream();
+  *out = AllocateBufferWithOptions(*allocator, count_or_bytes, false, stream, stream->GetWaitNotificationFn());
+  return nullptr;
+};
+
 #if ENABLE_CUSTOM_OP_API
 #include "core/framework/customregistry.h"
 namespace onnxruntime {
@@ -1040,59 +1066,120 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o
   return Status::OK();
 }
 
-void InferOutputTypes(const InlinedVector<const KernelDef*>& kernel_defs,
-                      ONNX_NAMESPACE::InferenceContext& infer_ctx) {
-  for (const auto& kernel_def : kernel_defs) {
+// This function attempts to do its best for older custom ops (most of them) who do not have
+// they own type and shape inference function. However, it falls short in some cases, and we leave
+// those for the user to handle in their own inference function.
+static void InferOutputTypes(const ONNX_NAMESPACE::OpSchema& schema, gsl::span<const KernelDef* const> kernel_defs,
+                             ONNX_NAMESPACE::InferenceContext& infer_ctx) {
+  const auto& inputs = schema.inputs();
+  const auto node_input_num = infer_ctx.getNumInputs();
+
+  const KernelDef* def_selected = nullptr;
+  bool is_variadic_input = false;
+  bool is_homogeneous_input = false;
+  int32_t output_propagate{0};
+
+  for (size_t kernel_index = 0;
+       kernel_index < kernel_defs.size() && def_selected == nullptr;
+       ++kernel_index) {
+    const auto* kernel_def = kernel_defs[kernel_index];
     const auto& type_constraints = kernel_def->TypeConstraints();
-    auto num_inputs = infer_ctx.getNumInputs();
-    bool matched = true;
-    ONNXTensorElementDataType undef = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
-    // first, make sure there is a constraint for every input
-    for (size_t i = 0; i < num_inputs && matched; ++i) {
-      auto input_name = "Input" + std::to_string(i);
-      auto input_type = infer_ctx.getInputType(i);
-      if (input_type) {
-        auto elem_type = static_cast<ONNXTensorElementDataType>(input_type->tensor_type().elem_type());
-        auto tc_iter = type_constraints.find(input_name);
-        if (tc_iter != type_constraints.end()) {
-          if (tc_iter->second.size() > 1) {
-            undef = elem_type;
-          } else if (tc_iter->second.size() != 1 ||
-                     tc_iter->second[0] != DataTypeImpl::TensorTypeFromONNXEnum(elem_type)) {
-            matched = false;
+    def_selected = kernel_def;
+
+    for (size_t i = 0; i < node_input_num; ++i) {
+      const auto input_type = infer_ctx.getInputType(i);
+
+      // Guard against variadic parameter index
+      const size_t schema_input_index = (i < inputs.size()) ? i : inputs.size() - 1;
+      const auto& param = inputs[schema_input_index];
+      const auto& input_name = param.GetName();
+      if (input_type == nullptr) {
+        if (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Optional)
+          continue;
+
+        ORT_THROW("[CustomOP type inferencing error]: kernel Input: ", input_name,
+                  " is absent, but not optional. Op : ", schema.Name());
+      }
+
+      is_variadic_input = (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Variadic);
+      is_homogeneous_input = param.GetIsHomogeneous();
+
+      if (!is_variadic_input || is_homogeneous_input) {
+        auto hit = type_constraints.find(input_name);
+        if (hit != type_constraints.end()) {
+          const auto& types = hit->second;
+          // For custom ops kernel constraints are never empty
+          assert(!types.empty());
+          if (!std::any_of(types.cbegin(), types.cend(),
+                           [input_type](const DataTypeImpl* type) {
+                             return type->IsCompatible(*input_type);
+                           })) {
+            def_selected = nullptr;
+            output_propagate = 0;
+            break;
+          }
+
+          // If we have multiple types possible from the constraints,
+          // record the last type and use it to guess the output type if
+          // output may have different types. Works well for symmetric single input/outputs
+          // otherwise give up and let the user supply their own function
+          if (types.size() > 1) {
+            output_propagate = input_type->tensor_type().elem_type();
           }
         } else {
-          matched = false;
+          ORT_THROW("[CustomOP type inferencing error]: no type constraint found for input: ",
+                    input_name, " Op: ", schema.Name());
         }
-      } else {
-        matched = false;
-      }
-    }  // for
-    // next, ensure that there is a constraint for every output
-    auto num_outputs = infer_ctx.getNumOutputs();
-    for (size_t i = 0; i < num_outputs && matched; i++) {
-      auto output_name = "Output" + std::to_string(i);
-      auto tc_iter = type_constraints.find(output_name);
-      if (tc_iter == type_constraints.end() || tc_iter->second.size() < 1) {
-        matched = false;
       }
     }
-    if (matched) {
-      for (size_t i = 0; i < num_outputs; i++) {
-        auto output_name = "Output" + std::to_string(i);
-        auto output_type = infer_ctx.getOutputType(i);
-        auto tc_iter = type_constraints.find(output_name);
-        if (tc_iter->second.size() > 1) {
-          output_type->mutable_tensor_type()->set_elem_type(undef);
-        } else {
-          output_type->mutable_tensor_type()->set_elem_type(
-              tc_iter->second[0]->GetTypeProto()->tensor_type().elem_type());
-        }
-      }
+  }
+
+  if (def_selected == nullptr) {
+    ORT_THROW("[CustomOP type inferencing error]: no kernel def matches node inputs for Op: ", schema.Name());
+  }
+
+  const auto& outputs = schema.outputs();
+  const auto node_output_num = infer_ctx.getNumOutputs();
+  const auto& selected_type_constraints = def_selected->TypeConstraints();
+
+  for (size_t i = 0; i < node_output_num; ++i) {
+    auto output_type = infer_ctx.getOutputType(i);
+    // Account for variadic outputs
+    const size_t schema_output_index = (i < outputs.size()) ? i : outputs.size() - 1;
+    const auto& param = outputs[schema_output_index];
+    const auto& output_name = param.GetName();
+
+    const bool is_variadic_output = (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Variadic);
+    const bool is_homogeneous = param.GetIsHomogeneous();
+
+    // We give up on variadic non-homogeneous outputs
+    // Let the user handle it in their inference function
+    if (is_variadic_output && !is_homogeneous) {
       break;
     }
+
+    auto hit = selected_type_constraints.find(output_name);
+    if (hit != selected_type_constraints.end()) {
+      const auto& types = hit->second;
+      assert(!types.empty());
+
+      if (types.size() == 1) {
+        // Use the constraint type
+        output_type->mutable_tensor_type()->set_elem_type(
+            types[0]->GetTypeProto()->tensor_type().elem_type());
+      } else if (!is_variadic_input || is_homogeneous_input) {
+        // If not variadic or homogeneous, and there are multiple types possible, guess from the last input type
+        // as this works for symmetric varied single input/outputs
+        // otherwise give up and let the user supply their own function
+        output_type->mutable_tensor_type()->set_elem_type(output_propagate);
+      }
+    } else {
+      ORT_THROW("[CustomOP type inferencing error]: no type constraint found for output: ",
+                output_name, " Op: ", schema.Name());
+    }
   }
 }
+
 #endif
 
 common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domains,
@@ -1152,13 +1239,13 @@ common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domai
     }
 
     std::vector<ONNX_NAMESPACE::OpSchema> schemas;
-    for (auto schema_iter : schema_map) {
-      schemas.push_back(schema_iter.second);
-      InlinedVector<const KernelDef*> kernel_defs = std::move(kernel_def_map[schema_iter.first]);
+    for (auto& [name, schema] : schema_map) {
+      schemas.push_back(schema);
       auto infer_fn = schemas.back().GetTypeAndShapeInferenceFunction();
       ONNX_NAMESPACE::InferenceFunction extended_infer_fn =
-          [infer_fn, kernel_defs](ONNX_NAMESPACE::InferenceContext& infer_ctx) {
-            InferOutputTypes(kernel_defs, infer_ctx);
+          [sch = schema, infer_fn = std::move(infer_fn),
+           kernel_defs = std::move(kernel_def_map[name])](ONNX_NAMESPACE::InferenceContext& infer_ctx) {
+            InferOutputTypes(sch, kernel_defs, infer_ctx);
             if (infer_fn) {
               infer_fn(infer_ctx);
             }
diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc
index 80a0cb673c199..318c76645bdf5 100644
--- a/onnxruntime/core/session/environment.cc
+++ b/onnxruntime/core/session/environment.cc
@@ -240,12 +240,10 @@ Status Environment::Initialize(std::unique_ptr<logging::LoggingManager> logging_
 // Register contributed schemas.
 // The corresponding kernels are registered inside the appropriate execution provider.
 #ifndef DISABLE_CONTRIB_OPS
-#ifndef ORT_MINIMAL_BUILD
       RegisterOpSetSchema<contrib::OpSet_Microsoft_ver1>();
       RegisterOpSetSchema<contrib::OpSet_ONNX_Deprecated>();
       // internal opset that has NHWC versions of ONNX operators
       RegisterOpSetSchema<internal_nhwc_onnx::OpSet_Internal_NHWC_ONNX>();
-#endif
       contrib::RegisterContribSchemas();
 #endif
 
diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc
index b045f30a59797..ece224ef206fc 100644
--- a/onnxruntime/core/session/inference_session.cc
+++ b/onnxruntime/core/session/inference_session.cc
@@ -60,6 +60,7 @@
 #include "core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h"
 #include "core/providers/dml/dml_session_options_config_keys.h"
 #include "core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h"
+#include "core/optimizer/stft_decomposition.h"
 #endif
 #include "core/session/environment.h"
 #include "core/session/user_logging_sink.h"
@@ -1725,10 +1726,17 @@ common::Status InferenceSession::Initialize() {
         // graph optimization level and is generally always applied.
         bool dml_graph_fusion_enabled = session_options_.optimized_model_filepath.empty() &&
                                         session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigDisableDmlGraphFusion, "0") == "0";
+        std::string dml_graph_serialization_enabled_config_val = session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigEnableGraphSerialization, "0");
+        std::transform(dml_graph_serialization_enabled_config_val.begin(),
+                       dml_graph_serialization_enabled_config_val.end(),
+                       dml_graph_serialization_enabled_config_val.begin(),
+                       [](char ch) { return std::tolower(ch); });
+        bool dml_graph_serialization_enabled = dml_graph_serialization_enabled_config_val == "true";
 
         if (dml_graph_fusion_enabled) {
           std::unique_ptr<onnxruntime::GraphTransformer> dmlGraphFusionTransformer = std::make_unique<Dml::DmlGraphFusionTransformer>("DmlGraphFusionTransformer",
-                                                                                                                                      dmlExecutionProvider);
+                                                                                                                                      dmlExecutionProvider,
+                                                                                                                                      dml_graph_serialization_enabled);
           if (dmlGraphFusionTransformer == nullptr) {
             return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr");
           }
@@ -1754,6 +1762,14 @@ common::Status InferenceSession::Initialize() {
           }
           ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlOperatorFusionTransformer), onnxruntime::TransformerLevel::Level2));
         }
+
+        const auto dml_ep_impl = static_cast<const Dml::ExecutionProvider*>(dmlExecutionProvider);
+        auto is_mcdm_device = dml_ep_impl->GetImpl()->IsMcdmDevice();
+        if (is_mcdm_device) {
+          const InlinedHashSet<std::string_view> dml_ep = {onnxruntime::kDmlExecutionProvider};
+          auto stft_decomposition_transformer = std::make_unique<STFTDecomposition>(dml_ep);
+          ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(stft_decomposition_transformer), onnxruntime::TransformerLevel::Level1));
+        }
       }
 #endif
 
@@ -2289,8 +2305,8 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options,
     // TODO: only call OnRunStart for all providers in-use
     for (auto& xp : execution_providers_) {
       // call OnRunStart and add to exec_providers_to_stop if successful
-      auto start_func = [&xp, &exec_providers_to_stop]() {
-        auto status = xp->OnRunStart();
+      auto start_func = [&xp, &exec_providers_to_stop, run_options]() {
+        auto status = xp->OnRunStart(run_options);
         if (status.IsOK())
           exec_providers_to_stop.push_back(xp.get());
 
@@ -2326,7 +2342,7 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options,
 
   // info all execution providers InferenceSession:Run ended
   for (auto* xp : exec_providers_to_stop) {
-    auto status = xp->OnRunEnd(/*sync_stream*/ false);
+    auto status = xp->OnRunEnd(/*sync_stream*/ false, run_options);
     ORT_CHECK_AND_SET_RETVAL(status);
   }
 
@@ -2376,21 +2392,32 @@ Status InferenceSession::Run(const RunOptions& run_options,
   Status retval = Status::OK();
   const Env& env = Env::Default();
 
+  int graph_annotation_id = 0;
+  const std::string& graph_annotation_str =
+      run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigCudaGraphAnnotation, "");
+  if (!graph_annotation_str.empty()) {
+    if (!TryParseStringWithClassicLocale<int>(graph_annotation_str, graph_annotation_id)) {
+      return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to parse the cuda graph annotation id: ",
+                             graph_annotation_str);
+    }
+  }
+
   // Increment/decrement concurrent_num_runs_ and control
   // session threads spinning as configured. Do nothing for graph replay except the counter.
   const bool control_spinning = use_per_session_threads_ &&
                                 force_spinning_stop_between_runs_ &&
-                                !cached_execution_provider_for_graph_replay_.IsGraphCaptured();
+                                !cached_execution_provider_for_graph_replay_.IsGraphCaptured(graph_annotation_id);
   auto* intra_tp = (control_spinning) ? thread_pool_.get() : nullptr;
   auto* inter_tp = (control_spinning) ? inter_op_thread_pool_.get() : nullptr;
   ThreadPoolSpinningSwitch runs_refcounter_and_tp_spin_control(intra_tp, inter_tp, current_num_runs_);
 
   // Check if this Run() is simply going to be a CUDA Graph replay.
-  if (cached_execution_provider_for_graph_replay_.IsGraphCaptured()) {
+  if (cached_execution_provider_for_graph_replay_.IsGraphCaptured(graph_annotation_id)) {
     LOGS(*session_logger_, INFO) << "Replaying the captured "
                                  << cached_execution_provider_for_graph_replay_.Type()
-                                 << " CUDA Graph for this model with tag: " << run_options.run_tag;
-    ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph());
+                                 << " CUDA Graph for this model with tag: " << run_options.run_tag
+                                 << " with graph annotation id: " << graph_annotation_id;
+    ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph(graph_annotation_id));
   } else {
     InlinedVector<IExecutionProvider*> exec_providers_to_stop;
     exec_providers_to_stop.reserve(execution_providers_.NumProviders());
@@ -2448,8 +2475,8 @@ Status InferenceSession::Run(const RunOptions& run_options,
       // TODO: only call OnRunStart for all providers in-use
       for (auto& xp : execution_providers_) {
         // call OnRunStart and add to exec_providers_to_stop if successful
-        auto start_func = [&xp, &exec_providers_to_stop]() {
-          auto status = xp->OnRunStart();
+        auto start_func = [&xp, &exec_providers_to_stop, &run_options]() {
+          auto status = xp->OnRunStart(run_options);
           if (status.IsOK())
             exec_providers_to_stop.push_back(xp.get());
 
@@ -2490,7 +2517,7 @@ Status InferenceSession::Run(const RunOptions& run_options,
       // info all execution providers InferenceSession:Run ended
       for (auto* xp : exec_providers_to_stop) {
         bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
-        auto status = xp->OnRunEnd(synchronize_execution_providers);
+        auto status = xp->OnRunEnd(synchronize_execution_providers, run_options);
         ORT_CHECK_AND_SET_RETVAL(status);
       }
 
@@ -2552,7 +2579,8 @@ Status InferenceSession::Run(const RunOptions& run_options,
   // N is defined in min_num_runs_before_hip_graph_capture_ for ROCM EP,
   // and the value could be different for other EP.
   if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() &&
-      !cached_execution_provider_for_graph_replay_.IsGraphCaptured()) {
+      cached_execution_provider_for_graph_replay_.AllowGraphCaptureOnRun(graph_annotation_id) &&
+      !cached_execution_provider_for_graph_replay_.IsGraphCaptured(graph_annotation_id)) {
     LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture.";
     ORT_RETURN_IF_ERROR(Run(run_options, feed_names, feeds, output_names, p_fetches, p_fetches_device_info));
   }
diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h
index f8211bfd2dd4e..3038c8d22ec80 100644
--- a/onnxruntime/core/session/inference_session.h
+++ b/onnxruntime/core/session/inference_session.h
@@ -675,7 +675,6 @@ class InferenceSession {
    * If we encounter an invalid request, we return an error
    * back to the user.
    */
-
   [[nodiscard]] common::Status ValidateAndParseShrinkArenaString(const std::string& ort_device_list,
                                                                  /*out*/ InlinedVector<AllocatorPtr>& arenas_to_shrink) const;
 
@@ -867,14 +866,17 @@ class InferenceSession {
       return cached_execution_provider_for_graph_replay_ != nullptr && cached_execution_provider_for_graph_replay_->IsGraphCaptureEnabled();
     }
 
-    bool IsGraphCaptured() const {
-      return cached_execution_provider_for_graph_replay_ != nullptr && cached_execution_provider_for_graph_replay_->IsGraphCaptured();
+    bool IsGraphCaptured(int graph_annotation_id) const {
+      return cached_execution_provider_for_graph_replay_ != nullptr && cached_execution_provider_for_graph_replay_->IsGraphCaptured(graph_annotation_id);
+    }
+
+    bool AllowGraphCaptureOnRun(int graph_annotation_id) const {
+      return cached_execution_provider_for_graph_replay_ != nullptr && graph_annotation_id != kGraphAnnotationSkip;
     }
 
-    Status ReplayGraph() {
-      ORT_ENFORCE(IsGraphCaptured());
+    Status ReplayGraph(int graph_annotation_id) {
       if (cached_execution_provider_for_graph_replay_) {
-        return cached_execution_provider_for_graph_replay_->ReplayGraph();
+        return cached_execution_provider_for_graph_replay_->ReplayGraph(graph_annotation_id);
       }
       return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cached EP instance for graph replay is not set yet before calling ReplayGraph()");
     }
@@ -884,6 +886,8 @@ class InferenceSession {
     }
 
     IExecutionProvider* cached_execution_provider_for_graph_replay_ = nullptr;
+    // TODO(wy): Same as kCudaGraphAnnotationSkip in cuda_graph.h. Move to a common place.
+    constexpr static int kGraphAnnotationSkip = -1;
   };
 
   CachedExecutionProviderForGraphReplay cached_execution_provider_for_graph_replay_;
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index dec8754ea244f..270b3490689c4 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -2725,6 +2725,8 @@ static constexpr OrtApi ort_api_1_to_18 = {
     &OrtApis::KernelContext_ParallelFor,
     &OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2,
     &OrtApis::SessionOptionsAppendExecutionProvider_VitisAI,
+    &OrtApis::KernelContext_GetScratchBuffer,
+    &OrtApis::KernelInfoGetAllocator,
 };
 
 // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
index 9ce94ba89a942..3591c96234aa3 100644
--- a/onnxruntime/core/session/ort_apis.h
+++ b/onnxruntime/core/session/ort_apis.h
@@ -513,4 +513,8 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO_V2,
 ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options,
                     _In_reads_(num_keys) const char* const* provider_options_keys,
                     _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);
+
+ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
+
+ORT_API_STATUS_IMPL(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out);
 }  // namespace OrtApis
diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc
index 3bec9aa146f76..c7cf5963fa10f 100644
--- a/onnxruntime/core/session/provider_bridge_ort.cc
+++ b/onnxruntime/core/session/provider_bridge_ort.cc
@@ -14,6 +14,7 @@
 #include "core/framework/execution_provider.h"
 #include "core/framework/kernel_registry.h"
 #include "core/framework/provider_shutdown.h"
+#include "core/framework/run_options.h"
 #include "core/framework/tensorprotoutils.h"
 #include "core/framework/TensorSeq.h"
 #include "core/framework/provider_options.h"
@@ -527,6 +528,7 @@ struct ProviderHostImpl : ProviderHost {
   void TensorProto__add_dims(ONNX_NAMESPACE::TensorProto* p, int64_t value) override { p->add_dims(value); }
   bool TensorProto__has_data_location(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_data_location(); }
   int TensorProto__data_location(const ONNX_NAMESPACE::TensorProto* p) override { return p->data_location(); }
+  void TensorProto__set_data_location(ONNX_NAMESPACE::TensorProto* p, ONNX_NAMESPACE::TensorProto_DataLocation data_location) override { return p->set_data_location(data_location); }
   bool TensorProto__has_raw_data(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_raw_data(); }
   const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) override { return p->raw_data(); }
   std::string* TensorProto__mutable_raw_data(ONNX_NAMESPACE::TensorProto* p) override { return p->mutable_raw_data(); }
@@ -676,6 +678,9 @@ struct ProviderHostImpl : ProviderHost {
     return p->GetConfigEntry(config_key);
   }
 
+  // OrtRunOptions (wrapped)
+  const ConfigOptions& RunOptions__GetConfigOptions(const RunOptions* p) override { return p->config_options; }
+
   // ComputeCapability (wrapped)
   std::unique_ptr<ComputeCapability> ComputeCapability__construct(std::unique_ptr<IndexedSubGraph> t_sub_graph) override { return std::make_unique<ComputeCapability>(std::move(t_sub_graph)); }
   void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; }
@@ -962,8 +967,9 @@ struct ProviderHostImpl : ProviderHost {
 
   // Model (wrapped)
   std::unique_ptr<Model> Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path,
+                                          const IOnnxRuntimeOpSchemaRegistryList* local_registries,
                                           const logging::Logger& logger) override {
-    return std::make_unique<Model>(model_proto, model_path, nullptr, logger);
+    return std::make_unique<Model>(model_proto, model_path, local_registries, logger);
   }
   void Model__operator_delete(Model* p) override { delete p; }
   Graph& Model__MainGraph(Model* p) override { return p->MainGraph(); }
@@ -1043,6 +1049,7 @@ struct ProviderHostImpl : ProviderHost {
   Node* Graph__GetNode(Graph* p, NodeIndex node_index) noexcept override { return p->GetNode(node_index); }
   const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const override { return p->GetNode(node_index); }
   const NodeArg* Graph__GetNodeArg(const Graph* p, const std::string& name) const override { return p->GetNodeArg(name); }
+  IOnnxRuntimeOpSchemaCollectionPtr Graph__GetSchemaRegistry(const Graph* p) const override { return p->GetSchemaRegistry(); }
 
   // GraphViewer (wrapped)
   void GraphViewer__operator_delete(GraphViewer* p) override { delete p; }
diff --git a/onnxruntime/core/util/matrix_layout.h b/onnxruntime/core/util/matrix_layout.h
index a0405e32034ae..783a29d8a2055 100644
--- a/onnxruntime/core/util/matrix_layout.h
+++ b/onnxruntime/core/util/matrix_layout.h
@@ -17,7 +17,6 @@
 #include <cstdint>
 #include "core/common/gsl.h"
 
-// TODO!! Already have this in cuda, what about cpu code though?
 #if defined(_MSC_VER)
 #define ORT_FORCEINLINE __forceinline
 #else
diff --git a/onnxruntime/core/util/thread_utils.cc b/onnxruntime/core/util/thread_utils.cc
index a5a165e150cf1..2a6c14ff1b058 100644
--- a/onnxruntime/core/util/thread_utils.cc
+++ b/onnxruntime/core/util/thread_utils.cc
@@ -93,22 +93,31 @@ static std::unique_ptr<ThreadPool>
 CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options) {
   ThreadOptions to;
   if (options.thread_pool_size <= 0) {  // default
-    auto default_affinities = Env::Default().GetDefaultThreadAffinities();
-    if (default_affinities.size() <= 1) {
-      return nullptr;
-    }
-    options.thread_pool_size = static_cast<int>(default_affinities.size());
     if (options.auto_set_affinity) {
 #ifdef _WIN32
       // Only set thread affinity on Server with auto affinity.
       // On client best to let OS scheduler handle.
       // On big (P-Core) / little (E-Core) CPU designs affinity overrides QoS and has high power usage
       if (IsWindowsServer()) {
+        auto default_affinities = Env::Default().GetDefaultThreadAffinities();
+        if (default_affinities.size() <= 1) {
+          return nullptr;
+        }
+        options.thread_pool_size = static_cast<int>(default_affinities.size());
         to.affinities = std::move(default_affinities);
+      } else {
+        options.thread_pool_size = Env::Default().GetNumPhysicalCpuCores();
       }
 #else
+      auto default_affinities = Env::Default().GetDefaultThreadAffinities();
+      if (default_affinities.size() <= 1) {
+        return nullptr;
+      }
+      options.thread_pool_size = static_cast<int>(default_affinities.size());
       to.affinities = std::move(default_affinities);
 #endif
+    } else {
+      options.thread_pool_size = Env::Default().GetNumPhysicalCpuCores();
     }
   }
   if (options.thread_pool_size <= 1) {
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index 9c36eb635ffcf..7b56f0c68427a 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -937,6 +937,20 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
             ORT_THROW("Invalid value passed for disable_dynamic_shapes: ", option.second);
           }
           OV_provider_options_map[option.first] = option.second;
+        } else if (option.first == "enable_dynamic_shapes") {
+          LOGS_DEFAULT(WARNING) << " Deprecation notice - 'enable_dynamic_shapes' is Deprected. Upgrade the API to disable_dynamic_shapes parameter."
+                                   "Please refer https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements to ensure all dependencies are met.";
+          std::string value;
+          if (!(option.second == "True" || option.second == "true" ||
+                option.second == "False" || option.second == "false")) {
+            ORT_THROW("Invalid value passed for enable_dynamic_shapes: ", option.second);
+          }
+          if (option.second == "True" || option.second == "true") {
+            value = "false";
+          } else {
+            value = "true";
+          }
+          OV_provider_options_map["disable_dynamic_shapes"] = value;
         } else if (option.first == "device_id") {
           OV_provider_options_map[option.first] = option.second;
           continue;
@@ -967,7 +981,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
       if (!Env::Default().GetEnvironmentVar("INTEL_OPENVINO_DIR").empty()) {
         ORT_THROW("INTEL_OPENVINO_DIR is set but OpenVINO library wasn't able to be loaded. Please install a supported version of OpenVINO as mentioned in the requirements page (https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements), ensure dependency libraries are in the PATH and your hardware is supported.");
       } else {
-        LOGS_DEFAULT(WARNING) << "Failed to create " << type << ". Please reference https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements to ensure all dependencies are met.";
+        LOGS_DEFAULT(WARNING) << "Failed to create " << type << ". Please refer https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements to ensure all dependencies are met.";
       }
     }
 #endif
@@ -1327,14 +1341,14 @@ void addGlobalMethods(py::module& m) {
 
 #ifdef ENABLE_ATEN
   m.def("register_aten_op_executor",
-        [](const std::string& is_cpu_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
-          size_t is_cpu_argument_address_int, aten_op_executor_address_int;
+        [](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
+          size_t is_tensor_argument_address_int, aten_op_executor_address_int;
           ORT_THROW_IF_ERROR(
-              ParseStringWithClassicLocale(is_cpu_argument_address_str, is_cpu_argument_address_int));
+              ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int));
           ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int));
-          void* p_is_cpu_argument = reinterpret_cast<void*>(is_cpu_argument_address_int);
+          void* p_is_tensor_argument = reinterpret_cast<void*>(is_tensor_argument_address_int);
           void* p_aten_op_executor = reinterpret_cast<void*>(aten_op_executor_address_int);
-          contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_cpu_argument, p_aten_op_executor);
+          contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor);
         });
 #endif
 }
diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h
index 6827f2c9dfd91..22314610dbee9 100644
--- a/onnxruntime/python/onnxruntime_pybind_state_common.h
+++ b/onnxruntime/python/onnxruntime_pybind_state_common.h
@@ -60,11 +60,8 @@ struct OrtStatus {
 #elif OPENVINO_CONFIG_GPU_FP16
 #define BACKEND_OPENVINO "-OPENVINO_GPU_FP16"
 
-#elif OPENVINO_CONFIG_NPU_FP16
-#define BACKEND_OPENVINO "-OPENVINO_NPU_FP16"
-
-#elif OPENVINO_CONFIG_NPU_U8
-#define BACKEND_OPENVINO "-OPENVINO_NPU_U8"
+#elif OPENVINO_CONFIG_NPU
+#define BACKEND_OPENVINO "-OPENVINO_NPU"
 
 #elif OPENVINO_CONFIG_MULTI
 #define BACKEND_OPENVINO "-OPENVINO_MULTI"
diff --git a/onnxruntime/python/onnxruntime_validation.py b/onnxruntime/python/onnxruntime_validation.py
index 16cbc8e8099e1..10d9f469863c4 100644
--- a/onnxruntime/python/onnxruntime_validation.py
+++ b/onnxruntime/python/onnxruntime_validation.py
@@ -22,7 +22,7 @@ def check_distro_info():
         __my_distro__ = __my_system__
         __my_distro_ver__ = platform.release().lower()
 
-        if __my_distro_ver__ != "10":
+        if __my_distro_ver__ not in ["10", "11"]:
             warnings.warn(
                 "Unsupported Windows version (%s). ONNX Runtime supports Windows 10 and above, only."
                 % __my_distro_ver__
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py
index e32cb032798fc..400a9d8a7a187 100644
--- a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py
@@ -35,7 +35,11 @@ def sigmoid_function(x):
     return 1.0 / (1.0 + np.exp(-x))
 
 
-def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish):
+def group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, with_silu, has_skip):
+    add_output = None
+    if has_skip:
+        input_x = input_x + skip_x + bias_x
+        add_output = input_x
     n, h, w, c = input_x.shape
     input_x = input_x.transpose([0, 3, 1, 2])
     assert c % num_groups == 0
@@ -45,46 +49,82 @@ def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish):
     x = x.transpose([0, 2, 3, 1])
     x = x * gamma + beta
 
-    if with_swish:
+    if with_silu:
         x = x * sigmoid_function(x)
-    return x
+    return x, add_output
 
 
-def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func):
+def run_group_norm(
+    batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, silu: bool, has_skip: bool, func
+):
     np.random.seed(0)
     width = height
     input_x = np.random.rand(batch_size, height, width, num_channels).astype(np.float32)
     gamma = np.random.rand(num_channels).astype(np.float32)
     beta = np.random.rand(num_channels).astype(np.float32)
     # the size of workspace is defined in onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h L18
-    workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * 32 * 32).astype(np.float32)
+    workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * batch_size * num_groups).astype(np.float32)
     epsilon = 1e-05
     output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
-    use_swish = swish
 
-    host_x = input_x.astype(dtype)
-    input_d = ke.DeviceArray(host_x)
+    skip_x = (
+        np.random.rand(batch_size, height, width, num_channels).astype(np.float32)
+        if has_skip
+        else np.empty((0), dtype=dtype)
+    )
+    bias_x = np.random.rand(num_channels).astype(np.float32) if has_skip else np.empty((0), dtype=dtype)
+    add_output = (
+        np.random.rand(batch_size, height, width, num_channels).astype(dtype)
+        if has_skip
+        else np.empty((0), dtype=dtype)
+    )
+    use_silu = silu
+    broadcast_skip = False
+    if has_skip:
+        skip_x_shape = skip_x.shape
+        b2 = len(skip_x_shape) == 2 and skip_x_shape[0] == batch_size and skip_x_shape[1] == num_channels
+        b4 = (
+            len(skip_x_shape) == 4
+            and skip_x_shape[0] == batch_size
+            and skip_x_shape[1] == 1
+            and skip_x_shape[2] == 1
+            and skip_x_shape[3] == num_channels
+        )
+        if b2 or b4:
+            broadcast_skip = True
+    channels_per_block = 0  # Compute in params initialization
+
+    input_d = ke.DeviceArray(input_x.astype(dtype))
+    skip_d = ke.DeviceArray(skip_x.astype(dtype))
+    bias_d = ke.DeviceArray(bias_x.astype(dtype))
     gamma_d = ke.DeviceArray(gamma)
     beta_d = ke.DeviceArray(beta)
     workspace_d = ke.DeviceArray(workspace)
     y_d = ke.DeviceArray(output_y)
+    y_add_d = ke.DeviceArray(add_output)
     f = getattr(ke, func)
 
     my_op = f(
         y_d,
-        workspace_d,
+        y_add_d,
         input_d,
+        skip_d,
+        bias_d,
         gamma_d,
         beta_d,
+        workspace_d,
+        epsilon,
         batch_size,
+        num_channels,
         height,
         width,
-        num_channels,
         num_groups,
-        epsilon,
-        use_swish,
+        use_silu,
+        broadcast_skip,
+        channels_per_block,
     )
-    y_ref = group_norm(input_x, gamma, beta, num_groups, epsilon, use_swish).astype(dtype)
+    y_ref, y_add_d_ref = group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, use_silu, has_skip)
+    y_ref = y_ref.astype(dtype)
 
     for impl in my_op.ListOps():
         if not my_op.SelectOp(impl):
@@ -95,6 +135,10 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups:
         y_d.UpdateHostNumpyArray()
 
         np.testing.assert_allclose(y_ref, output_y, atol=1e-02)
+        if has_skip:
+            y_add_d_ref = y_add_d_ref.astype(dtype)
+            y_add_d.UpdateHostNumpyArray()
+            np.testing.assert_allclose(y_add_d_ref, add_output, atol=1e-02)
 
 
 dtypes = ["float32", "float16"]
@@ -102,19 +146,21 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups:
 
 @pytest.mark.parametrize("sd_sizes", get_sd_sizes())
 @pytest.mark.parametrize("dtype", dtypes)
-@pytest.mark.parametrize("swish", [True])
-def test_group_norm(sd_sizes, dtype, swish):
+@pytest.mark.parametrize("silu", [True])
+@pytest.mark.parametrize("has_skip", [True, False])
+def test_group_norm(sd_sizes, dtype, silu, has_skip):
     for func in dtype_to_funcs(dtype):
-        run_group_norm(*sd_sizes, dtype, swish, func)
+        run_group_norm(*sd_sizes, dtype, silu, has_skip, func)
 
 
 @pytest.mark.parametrize("sd_sizes", get_sd_sizes())
 @pytest.mark.parametrize("dtype", dtypes)
-@pytest.mark.parametrize("swish", [True])
-def test_group_norm_ck(sd_sizes, dtype, swish):
-    swish_suffix = "Swish" if swish else "Pass"
-    ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype)
-    run_group_norm(*sd_sizes, dtype, swish, ck_f_name)
+@pytest.mark.parametrize("silu", [True])
+@pytest.mark.parametrize("has_skip", [False])
+def test_group_norm_ck(sd_sizes, dtype, silu, has_skip):
+    silu_suffix = "Silu" if silu else "Pass"
+    ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype)
+    run_group_norm(*sd_sizes, dtype, silu, has_skip, ck_f_name)
 
 
 @dataclass
@@ -136,37 +182,67 @@ def report(self):
 
 
 def profile_group_norm_func(
-    batch_size: int, height: int, width: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func
+    batch_size: int,
+    height: int,
+    width: int,
+    num_channels: int,
+    num_groups: int,
+    dtype: str,
+    silu: bool,
+    has_skip: bool,
+    func,
 ):
     np.random.seed(0)
     input_x = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
     gamma = np.random.rand(num_channels).astype(np.float32)
     beta = np.random.rand(num_channels).astype(np.float32)
-    workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * 32 * 32).astype(np.float32)
+    workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * batch_size * num_groups).astype(np.float32)
     epsilon = 0.05
     output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype)
-    use_swish = swish
+
+    skip_x = (
+        np.random.rand(batch_size, height, width, num_channels).astype(dtype)
+        if has_skip
+        else np.empty((0), dtype=dtype)
+    )
+    bias_x = np.random.rand(num_channels).astype(dtype) if has_skip else np.empty((0), dtype=dtype)
+    add_output = (
+        np.random.rand(batch_size, height, width, num_channels).astype(dtype)
+        if has_skip
+        else np.empty((0), dtype=dtype)
+    )
+    use_silu = silu
+    broadcast_skip = False
+    channels_per_block = 0  # Compute in params initialization
 
     input_d = ke.DeviceArray(input_x)
+    skip_d = ke.DeviceArray(skip_x)
+    bias_d = ke.DeviceArray(bias_x)
     gamma_d = ke.DeviceArray(gamma)
     beta_d = ke.DeviceArray(beta)
     workspace_d = ke.DeviceArray(workspace)
     y_d = ke.DeviceArray(output_y)
+    y_add_d = ke.DeviceArray(add_output)
     f = getattr(ke, func)
 
     my_op = f(
         y_d,
-        workspace_d,
+        y_add_d,
         input_d,
+        skip_d,
+        bias_d,
         gamma_d,
         beta_d,
+        workspace_d,
+        epsilon,
         batch_size,
+        num_channels,
         height,
         width,
-        num_channels,
         num_groups,
-        epsilon,
-        use_swish,
+        use_silu,
+        broadcast_skip,
+        channels_per_block,
     )
     for impl in my_op.ListOps():
         duration_ms = -1
@@ -181,14 +257,14 @@ def profile_group_norm_func(
         )
 
 
-def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, swish=True, sort=True):
+def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, silu=True, has_skip=True, sort=True):
     with ke.benchmark(sort):
         for func in dtype_to_funcs(dtype):
-            profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, func)
+            profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, func)
         # ck function
-        swish_suffix = "Swish" if swish else "Pass"
-        ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype)
-        profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, ck_f_name)
+        silu_suffix = "Silu" if silu else "Pass"
+        ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype)
+        profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, ck_f_name)
 
 
 sd_profile_sizes = [
@@ -227,7 +303,8 @@ def profile():
     group.add_argument("num_channels", type=int)
     group.add_argument("num_groups", type=int)
     group.add_argument("dtype", choices=dtypes)
-    group.add_argument("--swish", action="store_true")
+    group.add_argument("--silu", action="store_true")
+    group.add_argument("--has_skip", action="store_true")
     group.add_argument("--sort", action="store_true")
 
     if len(sys.argv) == 1:
@@ -241,6 +318,7 @@ def profile():
             args.num_channels,
             args.num_groups,
             args.dtype,
-            args.swish,
+            args.silu,
+            args.has_skip,
             args.sort,
         )
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu
index 0bd47b2c0387e..6af163ab94b10 100644
--- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu
@@ -12,17 +12,21 @@
 #include "python/tools/kernel_explorer/kernel_explorer_interface.h"
 
 namespace py = pybind11;
-
+using onnxruntime::contrib::rocm::GetGroupNormWorkspaceSizeInBytes;
 namespace onnxruntime {
 
 template <typename T, int ThreadsPerBlock, int VecSize>
 class GroupNormNHWC : public IKernelExplorer {
  public:
-  GroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
-                int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
-      : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
-                static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
-                batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
+  GroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& bias,
+                DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, float epsilon,
+                int batch_size, int num_channels, int height, int width, int num_groups, bool use_silu,
+                bool broadcast_skip, int channels_per_block)
+      : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
+                static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
+                static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
+                epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
+                channels_per_block) {
     type_string_ = "GroupNormNHWC_" + std::to_string(ThreadsPerBlock) + "_" + std::to_string(VecSize);
   }
 
@@ -40,7 +44,7 @@ class GroupNormNHWC : public IKernelExplorer {
   }
 
  private:
-  using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
+  using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
   ParamsT params_{};
   contrib::rocm::GroupNormNHWCOp<T, ThreadsPerBlock, VecSize> op_{};
   std::string type_string_{};
@@ -49,11 +53,15 @@ class GroupNormNHWC : public IKernelExplorer {
 template <typename T>
 class GroupNormNHWCStaticSelection : public IKernelExplorer {
  public:
-  GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
-                               int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
-      : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
-                static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
-                batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
+  GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip,
+                               DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace,
+                               float epsilon, int batch_size, int num_channels, int height, int width, int num_groups,
+                               bool use_silu, bool broadcast_skip, int channels_per_block)
+      : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
+                static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
+                static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
+                epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
+                channels_per_block) {
     type_string_ = "GroupNormNHWCStaticSelection";
   }
 
@@ -71,7 +79,7 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer {
   }
 
  private:
-  using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
+  using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
   ParamsT params_{};
   std::string type_string_{};
 };
@@ -79,11 +87,15 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer {
 template <typename T>
 class GroupNormNHWCTunable : public IKernelExplorer {
  public:
-  GroupNormNHWCTunable(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
-                       int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
-      : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
-                static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
-                batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
+  GroupNormNHWCTunable(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip,
+                       DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace,
+                       float epsilon, int batch_size, int num_channels, int height, int width, int num_groups,
+                       bool use_silu, bool broadcast_skip, int channels_per_block)
+      : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
+                static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
+                static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
+                epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
+                channels_per_block) {
     params_.TuningContext()->EnableTunableOpAndTuning();
   }
 
@@ -100,21 +112,25 @@ class GroupNormNHWCTunable : public IKernelExplorer {
   }
 
  private:
-  using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
+  using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
   ParamsT params_{};
   contrib::rocm::GroupNormNHWCTunableOp<T> op_{};
 };
 
 #ifdef USE_COMPOSABLE_KERNEL
-template <typename T, bool WithSwish>
+template <typename T, bool WithSilu>
 class CKGroupNormNHWC : public IKernelExplorer {
  public:
-  CKGroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
-                  int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
-      : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
-                static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
-                batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
-    for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps<T, float, WithSwish>()) {
+  CKGroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip,
+                  DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace,
+                  float epsilon, int batch_size, int num_channels, int height, int width, int num_groups,
+                  bool use_silu, bool broadcast_skip, int channels_per_block)
+      : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
+                static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
+                static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
+                epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
+                channels_per_block) {
+    for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps<T, float, WithSilu>()) {
       type_strings_.emplace_back(std::move(type_string));
       ops_.emplace_back(std::move(op));
     }
@@ -141,7 +157,7 @@ class CKGroupNormNHWC : public IKernelExplorer {
   }
 
  private:
-  using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
+  using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
   using OpT = rocm::tunable::Op<ParamsT>;
   ParamsT params_{};
   std::vector<OpT> ops_;
@@ -151,15 +167,19 @@ class CKGroupNormNHWC : public IKernelExplorer {
 #endif  // USE_COMPOSABLE_KERNEL
 
 #ifdef USE_TRITON_KERNEL
-template <typename T, bool WithSwish>
+template <typename T, bool WithSilu>
 class GroupNormNHWCTriton : public IKernelExplorer {
  public:
-  GroupNormNHWCTriton(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta,
-                      int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish)
-      : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<float*>(workspace.ptr()),
-                static_cast<T*>(input.ptr()), static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()),
-                batch_size, height, width, num_channels, num_groups, epsilon, use_swish) {
-    for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps<T, WithSwish>()) {
+  GroupNormNHWCTriton(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip,
+                      DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace,
+                      float epsilon, int batch_size, int num_channels, int height, int width, int num_groups,
+                      bool use_silu, bool broadcast_skip, int channels_per_block)
+      : params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(add_output.ptr()),
+                static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(bias.ptr()),
+                static_cast<float*>(gamma.ptr()), static_cast<float*>(beta.ptr()), static_cast<float*>(workspace.ptr()),
+                epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip,
+                channels_per_block) {
+    for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps<T, WithSilu>()) {
       name_strings_.emplace_back(name);
       ops_.emplace_back(std::move(op));
     }
@@ -186,7 +206,7 @@ class GroupNormNHWCTriton : public IKernelExplorer {
   }
 
  private:
-  using ParamsT = contrib::rocm::GroupNormNHWCParams<T>;
+  using ParamsT = contrib::rocm::GroupNormNHWCTunableParams<T>;
   using OpT = rocm::tunable::Op<ParamsT>;
   ParamsT params_{};
   std::vector<OpT> ops_;
@@ -198,7 +218,8 @@ class GroupNormNHWCTriton : public IKernelExplorer {
 #define REGISTER_OP(name, type, threads_per_block, vec_size)                                                   \
   py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
       .def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&,                      \
-                    int, int, int, int, int, float, bool>())                                                   \
+                    DeviceArray&, DeviceArray&, DeviceArray&, float,                                           \
+                    int, int, int, int, int, bool, bool, int>())                                               \
       .def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats)                                 \
       .def("Profile", &name<type, threads_per_block, vec_size>::Profile)                                       \
       .def("Run", &name<type, threads_per_block, vec_size>::Run)                                               \
@@ -220,7 +241,8 @@ class GroupNormNHWCTriton : public IKernelExplorer {
 #define REGISTER_COMMON(name, type, ...)                                                  \
   py::class_<type<__VA_ARGS__>>(m, name)                                                  \
       .def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
-                    int, int, int, int, int, float, bool>())                              \
+                    DeviceArray&, DeviceArray&, DeviceArray&, float,                      \
+                    int, int, int, int, int, bool, bool, int>())                          \
       .def("SetRepeats", &type<__VA_ARGS__>::SetRepeats)                                  \
       .def("Profile", &type<__VA_ARGS__>::Profile)                                        \
       .def("Run", &type<__VA_ARGS__>::Run)                                                \
@@ -230,11 +252,11 @@ class GroupNormNHWCTriton : public IKernelExplorer {
 #define REGISTER_OP_TYPED(name, type) \
   REGISTER_COMMON(#name "_" #type, name, type)
 
-#define REGISTER_CK(type, with_swish, swish_suffix) \
-  REGISTER_COMMON("CKGroupNormNHWC" swish_suffix "_" #type, CKGroupNormNHWC, type, with_swish)
+#define REGISTER_CK(type, with_silu, silu_suffix) \
+  REGISTER_COMMON("CKGroupNormNHWC" silu_suffix "_" #type, CKGroupNormNHWC, type, with_silu)
 
-#define REGISTER_TRITON(type, with_swish, swish_suffix) \
-  REGISTER_COMMON("GroupNormNHWCTriton" swish_suffix "_" #type, GroupNormNHWCTriton, type, with_swish)
+#define REGISTER_TRITON(type, with_silu, silu_suffix) \
+  REGISTER_COMMON("GroupNormNHWCTriton" silu_suffix "_" #type, GroupNormNHWCTriton, type, with_silu)
 
 KE_REGISTER(m) {
   REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWC, half);
@@ -248,16 +270,16 @@ KE_REGISTER(m) {
 
 #ifdef USE_COMPOSABLE_KERNEL
   REGISTER_CK(half, false, "Pass");
-  REGISTER_CK(half, true, "Swish");
+  REGISTER_CK(half, true, "Silu");
   REGISTER_CK(float, false, "Pass");
-  REGISTER_CK(float, true, "Swish");
+  REGISTER_CK(float, true, "Silu");
 #endif  // USE_COMPOSABLE_KERNEL
 
 #ifdef USE_TRITON_KERNEL
   REGISTER_TRITON(half, false, "Pass");
-  REGISTER_TRITON(half, true, "Swish");
+  REGISTER_TRITON(half, true, "Silu");
   REGISTER_TRITON(float, false, "Pass");
-  REGISTER_TRITON(float, true, "Swish");
+  REGISTER_TRITON(float, true, "Silu");
 #endif
 }
 
diff --git a/onnxruntime/python/tools/microbench/benchmark.py b/onnxruntime/python/tools/microbench/benchmark.py
index a52740d45956c..a5936afcfe13e 100644
--- a/onnxruntime/python/tools/microbench/benchmark.py
+++ b/onnxruntime/python/tools/microbench/benchmark.py
@@ -147,20 +147,17 @@ def __init__(self, args):
 
     @classmethod
     @abstractmethod
-    def create_inputs_outputs(cls, op_param):
-        ...
+    def create_inputs_outputs(cls, op_param): ...
 
     def add_case(self, op_param, model):
         self.cases += [(op_param, model)]
 
     @abstractmethod
-    def create_cases(self):
-        ...
+    def create_cases(self): ...
 
     @classmethod
     @abstractmethod
-    def case_profile(cls, op_param, time):
-        ...
+    def case_profile(cls, op_param, time): ...
 
     def benchmark(self):
         self.create_cases()
diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py
new file mode 100644
index 0000000000000..667d7047c1fbd
--- /dev/null
+++ b/onnxruntime/python/tools/quantization/base_quantizer.py
@@ -0,0 +1,723 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+import logging
+from typing import Any, Dict
+
+import numpy as np
+import onnx
+import onnx.numpy_helper
+
+try:
+    from onnx.reference.op_run import to_array_extended
+except ImportError:
+    # old version of onnx.
+    to_array_extended = None
+
+from .calibrate import TensorData
+from .onnx_model import ONNXModel
+from .quant_utils import (
+    ONNX_TYPE_TO_NP_TYPE,
+    TENSOR_NAME_QUANT_SUFFIX,
+    QuantizedValue,
+    QuantizedValueType,
+    QuantType,
+    compute_scale_zp,
+    compute_scale_zp_float8,
+    find_by_name,
+    get_qmin_qmax_for_qType,
+    model_has_infer_metadata,
+    quantize_data,
+    quantize_nparray,
+    save_and_reload_model_with_shape_infer,
+    tensor_proto_to_array,
+)
+
+
+class QuantizationParams:
+    def __init__(self, **data: Dict[str, Any]):
+        self.data = {}
+        for k, v in data.items():
+            if not isinstance(k, str):
+                raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.")
+            if not isinstance(v, (int, str, np.ndarray)):
+                raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.")
+            if k == "scale" and v.dtype not in (np.float32, np.float16):
+                raise ValueError(f"scale must a float32 or float16 numpy element but is {v.dtype} for k={k!r}")
+            self.data[k] = v
+
+    def __iter__(self):
+        yield from self.data
+
+    def __getitem__(self, key):
+        return self.data[key]
+
+    def __len__(self):
+        return len(self.data)
+
+
+class BaseQuantizer:
+    def __init__(
+        self,
+        model,
+        per_channel,
+        reduce_range,
+        weight_qType,
+        activation_qType,
+        tensors_range,
+        nodes_to_quantize,
+        nodes_to_exclude,
+        op_types_to_quantize,
+        extra_options=None,
+    ):
+        if not model_has_infer_metadata(model):
+            model = save_and_reload_model_with_shape_infer(model)
+        self.value_infos = {vi.name: vi for vi in model.graph.value_info}
+        self.value_infos.update({ot.name: ot for ot in model.graph.output})
+        self.value_infos.update({it.name: it for it in model.graph.input})
+
+        self.model = ONNXModel(model)
+        self.per_channel = per_channel  # weight-pack per channel
+        self.reduce_range = reduce_range
+
+        self.extra_options = extra_options if extra_options else {}
+        self.enable_subgraph_quantization = (
+            "EnableSubgraph" in self.extra_options and self.extra_options["EnableSubgraph"]
+        )
+        self.parent = None
+        self.force_quantize_no_input_check = (
+            "ForceQuantizeNoInputCheck" in self.extra_options and self.extra_options["ForceQuantizeNoInputCheck"]
+        )
+        self.is_weight_symmetric = self.extra_options.get(
+            "WeightSymmetric", weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN)
+        )
+        self.is_activation_symmetric = self.extra_options.get("ActivationSymmetric", False)
+        self.min_real_range = self.extra_options.get("MinimumRealRange")
+
+        self.activation_qType = getattr(activation_qType, "tensor_type", activation_qType)
+        self.weight_qType = getattr(weight_qType, "tensor_type", weight_qType)
+
+        """
+            Dictionary specifying the min and max values for tensors. It has following format:
+                {
+                    "param_name": [min, max]
+                }
+            example:
+                {
+                    'Conv_3:0': [np.float32(0), np.float32(0.5)],
+                    'Conv_4:0': [np.float32(1), np.float32(3.5)]
+                }
+        """
+        if tensors_range is not None and any(map(lambda t: not isinstance(t, TensorData), tensors_range.values())):
+            raise TypeError(
+                f"tensors_range contains unexpected types {set(type(v) for v in tensors_range.values())}, not TensorData."
+            )
+        self.tensors_range = tensors_range
+        self.nodes_to_quantize = nodes_to_quantize  # specific nodes to quantize
+        self.nodes_to_exclude = nodes_to_exclude  # specific nodes to exclude
+        self.op_types_to_quantize = op_types_to_quantize
+
+        self.opset_version = self.check_opset_version()
+
+        # Map of all original value names to quantized value names
+        self.quantized_value_map = {}
+
+        self.tensor_quant_overrides, self.tensor_quant_override_types = self._get_and_check_tensor_quant_overrides()
+        self.quantization_params = self.calculate_quantization_params()
+
+        # to store specified scale and zeropoint instead of calculated value, tensor_name->(scale, zeropoint)
+        self.used_scale_zp_map = {}
+
+    def set_quant_scale_zp(self, tensor_name, value):
+        assert isinstance(value, tuple) and len(value) == 2, "value must be scale(float or float16) and zeropoint"
+        assert hasattr(value[0], "dtype")
+        assert tensor_name not in self.used_scale_zp_map, f"{tensor_name} has been setted before"
+        self.used_scale_zp_map[tensor_name] = value
+
+    def find_quant_scale_zp(self, input_name):
+        if input_name in self.used_scale_zp_map:
+            return self.used_scale_zp_map[input_name]
+        if self.parent is not None:
+            return self.parent.find_quantized_value(input_name)
+        return (None, None)
+
+    def quantize_model(self):
+        raise NotImplementedError
+
+    def is_input_a_initializer(self, input_name):
+        initializer = find_by_name(input_name, self.model.initializer())
+        return initializer is not None
+
+    def is_per_channel(self):
+        return self.per_channel
+
+    def is_valid_quantize_weight(self, weight_name):
+        weight = find_by_name(weight_name, self.model.initializer())
+        if weight is not None:
+            return weight.data_type in (onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16)
+        if (not self.enable_subgraph_quantization) or (self.parent is None):
+            return False
+        return self.parent.is_valid_quantize_weight(weight_name)
+
+    def should_quantize_node(self, node):
+        if (
+            self.nodes_to_quantize is not None
+            and len(self.nodes_to_quantize) != 0
+            and node.name not in self.nodes_to_quantize
+        ):
+            return False
+
+        if node.op_type not in self.op_types_to_quantize:
+            return False
+
+        if self.nodes_to_exclude is not None and node.name in self.nodes_to_exclude:
+            return False
+
+        return True
+
+    def check_opset_version(self):
+        ai_onnx_domain = [
+            opset for opset in self.model.model.opset_import if not opset.domain or opset.domain == "ai.onnx"
+        ]
+        if len(ai_onnx_domain) != 1:
+            raise ValueError("Failed to find proper ai.onnx domain")
+        opset_version = ai_onnx_domain[0].version
+
+        if opset_version == 10:
+            logging.warning(
+                f"The original model opset version is {opset_version}, which does not support node fusions. Please update the model to opset >= 11 for better performance."
+            )
+            return 10
+
+        if opset_version < 10:
+            logging.warning(
+                f"The original model opset version is {opset_version}, which does not support quantization. Please update the model to opset >= 11. Updating the model automatically to opset 11. Please verify the quantized model."
+            )
+            self.model.model.opset_import.remove(ai_onnx_domain[0])
+            self.model.model.opset_import.extend([onnx.helper.make_opsetid("", 11)])
+            opset_version = 11
+
+        if opset_version < 19 and self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
+            logging.warning(
+                f"The original model opset version is {opset_version}, which does not support quantization to float 8. "
+                "Please update the model to opset >= 19. Updating the model automatically to opset 19. "
+                "Please verify the quantized model."
+            )
+            self.model.model.opset_import.remove(ai_onnx_domain[0])
+            self.model.model.opset_import.extend([onnx.helper.make_opsetid("", 19)])
+            self.model.model.ir_version = 9
+            opset_version = 19
+
+        return opset_version
+
+    def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0):
+        """
+        Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
+        """
+
+        # Handle case where bias already in quantization map
+        if bias_name in self.quantized_value_map:
+            return self.quantized_value_map[bias_name].q_name
+
+        # get scale for weight
+        weight_scale_name = self.quantized_value_map[weight_name].scale_name
+        weight_initializer = find_by_name(weight_scale_name, self.model.initializer())
+        weight_scale = tensor_proto_to_array(weight_initializer)
+
+        # get bias
+        bias_initializer = find_by_name(bias_name, self.model.initializer())
+        bias_data = tensor_proto_to_array(bias_initializer)
+        quantized_bias_name = bias_name + TENSOR_NAME_QUANT_SUFFIX
+
+        # get scale for input
+        if input_name in self.quantized_value_map:
+            input_scale_name = self.quantized_value_map[input_name].scale_name
+        elif input_name in self.quantization_params:
+            _, input_scale_name, _, _, _ = self._get_quantization_params(input_name)
+        else:
+            raise ValueError(f"Expected {input_name} to be in quantized value map for static quantization")
+
+        inputscale_initializer = find_by_name(input_scale_name, self.model.initializer())
+        input_scale = tensor_proto_to_array(inputscale_initializer)
+
+        # quantize bias
+        if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
+            data = np.asarray(bias_data)
+            if data.dtype == np.float16:
+                node_qtype = onnx.TensorProto.FLOAT16
+            elif data.dtype == np.float32:
+                node_qtype = onnx.TensorProto.FLOAT
+            else:
+                raise TypeError(f"Only float16 or float32 are supported with float 8 but bias dtype is {data.dtype}.")
+            quantized_data = data.astype(np.float32)
+            bias_scale = np.array([1], dtype=quantized_data.dtype)
+            bias_scale_data = bias_scale.reshape(-1)
+            packed_bias_initializer = onnx.numpy_helper.from_array(quantized_data, quantized_bias_name)
+            self.model.initializer_extend([packed_bias_initializer])
+            node_type = "Cast"
+        else:
+            # calculate scale for bias
+            # TODO: This formula should be explained including why the scale is not estimated for the bias as well.
+            bias_scale = input_scale * weight_scale * beta
+
+            quantized_data = (np.asarray(bias_data) / bias_scale).round().astype(np.int32)
+
+            # update bias initializer
+            bias_np_data = np.asarray(quantized_data, dtype=np.int32).reshape(bias_initializer.dims)
+            packed_bias_initializer = onnx.numpy_helper.from_array(bias_np_data, quantized_bias_name)
+            self.model.initializer_extend([packed_bias_initializer])
+            bias_scale_data = np.asarray(bias_scale, dtype=np.float32).reshape(-1)
+            node_type = "DequantizeLinear"
+            node_qtype = self.weight_qType
+
+        # update scale initializer
+        quantized_bias_scale_name = quantized_bias_name + "_scale"
+        packed_bias_scale_initializer = onnx.numpy_helper.from_array(bias_scale_data, quantized_bias_scale_name)
+        self.model.initializer_extend([packed_bias_scale_initializer])
+
+        # update zero initializer
+        if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
+            tensor_type = self.weight_qType
+        else:
+            tensor_type = onnx.TensorProto.INT32
+
+        quantized_bias_zp_name = quantized_bias_name + "_zero_point"
+        if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
+            packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, self.weight_qType, [1], [0.0])
+        elif self.is_per_channel():
+            bias_zp_data = np.zeros(bias_scale.shape, dtype=np.int32).reshape(-1)
+            packed_bias_zp_initializer = onnx.numpy_helper.from_array(bias_zp_data, quantized_bias_zp_name)
+        else:
+            packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, tensor_type, [], [0])
+        self.model.initializer_extend([packed_bias_zp_initializer])
+
+        assert bias_name not in self.quantized_value_map
+        quantized_value = QuantizedValue(
+            bias_name,
+            quantized_bias_name,
+            quantized_bias_scale_name,
+            quantized_bias_zp_name,
+            QuantizedValueType.Initializer,
+            0 if bias_scale_data.size > 1 else None,
+            node_type=node_type,
+            node_qtype=node_qtype,
+        )
+        self.quantized_value_map[bias_name] = quantized_value
+
+        return quantized_bias_name
+
+    def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_weight=False):
+        """
+        :param weight: TensorProto initializer
+        :param qType: type to quantize to
+        :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point.
+                                  If keep_float_weight is False, quantize the weight, or don't quantize the weight.
+        :return: quantized weight name, zero point name, scale name
+        """
+        # Find if this input is already quantized
+        if weight.name in self.quantized_value_map:
+            quantized_value = self.quantized_value_map[weight.name]
+            return (
+                quantized_value.q_name,
+                quantized_value.zp_name,
+                quantized_value.scale_name,
+            )
+
+        q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX
+        zp_name = weight.name + "_zero_point"
+        scale_name = weight.name + "_scale"
+
+        # Quantize weight data. Use quantization overrides if provided by the user.
+        weight_data = tensor_proto_to_array(weight)
+        quant_overrides = self.get_per_tensor_quant_overrides(weight.name)
+        if "quant_type" in quant_overrides:
+            qType = quant_overrides["quant_type"].tensor_type  # noqa: N806
+
+        if "scale" in quant_overrides and "zero_point" in quant_overrides:
+            zero_point = np.array(quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[qType])
+            scale = np.array(quant_overrides["scale"])
+            q_weight_data = quantize_nparray(qType, weight_data.flatten(), scale, zero_point)
+            assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
+            assert (
+                zero_point.dtype != np.float32 and zero_point.dtype != np.float16
+            ), f"Unexpected dtype {zero_point.dtype}"
+            assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
+
+        else:
+            _, _, zero_point, scale, q_weight_data = quantize_data(
+                weight_data.flatten(),
+                qType,
+                quant_overrides.get("symmetric", self.is_weight_symmetric),
+                reduce_range=quant_overrides.get("reduce_range", self.reduce_range and reduce_range),
+                min_real_range=self.min_real_range,
+                rmin_override=quant_overrides.get("rmin"),
+                rmax_override=quant_overrides.get("rmax"),
+            )
+
+            assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
+            assert (
+                zero_point.dtype != np.float32 and zero_point.dtype != np.float16
+            ), f"Unexpected dtype {zero_point.dtype}"
+            assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
+
+        scale_dtype = weight.data_type
+        scale_initializer = onnx.helper.make_tensor(scale_name, scale_dtype, [], scale.reshape((-1,)).tolist())
+        zero_initializer = onnx.helper.make_tensor(zp_name, qType, [], zero_point.reshape((-1,)).tolist())
+        self.model.initializer_extend([scale_initializer, zero_initializer])
+
+        if not keep_float_weight:
+            if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
+                q_weight_initializer = onnx.TensorProto()
+                q_weight_initializer.data_type = self.weight_qType
+                q_weight_initializer.dims.extend(weight.dims)
+                q_weight_initializer.name = q_weight_name
+                # Do not remove .flatten().copy() numpy is not clear about data persistence.
+                q_weight_initializer.raw_data = q_weight_data.flatten().copy().tobytes()
+                if to_array_extended is not None:
+                    # This test should not be needed but it helped catch some issues
+                    # with data persistence and tobytes.
+                    check = to_array_extended(q_weight_initializer)
+                    if check.shape != weight_data.shape or check.tobytes() != q_weight_data.tobytes():
+                        raise RuntimeError(
+                            f"The initializer of shape {weight_data.shape} could not be created, expecting "
+                            f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}"
+                            f"\nraw={str(q_weight_initializer)[:200]}."
+                        )
+            else:
+                q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape(
+                    weight.dims
+                )
+                q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name)
+            self.model.initializer_extend([q_weight_initializer])
+
+        # Log entry for this quantized weight
+        quantized_value = QuantizedValue(
+            weight.name,
+            q_weight_name,
+            scale_name,
+            zp_name,
+            QuantizedValueType.Initializer,
+            None,
+        )
+        self.quantized_value_map[weight.name] = quantized_value
+        return q_weight_name, zp_name, scale_name
+
+    def quantize_weight_per_channel(
+        self,
+        weight_name,
+        weight_qType,
+        channel_axis,
+        reduce_range=True,
+        keep_float_weight=False,
+    ):
+        # Find if this input is already quantized
+        if weight_name in self.quantized_value_map:
+            quantized_value = self.quantized_value_map[weight_name]
+            return (
+                quantized_value.q_name,
+                quantized_value.zp_name,
+                quantized_value.scale_name,
+            )
+
+        initializer = find_by_name(weight_name, self.model.initializer())
+        if initializer is None:
+            raise ValueError("{} is not an initializer", weight_name)
+
+        weights = tensor_proto_to_array(initializer)
+        channel_count = weights.shape[channel_axis]
+        quant_overrides_for_channels = self.get_per_channel_quant_overrides(weight_name, channel_count)
+
+        # If user provides per-channel quantization overrides, all channels must use the same quantization type.
+        # So, just use the first channel's type.
+        if "quant_type" in quant_overrides_for_channels[0]:
+            weight_qType = quant_overrides_for_channels[0]["quant_type"].tensor_type  # noqa: N806
+
+        zero_point_list = []
+        scale_list = []
+        quantized_per_channel_data_list = []
+        for i in range(channel_count):
+            per_channel_data = weights.take(i, channel_axis)
+            channel_quant_overrides = quant_overrides_for_channels[i]
+
+            if "scale" in channel_quant_overrides and "zero_point" in channel_quant_overrides:
+                zero_point = np.array(channel_quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[weight_qType])
+                scale = np.array(channel_quant_overrides["scale"])
+                quantized_per_channel_data = quantize_nparray(
+                    weight_qType, per_channel_data.flatten(), scale, zero_point
+                )
+                assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
+                assert (
+                    zero_point.dtype != np.float32 and zero_point.dtype != np.float16
+                ), f"Unexpected dtype {zero_point.dtype}"
+                assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
+                assert isinstance(
+                    quantized_per_channel_data, np.ndarray
+                ), f"Unexpected type {type(quantized_per_channel_data)}"
+
+            else:
+                symmetric = channel_quant_overrides.get(
+                    "symmetric",
+                    (
+                        self.is_weight_symmetric
+                        or weight_qType in (onnx.TensorProto.INT8, onnx.TensorProto.FLOAT8E4M3FN)
+                    ),
+                )
+                _, _, zero_point, scale, quantized_per_channel_data = quantize_data(
+                    per_channel_data.flatten(),
+                    weight_qType,
+                    symmetric,
+                    reduce_range=channel_quant_overrides.get("reduce_range", self.reduce_range and reduce_range),
+                    min_real_range=self.min_real_range,
+                    rmin_override=channel_quant_overrides.get("rmin"),
+                    rmax_override=channel_quant_overrides.get("rmax"),
+                )
+
+                assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
+                assert (
+                    zero_point.dtype != np.float32 and zero_point.dtype != np.float16
+                ), f"Unexpected dtype {zero_point.dtype}"
+                assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
+                assert isinstance(
+                    quantized_per_channel_data, np.ndarray
+                ), f"Unexpected type {type(quantized_per_channel_data)}"
+
+            zero_point_list.append(zero_point)
+            scale_list.append(scale)
+            quantized_per_channel_data_list.append(quantized_per_channel_data)
+
+        # combine per_channel_data into one
+        reshape_dims = list(weights.shape)  # deep copy
+        reshape_dims[channel_axis] = 1  # only one per channel for reshape
+        quantized_weights = np.asarray(quantized_per_channel_data_list[0]).reshape(reshape_dims)
+        for i in range(1, len(quantized_per_channel_data_list)):
+            channel_weights = np.asarray(quantized_per_channel_data_list[i]).reshape(reshape_dims)
+            quantized_weights = np.concatenate((quantized_weights, channel_weights), channel_axis)
+
+        q_weight_name = weight_name + TENSOR_NAME_QUANT_SUFFIX
+        zp_name = weight_name + "_zero_point"
+        scale_name = weight_name + "_scale"
+
+        quantized_value = QuantizedValue(
+            weight_name,
+            q_weight_name,
+            scale_name,
+            zp_name,
+            QuantizedValueType.Initializer,
+            None,
+        )
+        self.quantized_value_map[weight_name] = quantized_value
+
+        # Update packed weight, zero point, and scale initializers
+        zero_scale_shape = [initializer.dims[channel_axis]]
+        scale_initializer = onnx.helper.make_tensor(
+            scale_name, initializer.data_type, zero_scale_shape, np.hstack(scale_list).tolist()
+        )
+        zero_initializer = onnx.helper.make_tensor(
+            zp_name, weight_qType, zero_scale_shape, np.hstack(zero_point_list).tolist()
+        )
+
+        self.model.initializer_extend([scale_initializer, zero_initializer])
+
+        if not keep_float_weight:
+            quantized_weights = np.asarray(
+                quantized_weights,
+                dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[weight_qType],
+            ).reshape(initializer.dims)
+            q_weight_initializer = onnx.numpy_helper.from_array(quantized_weights, q_weight_name)
+            self.model.initializer_extend([q_weight_initializer])
+
+        return q_weight_name, zp_name, scale_name
+
+    def _get_and_check_tensor_quant_overrides(self):
+        """
+        Get tensor quantization overrides and check correctness.
+        """
+        tensor_quant_overrides = self.extra_options.get("TensorQuantOverrides", {})
+        tensor_quant_override_types = set()
+
+        # Validate that compatible/valid overrides are provided.
+        if tensor_quant_overrides:
+            initializer_names = self.model.get_initializer_name_set()
+            value_info_names = set(self.value_infos.keys())
+            keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"}
+
+            for tensor_name, quant_overrides_list in tensor_quant_overrides.items():
+                if tensor_name not in initializer_names and tensor_name not in value_info_names:
+                    raise ValueError(f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model")
+
+                if not isinstance(quant_overrides_list, list):
+                    raise ValueError(f"Tensor quantization overrides for '{tensor_name}' are not in a list")
+
+                is_initializer = tensor_name in initializer_names
+                if not is_initializer and len(quant_overrides_list) > 1:
+                    raise ValueError(
+                        f"Tensor '{tensor_name}' has a list of per-channel overrides, but is not an initializer"
+                    )
+
+                quant_type = None
+                for index, quant_overrides in enumerate(quant_overrides_list):
+                    if not isinstance(quant_overrides, dict):
+                        raise ValueError(
+                            f"Tensor quantization overrides at index {index} for '{tensor_name}' are not in a dict"
+                        )
+
+                    # For per-channel quantization, all channels must use the same quantization type.
+                    # Therefore, if the user tries to override the quant_type for a channel, it must match in all
+                    # other channels.
+                    if index == 0:
+                        quant_type = quant_overrides.get("quant_type")
+                        if quant_type:
+                            tensor_quant_override_types.add(quant_type.tensor_type)
+                    elif quant_type != quant_overrides.get("quant_type"):
+                        raise ValueError(
+                            "Channel quantization types for tensor '{tensor_name}' do not match at index {index}."
+                        )
+
+                    has_scale = "scale" in quant_overrides
+                    has_zero_point = "zero_point" in quant_overrides
+
+                    if (has_scale and not has_zero_point) or (has_zero_point and not has_scale):
+                        raise ValueError(
+                            "Must provide both 'scale' and 'zero_point' if one of the overrides is provided"
+                        )
+
+                    if has_scale:
+                        for key in keys_unsupported_with_scale_zp:
+                            if key in quant_overrides:
+                                raise ValueError(
+                                    f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'"
+                                )
+
+        return tensor_quant_overrides, tensor_quant_override_types
+
+    def get_per_tensor_quant_overrides(self, tensor_name):
+        quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{}])
+        num_overrides = len(quant_overrides_list)
+        if num_overrides > 1:
+            raise ValueError(
+                f"Expected tensor '{tensor_name}' to use per-tensor quantization overrides, "
+                f"but found {num_overrides} per-channel overrides."
+            )
+
+        return quant_overrides_list[0] if num_overrides > 0 else {}
+
+    def get_per_channel_quant_overrides(self, tensor_name, num_channels):
+        quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{} for i in range(num_channels)])
+
+        if len(quant_overrides_list) != num_channels:
+            raise ValueError(
+                f"Expected tensor '{tensor_name}' to have {num_channels} per-channel quantization overrides, "
+                f"but found {len(quant_overrides_list)} instead."
+            )
+
+        return quant_overrides_list
+
+    def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=None):
+        """
+        Create initializers and inputs in the graph for zero point and scale of output.
+        Zero point and scale values are obtained from self.quantization_params if specified.
+            parameter param_name: Name of the quantization parameter.
+            return: result, scale_name, zero_point_name, scale_shape, zero_point_shape.
+        """
+        zero_point_type = self.activation_qType
+
+        if use_scale is None or use_zeropoint is None:
+            if self.quantization_params is None or param_name not in self.quantization_params:
+                logging.info(f'Quantization parameters for tensor:"{param_name}" not specified')
+                return False, "", "", "", ""
+
+            params = self.quantization_params[param_name]
+            if not isinstance(params, QuantizationParams):
+                raise TypeError(f"Unexpected type {type(params)} for {param_name!r}.")
+            if params is None or len(params) != 3:
+                raise ValueError(
+                    "Quantization parameters should contain zero point, scale, quant type. "
+                    f"Specified values for output {param_name}: {params}"
+                )
+
+            zero_point_values = np.array([params["zero_point"]])
+            if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16):
+                raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}")
+            scale_values = np.array([params["scale"]])
+            assert scale_values.dtype != np.float64
+            zero_point_type = params["quant_type"]
+        else:
+            zero_point_values = np.array([use_zeropoint])
+            scale_values = np.array([use_scale])
+            params = self.quantization_params[param_name]
+            if "scale" in params:
+                dtype = params["scale"].dtype
+                scale_values = scale_values.astype(dtype)
+            assert scale_values.dtype != np.float64
+
+        zero_point_shape = []
+        zero_point_name = param_name + "_zero_point"
+        scale_shape = []
+        scale_name = param_name + "_scale"
+
+        # Add initializers
+        init_zp = onnx.helper.make_tensor(
+            zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist()
+        )
+        self.model.add_initializer(init_zp)
+        if scale_values.dtype == np.float32:
+            scale_type = onnx.TensorProto.FLOAT
+        elif scale_values.dtype == np.float16:
+            scale_type = onnx.TensorProto.FLOAT16
+        else:
+            raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}")
+        init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist())
+        self.model.add_initializer(init_scale)
+
+        return True, scale_name, zero_point_name, scale_shape, zero_point_shape
+
+    def calculate_quantization_params(self):
+        if self.tensors_range is None:
+            return {}
+
+        # adjust tensor_ranges for input of Clip and Relu node
+        for node in self.model.nodes():
+            if node.op_type not in ["Clip", "Relu"]:
+                continue
+            if self.is_activation_symmetric:
+                continue
+            if not self.should_quantize_node(node):
+                continue
+            if len(self.model.input_name_to_nodes()[node.input[0]]) != 1:
+                continue
+            if node.input[0] not in self.tensors_range or node.output[0] not in self.tensors_range:
+                continue
+            td = self.tensors_range[node.output[0]]
+            if not isinstance(td, TensorData):
+                raise TypeError(f"Unexpected type {type(td)} for {node.output[0]!r}.")
+            self.tensors_range[node.input[0]] = td
+
+        quantization_params = {}
+        for tensor_name in self.tensors_range:
+            td = self.tensors_range[tensor_name]
+            if not isinstance(td, TensorData):
+                raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.")
+
+            quant_overrides = self.get_per_tensor_quant_overrides(tensor_name)
+
+            quant_type = self.activation_qType
+            if "quant_type" in quant_overrides:
+                quant_type = quant_overrides["quant_type"].tensor_type
+
+            if "scale" in quant_overrides and "zero_point" in quant_overrides:
+                zero, scale = quant_overrides["zero_point"], quant_overrides["scale"]
+            elif quant_type == onnx.TensorProto.FLOAT8E4M3FN:
+                zero, scale = compute_scale_zp_float8(quant_type, td.avg_std[1])
+            else:
+                rmin = quant_overrides.get("rmin", td.range_value[0])
+                rmax = quant_overrides.get("rmax", td.range_value[1])
+                symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric)
+                reduce_range = quant_overrides.get("reduce_range", False)
+                qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric)
+                zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range)
+
+            quantization_params[tensor_name] = QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type)
+
+        return quantization_params
diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py
index 624049b244580..971cc203f4d73 100644
--- a/onnxruntime/python/tools/quantization/calibrate.py
+++ b/onnxruntime/python/tools/quantization/calibrate.py
@@ -918,11 +918,7 @@ def compute_entropy(self):
         thresholds_dict = {}  # per tensor thresholds
 
         print(f"Number of tensors : {len(histogram_dict)}")
-        print(
-            "Number of histogram bins : {} (The number may increase depends on the data it collects)".format(
-                self.num_bins
-            )
-        )
+        print(f"Number of histogram bins : {self.num_bins} (The number may increase depends on the data it collects)")
         print(f"Number of quantized bins : {self.num_quantized_bins}")
 
         for tensor, histogram in histogram_dict.items():
diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py
index 9ebf400498e0e..fbf954febdda4 100644
--- a/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py
+++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py
@@ -122,6 +122,11 @@ def fuse(
 
         self.nodes_to_remove.extend(subgraph_nodes)
         fused_node = onnx.helper.make_node(
-            self.fused_op_type, inputs=[subgraph_input], outputs=[subgraph_output], p=2, axis=-1
+            self.fused_op_type,
+            name=self.create_unique_node_name(),
+            inputs=[subgraph_input],
+            outputs=[subgraph_output],
+            p=2,
+            axis=-1,
         )
         self.nodes_to_add.append(fused_node)
diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py
index becbaceab184e..e584a65574520 100644
--- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py
+++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py
@@ -3,6 +3,8 @@
 # Licensed under the MIT License. See License.txt in the project root for
 # license information.
 # --------------------------------------------------------------------------
+from __future__ import annotations
+
 import logging
 from pathlib import Path
 
@@ -13,7 +15,72 @@
 from .fusion_lpnorm import FusionLpNormalization
 
 
-def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: bool = False) -> bool:
+def qnn_preprocess_model(
+    model_input: Path,
+    model_output: Path,
+    fuse_layernorm: bool = False,
+    save_as_external_data: bool = False,
+    all_tensors_to_one_file: bool = False,
+    external_data_location: str | None = None,
+    external_data_size_threshold: int = 1024,
+    external_data_convert_attribute: bool = False,
+    inputs_to_make_channel_last: list[str] | None = None,
+    outputs_to_make_channel_last: list[str] | None = None,
+) -> bool:
+    """
+    If necessary, this method creates a new "pre-processed" model in preparation for
+    quantization of a model to be used in QNN EP. Returns true if a new model was created.
+
+    This method perfoms the following operations:
+    - Fuse Erf sequence into a single Gelu node.
+    - Fuse ReduceL2 sequence into a single LpNormalization node (p == 2).
+    - (Optional) Fuse ReduceMean sequence into a single LayerNormalization node.
+
+    Args:
+        model_input: Path to the input model file.
+        model_output: Path the output model file, which is only created if this method returns True.
+        fuse_layernorm: True if ReduceMean sequences should be fused into LayerNormalization nodes.
+            Defaults to False.
+        save_as_external_data: True if output model should be saved with external data. Defaults to false.
+        all_tensors_to_one_file: Effective only if save_as_external_data is true. Defaults to false.
+            If true, save all tensors to one external file specified by external_data_location.
+            If false, save each tensor to a file named with the tensor name.
+        external_data_location: Effective only if save_as_external_data is true. Defaults to None.
+            Specify the external file to which all tensors are saved. Path is relative
+            to the model path. If not specified, the model's name is used.
+        external_data_size_threshold: Effective only if save_as_external_data is true. Defaults to 1024.
+            Tensors with a data size >= external_data_size_threshold are converted to external data.
+            To convert every tensor with raw data to external data, set to 0.
+        external_data_convert_attribute: Effective only if save_as_external_data is true. Defaults to false.
+            If true, convert all tensors to external data.
+            If false, convert only non-attribute tensors to external data.
+        inputs_to_make_channel_last: List of graph input names to transpose to be "channel-last". For example,
+            if "input0" originally has the shape (N, C, D1, D2, ..., Dn), the resulting model will change input0's
+            shape to (N, D1, D2, ..., Dn, C) and add a transpose node after it.
+
+            Original:
+                input0 (N, C, D1, D2, ..., Dn) --> <Nodes>
+
+            Updated:
+                input0 (N, D1, D2, ..., Dn, C) --> Transpose --> input0_chanfirst (N, C, D1, D2, ..., Dn) --> <Nodes>
+
+            This can potentially improve inference latency for QDQ models running on QNN EP because the
+            additional transpose node may allow other transpose nodes inserted during ORT layout transformation
+            to cancel out.
+        outputs_to_make_channel_last: List of graph output names to transpose to be "channel-last". For example,
+            if "output0" originally has the shape (N, C, D1, D2, ..., Dn), the resulting model will change output0's
+            shape to (N, D1, D2, ..., Dn, C) and add a transpose node before it.
+
+            Original:
+                <Nodes> --> output0 (N, C, D1, D2, ..., Dn)
+
+            Updated:
+                <Nodes> --> output0_chanfirst (N, C, D1, D2, ..., Dn) --> Transpose --> output0 (N, D1, D2, ..., Dn, C)
+
+            This can potentially improve inference latency for QDQ models running on QNN EP because the
+            additional transpose node may allow other transpose nodes inserted during ORT layout transformation
+            to cancel out.
+    """
     modified = False
     model = onnx.load_model(model_input)
     onnx_model = ONNXModel(model)
@@ -44,8 +111,197 @@ def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm:
             if fusion_layernorm.apply():
                 modified = True
 
+    # Optionally, transpose inputs and/or outputs to make them "channel-last".
+    if inputs_to_make_channel_last or outputs_to_make_channel_last:
+        transpose_node_prefix = "Transpose_channel_"
+        transpose_node_suffix: int = onnx_model.get_largest_node_name_suffix(transpose_node_prefix) + 1
+        update_io_to_channel_last(
+            onnx_model.model,
+            inputs_to_make_channel_last,
+            outputs_to_make_channel_last,
+            transpose_node_name_prefix=transpose_node_prefix,
+            transpose_node_name_start_suffix=transpose_node_suffix,
+        )
+        modified = True
+
+    # Make sure all nodes have a name.
+    unnamed_node_prefix = "qnn_preproc_node_"
+    available_suffix = onnx_model.get_largest_node_name_suffix(unnamed_node_prefix) + 1
+    for node in onnx_model.model.graph.node:
+        if node.op_type != "Constant" and not node.name:
+            new_node_name = f"{unnamed_node_prefix}{available_suffix!s}"
+            available_suffix += 1
+            node.name = new_node_name
+            modified = True
+            logging.warning(f"Node of type {node.op_type} does not have a name. Renamed to {new_node_name}.")
+
     if modified:
         onnx_model.topological_sort()
-        onnx.save_model(model, model_output)
+        onnx.save_model(
+            model,
+            model_output,
+            save_as_external_data=save_as_external_data,
+            all_tensors_to_one_file=all_tensors_to_one_file,
+            location=external_data_location,
+            size_threshold=external_data_size_threshold,
+            convert_attribute=external_data_convert_attribute,
+        )
 
     return modified
+
+
+class InputOutputNameMap:
+    def __init__(
+        self,
+        orig_tensor_names: set[str],
+        orig_graph_inputs: dict[str, onnx.ValueInfoProto],
+        orig_graph_outputs: dict[str, onnx.ValueInfoProto],
+    ):
+        self.orig_tensor_names = orig_tensor_names
+        self.orig_graph_inputs = orig_graph_inputs
+        self.orig_graph_outputs = orig_graph_outputs
+        self.updated_io_names = {}
+        self.new_value_infos = []
+
+    def get_new_name(self, orig_name: str):
+        if orig_name in self.updated_io_names:
+            return self.updated_io_names[orig_name]
+
+        # Make a new tensor name that is unique among all tensors in the graph.
+        prefix: str = f"{orig_name}_channel_first_"
+        suffix: int = -1
+        for tensor_name in self.orig_tensor_names:
+            if tensor_name.startswith(prefix) and tensor_name[len(prefix) :].isdigit():
+                index = int(tensor_name[len(prefix) :])
+                suffix = max(suffix, index)
+
+        suffix += 1  # This is the first available suffix.
+        new_name = f"{prefix}{suffix!s}"
+
+        # Add new value_info objects for these new tensors.
+        orig_value_info = self.orig_graph_inputs.get(orig_name) or self.orig_graph_outputs[orig_name]
+        value_info_proto = onnx.ValueInfoProto()
+        value_info_proto.CopyFrom(orig_value_info)
+        value_info_proto.name = new_name
+        self.new_value_infos.append(value_info_proto)
+
+        self.updated_io_names[orig_name] = new_name
+        return self.updated_io_names[orig_name]
+
+
+def update_io_to_channel_last(
+    model: onnx.ModelProto,
+    inputs_to_update: list[str] | None,
+    outputs_to_update: list[str] | None,
+    transpose_node_name_prefix: str = "Transpose_channel_",
+    transpose_node_name_start_suffix: int = 0,
+):
+    inputs_to_update = set(inputs_to_update or [])
+    outputs_to_update = set(outputs_to_update or [])
+
+    if not inputs_to_update and not outputs_to_update:
+        return
+
+    graph = model.graph
+    orig_graph_inputs = {ginput.name: ginput for ginput in graph.input}
+    orig_graph_outputs = {goutput.name: goutput for goutput in graph.output}
+
+    # Check that the user passed in actual input and output names.
+    for input_name in inputs_to_update:
+        if input_name not in orig_graph_inputs:
+            raise ValueError(f"{input_name} is not a graph input")
+
+    for output_name in outputs_to_update:
+        if output_name not in orig_graph_outputs:
+            raise ValueError(f"{output_name} is not a graph output")
+
+    orig_tensor_names = set()
+    orig_tensor_names.update(set(orig_graph_inputs))
+    orig_tensor_names.update(set(orig_graph_outputs))
+    orig_tensor_names.update(input_name for node in graph.node for input_name in node.input if input_name)
+
+    # Maps original input (or output) name to its updated name used within the graph.
+    io_map = InputOutputNameMap(orig_tensor_names, orig_graph_inputs, orig_graph_outputs)
+
+    # Update each node's inputs/outputs to use the transposed versions.
+    for node in graph.node:
+        for i in range(len(node.input)):
+            if node.input[i] and node.input[i] in inputs_to_update:
+                node.input[i] = io_map.get_new_name(node.input[i])
+            elif node.input[i] and node.input[i] in outputs_to_update:
+                node.input[i] = io_map.get_new_name(node.input[i])
+
+        for i in range(len(node.output)):
+            if node.output[i] in outputs_to_update:
+                node.output[i] = io_map.get_new_name(node.output[i])
+
+    # Update graph inputs to channel-last and a Transpose (to channel-first) after each.
+    for g_input_name in inputs_to_update:
+        g_input = orig_graph_inputs[g_input_name]
+
+        if not g_input.type.HasField("tensor_type") or not g_input.type.tensor_type.HasField("shape"):
+            raise ValueError(f"Expected input {g_input.name} to have a tensor_type with a shape")
+
+        input_shape = g_input.type.tensor_type.shape
+        input_rank = len(input_shape.dim)
+
+        if input_rank < 3:
+            raise ValueError(f"Expected input {g_input.name} to be of rank >= 3")
+
+        channel_dim = onnx.TensorShapeProto.Dimension()
+        channel_dim.CopyFrom(input_shape.dim[1])
+        for i in range(1, input_rank - 1):
+            input_shape.dim[i].CopyFrom(input_shape.dim[i + 1])
+        input_shape.dim[input_rank - 1].CopyFrom(channel_dim)
+
+        transpose_perm = list(range(input_rank))
+        for i in range(input_rank):
+            transpose_perm[i] = i if i < 1 else i - 1
+        transpose_perm[1] = input_rank - 1
+
+        transpose_node = onnx.helper.make_node(
+            "Transpose",
+            name=f"{transpose_node_name_prefix}{transpose_node_name_start_suffix!s}",
+            inputs=[g_input.name],
+            outputs=[io_map.get_new_name(g_input.name)],
+            perm=transpose_perm,
+        )
+        transpose_node_name_start_suffix += 1
+
+        graph.node.extend([transpose_node])
+
+    # Update graph outputs to channel-last and a Transpose (from channel-first) before each.
+    for g_output_name in outputs_to_update:
+        g_output = orig_graph_outputs[g_output_name]
+        if not g_output.type.HasField("tensor_type") or not g_output.type.tensor_type.HasField("shape"):
+            raise ValueError(f"Expected output {g_output.name} to have a tensor_type with a shape")
+
+        output_shape = g_output.type.tensor_type.shape
+        output_rank = len(output_shape.dim)
+
+        if output_rank < 3:
+            raise ValueError(f"Expected output {g_output.name} to be of rank >= 3")
+
+        channel_dim = onnx.TensorShapeProto.Dimension()
+        channel_dim.CopyFrom(output_shape.dim[1])
+        for i in range(1, output_rank - 1):
+            output_shape.dim[i].CopyFrom(output_shape.dim[i + 1])
+        output_shape.dim[output_rank - 1].CopyFrom(channel_dim)
+
+        transpose_perm = list(range(output_rank))
+        for i in range(output_rank):
+            transpose_perm[i] = i if i == 0 else i + 1
+        transpose_perm[output_rank - 1] = 1
+
+        transpose_node = onnx.helper.make_node(
+            "Transpose",
+            name=f"{transpose_node_name_prefix}{transpose_node_name_start_suffix!s}",
+            inputs=[io_map.get_new_name(g_output.name)],
+            outputs=[g_output.name],
+            perm=transpose_perm,
+        )
+        transpose_node_name_start_suffix += 1
+
+        graph.node.extend([transpose_node])
+
+    graph.value_info.extend(io_map.new_value_infos)
diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py
index 7c2fa4f65ae1b..e9affae7ac263 100644
--- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py
+++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py
@@ -15,6 +15,7 @@
 Q16_TYPES = {QuantType.QInt16, QuantType.QUInt16}
 Q8_TYPES = {QuantType.QInt8, QuantType.QUInt8}
 OP_TYPES_TO_EXCLUDE = {"Cast"}
+MODEL_SIZE_THRESHOLD = 2147483648  # Quant model should use external data if >= 2GB
 
 
 def get_qnn_qdq_config(
@@ -28,14 +29,21 @@ def get_qnn_qdq_config(
     if per_channel:
         raise ValueError("QNN EP does not yet support per-channel quantization.")
 
-    # Process model nodes to setup overrides.
-    model = onnx.load_model(model_input)
+    model = onnx.load_model(model_input, load_external_data=False)
 
     op_types = set()
     tensor_quant_overrides = {}
+    model_has_external_data = False
+    name_to_initializer = {}
 
-    name_to_initializer = {initializer.name: initializer for initializer in model.graph.initializer}
+    # Build map of initializers (name -> initializer) and
+    # check if the model has external data.
+    for initializer in model.graph.initializer:
+        name_to_initializer[initializer.name] = initializer
+        if onnx.external_data_helper.uses_external_data(initializer):
+            model_has_external_data = True
 
+    # Setup quantization overrides for specific operator types
     for node in model.graph.node:
         op_types.add(node.op_type)
 
@@ -89,5 +97,6 @@ def get_qnn_qdq_config(
         activation_type=activation_type,
         weight_type=weight_type,
         op_types_to_quantize=list(op_types.difference(OP_TYPES_TO_EXCLUDE)),
+        use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD),
         extra_options=extra_options,
     )
diff --git a/onnxruntime/python/tools/quantization/fusions/fusion.py b/onnxruntime/python/tools/quantization/fusions/fusion.py
index b54b421226f1a..4bdc5c26cc946 100644
--- a/onnxruntime/python/tools/quantization/fusions/fusion.py
+++ b/onnxruntime/python/tools/quantization/fusions/fusion.py
@@ -24,6 +24,9 @@ def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str):
         self.nodes_to_remove: list = []
         self.nodes_to_add: list = []
 
+        self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_"
+        self._new_node_name_suffix = None  # int|None used to create unique node names for the fused ops.
+
     def fuse(
         self,
         node: onnx.NodeProto,
@@ -57,6 +60,18 @@ def apply(self) -> bool:
 
         return graph_updated
 
+    def create_unique_node_name(self):
+        prefix = self._new_node_name_prefix
+
+        if self._new_node_name_suffix is None:
+            largest_suffix: int = self.model.get_largest_node_name_suffix(prefix)
+            self._new_node_name_suffix = largest_suffix + 1
+
+        new_name = f"{prefix}{self._new_node_name_suffix!s}"
+        self._new_node_name_suffix += 1
+
+        return new_name
+
     @staticmethod
     def is_safe_to_fuse_nodes(
         nodes_to_remove: list[onnx.NodeProto],
diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py b/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py
index a20d6dbffd7a7..42c4a11833641 100644
--- a/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py
+++ b/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py
@@ -112,7 +112,9 @@ def fuse_1(
             return False
 
         self.nodes_to_remove.extend(subgraph_nodes)
-        fused_node = onnx.helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output])
+        fused_node = onnx.helper.make_node(
+            "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[subgraph_output]
+        )
         fused_node.domain = "com.microsoft"
         self.nodes_to_add.append(fused_node)
         return True
@@ -173,11 +175,9 @@ def fuse_2(
             if not self.has_constant_input(sqrt_node, 2.0):
                 return False
 
-        root_node = self.model.get_parent(div, 0, output_name_to_node)
-        if root_node is None:
-            return False
+        subgraph_input = div.input[0]
 
-        if root_node.output[0] not in mul.input:
+        if subgraph_input not in mul.input:
             return False
 
         subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
@@ -188,7 +188,9 @@ def fuse_2(
             return False
 
         self.nodes_to_remove.extend(subgraph_nodes)
-        fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]])
+        fused_node = onnx.helper.make_node(
+            "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[mul.output[0]]
+        )
         fused_node.domain = "com.microsoft"
         self.nodes_to_add.append(fused_node)
         return True
@@ -239,9 +241,8 @@ def fuse_3(
         if i < 0:
             return False
 
-        root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node)
-        if root_node is None:
-            return False
+        root_input_index = 1 - i
+        subgraph_input = first_mul.input[root_input_index]
 
         if mul_half.output[0] not in input_name_to_nodes:
             return False
@@ -250,7 +251,7 @@ def fuse_3(
             return False
         last_mul = children[0]
 
-        if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]):
+        if not (last_mul.input[0] == subgraph_input or last_mul.input[1] == subgraph_input):
             return False
 
         subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
@@ -263,7 +264,9 @@ def fuse_3(
             return False
 
         self.nodes_to_remove.extend(subgraph_nodes)
-        fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]])
+        fused_node = onnx.helper.make_node(
+            "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[last_mul.output[0]]
+        )
         fused_node.domain = "com.microsoft"
         self.nodes_to_add.append(fused_node)
         return True
diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py
index d7fb89236d3d2..7d58c1c180822 100644
--- a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py
+++ b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py
@@ -127,6 +127,7 @@ def fuse(
 
         normalize_node = onnx.helper.make_node(
             "LayerNormalization",
+            name=self.create_unique_node_name(),
             inputs=[reduce_mean_node.input[0], weight_input, bias_input],
             outputs=[last_add_node.output[0]],
         )
diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
index 3e9f9a6544a71..f4bcd508960a1 100644
--- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
+++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
@@ -65,7 +65,7 @@ def __init__(
         self,
         calibration_data_reader: CalibrationDataReader,
         percdamp=0.01,
-        blocksize=128,
+        block_size=128,
         actorder=False,
         mse=False,
         perchannel=True,
@@ -79,7 +79,7 @@ def __init__(
                 a calibration data reader. It enumerates calibration data and generates inputs for the original model.
             percdamp:
                 percent of the average Hessian diagonal to use for dampening.
-            blocksize (int, optional):
+            block_size (int, optional):
                 channel number in one block to execute a GPTQ quantization iteration.
             actorder (bool, optional):
                 whether rearrange Hessian matrix considering the diag's value.
@@ -93,42 +93,285 @@ def __init__(
         )
         self.calibration_data_reader = calibration_data_reader
         self.percdamp = percdamp
-        self.blocksize = blocksize
+        self.block_size = block_size
         self.actorder = actorder
         self.mse = mse
         self.perchannel = perchannel
 
 
-class MatMul4BitsQuantizer:
-    """Perform 4b quantization of constant MatMul weights"""
+class HQQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
+    def __init__(
+        self,
+        block_size=128,
+        bits=4,
+        axis=1,
+    ):
+        """
+        This is a class for HQQ algorithm Weight Only Quant Configuration.
+        HQQ algorithm quant weight without needing calibrate data.
 
+        Args:
+            block_size (int, optional):
+                channel number in one block to execute a GPTQ quantization iteration.
+            bits (int, optional):
+                how many bits to represent weight.
+            axis (int, optional):
+                0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf
+        """
+        super().__init__(
+            algorithm="HQQ",
+        )
+        self.block_size = block_size
+        self.bits = bits
+        self.axis = axis
+
+
+class DefaultWeightOnlyQuantConfig(WeightOnlyQuantConfig):
     def __init__(
         self,
-        model: ModelProto | str,
-        block_size: int,
-        is_symmetric: bool,
+        block_size: int = 128,
+        is_symmetric: bool = False,
         accuracy_level: int | None = None,
-        nodes_to_exclude=None,
-        algo_config: WeightOnlyQuantConfig = None,
     ):
-        if nodes_to_exclude is None:
-            nodes_to_exclude = []
-        self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model)
-        self.model_path = model if isinstance(model, str) else None
+        super().__init__(algorithm="DEFAULT")
         self.block_size = block_size
         self.is_symmetric = is_symmetric
+        self.bits = 4
         self.accuracy_level = accuracy_level
-        self.nodes_to_exclude = set(nodes_to_exclude)
-        self.algo_config = algo_config
 
+
+def is_divisible(val1, val2):
+    return int(val2 * np.ceil(val1 / val2)) == val1
+
+
+class HQQWeightOnlyQuantizer:
+    def __init__(
+        self,
+        config: HQQWeightOnlyQuantConfig,
+    ):
+        self.config = config
+
+    # Proximal solver || weight - dequantize(quantize(weight))||_p^p
     @staticmethod
-    def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
-        for gid in range(len(graph_path) - 1, -1, -1):
-            graph = graph_path[gid]
-            for tensor in graph.initializer:
-                if tensor.name == name:
-                    return tensor, graph
-        return None, None
+    def optimize_weights(
+        tensor,
+        scale,
+        zero,
+        min_max: list[int],
+        axis: int = 0,
+        opt_params: dict = None,  # noqa: RUF013
+        verbose=False,
+    ):
+        import torch
+
+        opt_params = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} if opt_params is None else opt_params
+        lp_norm, beta, kappa, iters = (
+            opt_params["lp_norm"],
+            opt_params["beta"],
+            opt_params["kappa"],
+            opt_params["iters"],
+        )
+
+        dtype = torch.float16 if tensor.is_cuda else torch.float32
+        w_f = tensor.to(dtype)
+        scale = scale.to(dtype)
+        zero = zero.to(dtype)
+
+        if lp_norm == 1:
+
+            def shrink_op(x, beta):
+                return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)
+
+        else:
+
+            def shrink_op(x, beta, p=lp_norm):
+                return torch.sign(x) * torch.nn.functional.relu(
+                    torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x) + 1e-8, p - 1)
+                )
+
+        best_error = 1e4
+        for i in range(iters):
+            w_q = torch.round(w_f * scale + zero).clamp(min_max[0], min_max[1])
+            w_r = (w_q - zero) / scale
+            w_e = shrink_op(w_f - w_r, beta)
+            zero = torch.mean(w_q - (w_f - w_e) * scale, axis=axis, keepdim=True)
+            beta *= kappa
+
+            current_error = float(torch.abs(w_f - w_r).mean())
+            if verbose:
+                print(i, np.round(current_error, 6))
+            if current_error < best_error:
+                best_error = current_error
+            else:
+                break
+
+        del w_f, w_q, w_r, w_e
+
+        return scale, zero
+
+    @staticmethod
+    def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits):
+        if pack_tensor.shape[0] == ori_int_tensor.shape[0]:
+            ori_int_tensor = ori_int_tensor.T
+            pack_tensor = pack_tensor.T
+        if bits in [2, 4, 8]:
+            compress_ratio = pack_tensor.element_size() * 8 // bits
+            for j in range(compress_ratio):
+                pack_tensor[0:] |= ori_int_tensor[j::compress_ratio] << (bits * (j))
+        else:
+            raise NotImplementedError("Only 2,4,8 bits are supported.")
+
+    # from Official implementation of Half-Quadratic Quantization (HQQ)
+    def quantize_internal(
+        self, tensor, bits=4, channel_wise=True, group_size=64, optimize=True, round_zero=True, axis=1
+    ):
+        import torch
+
+        weight = tensor.float()
+        ori_shape = weight.shape
+
+        pad_len = (group_size - ori_shape[axis] % group_size) % group_size
+        if axis == 1:
+            weight = torch.nn.functional.pad(weight, (0, pad_len), "constant", 0)
+        else:
+            weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_len), "constant", 0)
+        shape = weight.shape
+
+        # Reshape for grouping
+        if (group_size is not None) and channel_wise:
+            weight = weight.reshape([-1, group_size]) if (axis == 1) else weight.reshape([group_size, -1])
+
+        # Get min/max values
+        if channel_wise is False:
+            _min, _max = weight.min(), weight.max()
+            optimize = False
+        else:
+            _min = weight.min(axis=axis, keepdim=True)[0]
+            _max = weight.max(axis=axis, keepdim=True)[0]
+
+        max_v = 2**bits - 1
+        min_v = 0
+        min_max = [min_v, max_v]
+
+        # Note: here we work with the inverse of the scale to avoid division and quantize instead via weight*scale + zero, the scale is inverted later on.
+        # clamp to avoid half-precision problems
+        scale = (max_v / (_max - _min)).clamp(max=2e4)
+        #!!!!!!!!!!!!!!!
+        min_max_axis = _max - _min
+        if (min_max_axis == 0).sum().item() > 0:
+            min_max_axis[min_max_axis == 0] = max_v
+            scale = (max_v / min_max_axis).clamp(max=2e4)
+        zero = -_min * scale
+
+        if round_zero:
+            zero = torch.round(zero)
+
+        # Fine-tune weights
+        if optimize:
+            scale, zero = self.optimize_weights(tensor=weight, scale=scale, zero=zero, min_max=min_max, axis=axis)
+
+        # Quantize
+        # Necessary for fake quantization backprop
+        w_q = torch.round(weight * scale + zero).clamp(min_max[0], min_max[1])
+        w_q = w_q.reshape(shape).int()
+
+        scale = 1.0 / scale
+        if axis == 1:
+            scale = scale.reshape(shape[0], -1)
+            zero = zero.reshape(shape[0], -1)
+        else:
+            scale = scale.reshape(-1, shape[-1])
+            zero = zero.reshape(-1, shape[-1])
+        # cleanup
+        del weight, _min, _max
+
+        return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype)
+
+    def quantize(self, node: NodeProto, graph_stack: list[GraphProto]):
+        """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""
+        if node.op_type != "MatMul":
+            return node  # only care about MatMul for now
+        import torch
+
+        logger.info(f"start to quantize {node.name} ...")
+        inputB = node.input[1]  # noqa: N806
+        b_pb, bs_graph = get_initializer(inputB, graph_stack)
+        if b_pb is None:
+            logger.info("MatMul doesn't have const weight. Skip to quantize")
+            return node  # only care about constant weight
+
+        b_array = onnx.numpy_helper.to_array(b_pb)
+        if len(b_array.shape) != 2:
+            logger.info("MatMul weight is not 2D. Skip to quantize")
+            return node  # can only process 2-D matrix
+        b_array_torch = torch.from_numpy(b_array)
+        if torch.cuda.is_available():
+            b_array_torch = b_array_torch.cuda()
+        quant_weight_torch, scales_torch, zero_points_torch = self.quantize_internal(
+            b_array_torch.T, bits=self.config.bits, group_size=self.config.block_size
+        )
+        quant_weight_torch = quant_weight_torch.contiguous()
+        scales_torch = scales_torch.contiguous()
+        zero_points_torch = zero_points_torch.contiguous()
+
+        packed_torch = torch.zeros(
+            (quant_weight_torch.shape[0], quant_weight_torch.shape[1] // 2),
+            dtype=torch.uint8,
+            device=quant_weight_torch.device,
+        )
+        self.pack_on_row_fast_248bit(packed_torch, quant_weight_torch, self.config.bits)
+        scales = scales_torch.cpu().numpy()
+        zero_points = zero_points_torch.cpu().numpy()
+        b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy())
+        b_quant.name = b_pb.name + "_Q4"
+        for input in bs_graph.input:
+            if input.name == inputB:
+                bs_graph.input.remove(input)
+                break
+
+        scales_tensor = onnx.numpy_helper.from_array(scales)
+        scales_tensor.name = b_pb.name + "_scales"
+        bs_graph.initializer.extend([b_quant, scales_tensor])
+
+        input_names = [node.input[0], b_quant.name, scales_tensor.name]
+        zp_tensor = onnx.numpy_helper.from_array(zero_points)
+        zp_tensor.name = b_pb.name + "_zero_points"
+        bs_graph.initializer.extend([zp_tensor])
+        input_names.append(zp_tensor.name)
+
+        kwargs = {}
+        rows, cols = b_array.shape
+        kwargs["K"] = rows
+        kwargs["N"] = cols
+        kwargs["bits"] = self.config.bits
+        kwargs["block_size"] = self.config.block_size
+
+        matmul_q4_node = onnx.helper.make_node(
+            "MatMulNBits",
+            inputs=input_names,
+            outputs=[node.output[0]],
+            name=node.name + "_Q4" if node.name else "",
+            domain="com.microsoft",
+            **kwargs,
+        )
+
+        logger.info(f"complete quantization of {node.name} ...")
+
+        return matmul_q4_node
+
+
+def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
+    for gid in range(len(graph_path) - 1, -1, -1):
+        graph = graph_path[gid]
+        for tensor in graph.initializer:
+            if tensor.name == name:
+                return tensor, graph
+    return None, None
+
+
+class DefaultWeightOnlyQuantizer:
+    def __init__(self, config: DefaultWeightOnlyQuantConfig):
+        self.config = config
 
     def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray:
         """4b quantize fp32 weight to a blob"""
@@ -137,7 +380,7 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray:
             raise ValueError("Current int4 block quantization only supports 2D tensors!")
         rows, cols = fp32weight.shape
 
-        block_size = self.block_size
+        block_size = self.config.block_size
         blob_size = block_size // 2
         k_blocks = (rows + block_size - 1) // block_size
         padded_rows = k_blocks * block_size
@@ -149,23 +392,19 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray:
         packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
         scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype)
         zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8")
-        quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.is_symmetric)
+        quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric)
 
         return (packed, scales, zero_point)
 
-    def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto:
+    def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto:
         """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""
 
         if node.op_type != "MatMul":
             return node  # only care about MatMul for now
 
         logger.info(f"start to quantize {node.name} ...")
-        if node.name in self.nodes_to_exclude:
-            logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
-            return node
-
         inputB = node.input[1]  # noqa: N806
-        B, Bs_graph = MatMul4BitsQuantizer.__get_initializer(inputB, graph_stack)  # noqa: N806
+        B, Bs_graph = get_initializer(inputB, graph_stack)  # noqa: N806
         if B is None:
             logger.info("MatMul doesn't have const weight. Skip to quantize")
             return node  # only care about constant weight
@@ -188,7 +427,7 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto])
         Bs_graph.initializer.extend([B_quant, scales_tensor])
 
         input_names = [node.input[0], B_quant.name, scales_tensor.name]
-        if not self.is_symmetric:
+        if not self.config.is_symmetric:
             zp_tensor = onnx.numpy_helper.from_array(zero_points)
             zp_tensor.name = B.name + "_zero_points"
             Bs_graph.initializer.extend([zp_tensor])
@@ -199,8 +438,8 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto])
         kwargs["K"] = rows
         kwargs["N"] = cols
         kwargs["bits"] = 4
-        kwargs["block_size"] = self.block_size
-        if self.accuracy_level is not None:
+        kwargs["block_size"] = self.config.block_size
+        if self.config.accuracy_level is not None:
             kwargs["accuracy_level"] = self.accuracy_level
 
         matmul_q4_node = onnx.helper.make_node(
@@ -216,6 +455,38 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto])
 
         return matmul_q4_node
 
+
+class MatMul4BitsQuantizer:
+    """Perform 4b quantization of constant MatMul weights"""
+
+    def __init__(
+        self,
+        model: ModelProto | str,
+        block_size: int = 128,
+        is_symmetric: bool = False,
+        accuracy_level: int | None = None,
+        nodes_to_exclude=None,
+        algo_config: WeightOnlyQuantConfig = None,
+    ):
+        if nodes_to_exclude is None:
+            nodes_to_exclude = []
+        self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model)
+        self.model_path = model if isinstance(model, str) else None
+        self.block_size = block_size
+        self.is_symmetric = is_symmetric
+        self.accuracy_level = accuracy_level
+        self.nodes_to_exclude = set(nodes_to_exclude)
+        self.node_quantizer = None
+        if algo_config is None:
+            algo_config = DefaultWeightOnlyQuantConfig(
+                block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level
+            )
+        self.algo_config = algo_config
+        if algo_config.algorithm == "HQQ":
+            self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config)
+        elif algo_config.algorithm == "DEFAULT":
+            self.node_quantizer = DefaultWeightOnlyQuantizer(self.algo_config)
+
     def _process_subgraph(self, graph_stack: list[GraphProto]):
         new_nodes = []
         graph = graph_stack[-1]
@@ -246,8 +517,15 @@ def _process_subgraph(self, graph_stack: list[GraphProto]):
                 node = onnx.helper.make_node(  # noqa: PLW2901
                     node.op_type, node.input, node.output, name=node.name, **kwargs
                 )
-
-            new_nodes.append(self._q4_matmul_node_weight(node, graph_stack))
+            out_node = None
+            if node.name in self.nodes_to_exclude:
+                logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
+                out_node = node
+            elif self.algo_config is not None and self.algo_config.algorithm == "HQQ":
+                out_node = self.node_quantizer.quantize(node, graph_stack)
+            else:
+                out_node = self.node_quantizer.quantize(node, graph_stack)
+            new_nodes.append(out_node)
 
         graph.ClearField("node")
         graph.node.extend(new_nodes)
@@ -300,7 +578,7 @@ def inc_dataloader():
             from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize
 
             kwargs["percdamp"] = self.algo_config.percdamp
-            kwargs["blocksize"] = self.algo_config.blocksize
+            kwargs["blocksize"] = self.algo_config.block_size
             kwargs["actorder"] = self.algo_config.actorder
             kwargs["mse"] = self.algo_config.mse
             kwargs["perchannel"] = self.algo_config.perchannel
@@ -316,7 +594,7 @@ def inc_dataloader():
         logger.info(f"complete quantization of model with {algorithm} algorithm.")
 
     def process(self):
-        if self.algo_config is None:
+        if self.algo_config.algorithm in ["HQQ", "DEFAULT"]:
             # use a stack to keep track of sub-graphs
             graph_stack = [self.model.graph()]
             opset_import = self.model.opset_import()
@@ -327,7 +605,6 @@ def process(self):
                     has_ms_domain = True
             if not has_ms_domain:
                 opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])
-
             self._process_subgraph(graph_stack)
             self.model.clean_initializers()
         else:
@@ -349,6 +626,10 @@ def process(self):
             self.int4_quant_algo()
 
 
+def ort_convert_str_to_bool(value):
+    return value.lower() in ("true", "1")
+
+
 def parse_args():
     parser = argparse.ArgumentParser(
         description="""Blockwise int4 quantization for MatMul 2D weight matrices.
@@ -362,11 +643,22 @@ def parse_args():
     parser.add_argument("--input_model", required=True, help="Path to the input model file")
     parser.add_argument("--output_model", required=True, help="Path to the output model file")
     parser.add_argument("--block_size", required=False, default=32, type=int, help="Block size for quantization")
+    parser.add_argument(
+        "--quant_method",
+        default="default",
+        type=str,
+        choices=["default", "hqq"],
+        help="the algorithm used to quantize weight",
+    )
+    parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight")
     parser.add_argument(
         "--symmetric",
         required=False,
         default=True,
-        type=bool,
+        const=True,
+        nargs="?",
+        type=ort_convert_str_to_bool,
+        choices=[True, False],
         help="Indicate whether to quantize the model symmetrically",
     )
     parser.add_argument(
@@ -404,12 +696,24 @@ def parse_args():
         raise Exception(f"file {output_model_path} already exists")
 
     model = onnx.load(input_model_path)
+    if args.quant_method == "hqq":
+        quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits)
+    elif args.quant_method == "default":
+        quant_config = DefaultWeightOnlyQuantConfig(
+            block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level
+        )
+    elif args.quant_method == "rtn":
+        quant_config = RTNWeightOnlyQuantConfig()
+    elif args.quant_method == "gptq":
+        quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size)
+    else:
+        raise ValueError(f"Unsupported quantization method: {args.quant_method}")
+
     quant = MatMul4BitsQuantizer(
         model=model,
-        block_size=args.block_size,
-        is_symmetric=args.symmetric,
         accuracy_level=args.accuracy_level,
         nodes_to_exclude=args.nodes_to_exclude,
+        algo_config=quant_config,
     )
     quant.process()
     quant.model.save_model_to_file(output_model_path, True)
diff --git a/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py
index 951746a089305..2bf47fe1680e9 100644
--- a/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py
+++ b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py
@@ -199,14 +199,14 @@ def parse_args():
         "--quant_type",
         required=False,
         default=1,
-        options=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4],
+        choices=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4],
         help="Quantization data type. 0: FP4, 1: NF4",
     )
     parser.add_argument(
         "--block_size",
         required=False,
         default=64,
-        description="Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64",
+        help="Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64",
     )
     parser.add_argument("-v", "--verbose", required=False, action="store_true")
     parser.set_defaults(verbose=False)
diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py
index 4591c9c950e6e..716dd1eacec6a 100644
--- a/onnxruntime/python/tools/quantization/onnx_model.py
+++ b/onnxruntime/python/tools/quantization/onnx_model.py
@@ -79,11 +79,7 @@ def _clean_initializers_helper(graph, model):
                 graph.input.remove(name_to_input[initializer.name])
             except StopIteration:
                 if model.ir_version < 4:
-                    print(
-                        "Warning: invalid weight name {} found in the graph (not a graph input)".format(
-                            initializer.name
-                        )
-                    )
+                    print(f"Warning: invalid weight name {initializer.name} found in the graph (not a graph input)")
 
     requesting_tensor_names.difference_update(input.name for input in graph.input)
 
@@ -283,6 +279,23 @@ def find_node_by_name(self, node_name, new_nodes_list, graph):
         node = find_by_name(node_name, graph_nodes_list)
         return node
 
+    def get_largest_node_name_suffix(self, node_name_prefix):
+        """
+        Gets the largest node name (int) suffix for all node names that begin with `node_name_prefix`.
+        Example: for nodes my_prefix_0 and my_prefix_3, this method returns 3.
+        """
+        suffix = -1
+
+        for node in self.model.graph.node:
+            if node.name and node.name.startswith(node_name_prefix):
+                try:
+                    index = int(node.name[len(node_name_prefix) :])
+                    suffix = max(index, suffix)
+                except ValueError:
+                    continue
+
+        return suffix
+
     def find_nodes_by_initializer(self, graph, initializer):
         """
         Find all nodes with given initializer as an input.
diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py
index a72d21c03a8a6..e2044db04303d 100644
--- a/onnxruntime/python/tools/quantization/onnx_quantizer.py
+++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py
@@ -4,9 +4,7 @@
 # license information.
 # --------------------------------------------------------------------------
 import logging
-from typing import Any, Dict
 
-import numpy as np
 import onnx
 import onnx.numpy_helper
 from onnx import onnx_pb as onnx_proto
@@ -17,57 +15,25 @@
     # old version of onnx.
     to_array_extended = None
 
-from .calibrate import TensorData
+from .base_quantizer import BaseQuantizer
 from .onnx_model import ONNXModel
 from .quant_utils import (
-    ONNX_TYPE_TO_NP_TYPE,
     TENSOR_NAME_QUANT_SUFFIX,
     QuantizationMode,
     QuantizedValue,
-    QuantizedValueType,
-    QuantType,
     __producer__,
     __version__,
     add_infer_metadata,
     attribute_to_kwarg,
-    compute_scale_zp,
-    compute_scale_zp_float8,
     find_by_name,
-    get_qmin_qmax_for_qType,
     get_qrange_for_qType,
-    model_has_infer_metadata,
     ms_domain,
-    quantize_data,
-    quantize_nparray,
     save_and_reload_model_with_shape_infer,
-    tensor_proto_to_array,
 )
 from .registry import CreateOpQuantizer
 
 
-class QuantizationParams:
-    def __init__(self, **data: Dict[str, Any]):
-        self.data = {}
-        for k, v in data.items():
-            if not isinstance(k, str):
-                raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.")
-            if not isinstance(v, (int, str, np.ndarray)):
-                raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.")
-            if k == "scale" and v.dtype not in (np.float32, np.float16):
-                raise ValueError(f"scale must a float32 or float16 numpy element but is {v.dtype} for k={k!r}")
-            self.data[k] = v
-
-    def __iter__(self):
-        yield from self.data
-
-    def __getitem__(self, key):
-        return self.data[key]
-
-    def __len__(self):
-        return len(self.data)
-
-
-class ONNXQuantizer:
+class ONNXQuantizer(BaseQuantizer):
     def __init__(
         self,
         model,
@@ -83,13 +49,20 @@ def __init__(
         op_types_to_quantize,
         extra_options=None,
     ):
-        if not model_has_infer_metadata(model):
-            model = save_and_reload_model_with_shape_infer(model)
-        self.value_infos = {vi.name: vi for vi in model.graph.value_info}
-        self.value_infos.update({ot.name: ot for ot in model.graph.output})
-        self.value_infos.update({it.name: it for it in model.graph.input})
+        BaseQuantizer.__init__(
+            self,
+            model,
+            per_channel,
+            reduce_range,
+            weight_qType,
+            activation_qType,
+            tensors_range,
+            nodes_to_quantize,
+            nodes_to_exclude,
+            op_types_to_quantize,
+            extra_options,
+        )
 
-        self.model = ONNXModel(model)
         if not static:
             self.model.replace_gemm_with_matmul()
             # We need to update value_infos.
@@ -99,49 +72,12 @@ def __init__(
             self.value_infos.update({it.name: it for it in model.graph.input})
             self.model = ONNXModel(model)
 
-        self.per_channel = per_channel  # weight-pack per channel
-        self.reduce_range = reduce_range
         self.mode = mode  # QuantizationMode.Value
         self.static = static  # use static quantization for inputs.
-        self.fuse_dynamic_quant = False
+        self.fuse_dynamic_quant = self.opset_version > 10
 
-        self.extra_options = extra_options if extra_options else {}
-        self.enable_subgraph_quantization = (
-            "EnableSubgraph" in self.extra_options and self.extra_options["EnableSubgraph"]
-        )
-        self.force_quantize_no_input_check = (
-            "ForceQuantizeNoInputCheck" in self.extra_options and self.extra_options["ForceQuantizeNoInputCheck"]
-        )
         self.q_matmul_const_b_only = "MatMulConstBOnly" in self.extra_options and self.extra_options["MatMulConstBOnly"]
-        self.is_weight_symmetric = self.extra_options.get(
-            "WeightSymmetric", weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN)
-        )
-        self.is_activation_symmetric = self.extra_options.get("ActivationSymmetric", False)
-        self.min_real_range = self.extra_options.get("MinimumRealRange")
-
-        self.activation_qType = getattr(activation_qType, "tensor_type", activation_qType)
-        self.weight_qType = getattr(weight_qType, "tensor_type", weight_qType)
-        """
-            Dictionary specifying the min and max values for tensors. It has following format:
-                {
-                    "param_name": [min, max]
-                }
-            example:
-                {
-                    'Conv_3:0': [np.float32(0), np.float32(0.5)],
-                    'Conv_4:0': [np.float32(1), np.float32(3.5)]
-                }
-        """
-        if tensors_range is not None and any(map(lambda t: not isinstance(t, TensorData), tensors_range.values())):
-            raise TypeError(
-                f"tensors_range contains unexpected types {set(type(v) for v in tensors_range.values())}, not TensorData."
-            )
-        self.tensors_range = tensors_range
-        self.nodes_to_quantize = nodes_to_quantize  # specific nodes to quantize
-        self.nodes_to_exclude = nodes_to_exclude  # specific nodes to exclude
-        self.op_types_to_quantize = op_types_to_quantize
         self.new_nodes = []
-        self.parent = None
         self.graph_scope = "/"  # for human readable debug information
         self.tensor_names = {}  # in case the shape inference not totally working
         self.tensor_names.update({ot.name: 1 for ot in model.graph.output})
@@ -149,14 +85,9 @@ def __init__(
         for node in self.model.model.graph.node:
             self.tensor_names.update({output_name: 1 for output_name in node.output})
 
-        self.opset_version = self.check_opset_version()
-
         if self.mode not in QuantizationMode:
             raise ValueError(f"unsupported quantization mode {self.mode}")
 
-        self.tensor_quant_overrides = self._get_and_check_tensor_quant_overrides()
-        self.quantization_params = self.calculate_quantization_params()
-
         # QuantizeRange tensor name and zero tensor name for scale and zero point calculation.
         # Used when static is False
         self.fixed_qrange_uint8_name = "fixed_quantization_range_uint8"
@@ -166,94 +97,9 @@ def __init__(
         # For int8 data-type, zero point is always zero (respresented by fixed_zero_point_name tensor)
         self.fixed_zero_zp_name = "fixed_zero_zp"
 
-        # Map of all original value names to quantized value names
-        self.quantized_value_map = {}
         # some output from nodes will be quantized, yet itself should be treat as existing so
         # no dequantized will be applied when needed later
         self.generated_value_names = self.model.get_non_initializer_inputs()
-        # to store specified scale and zeropoint instead of calculated value, tensor_name->(scale, zeropoint)
-        self.used_scale_zp_map = {}
-
-    def _get_and_check_tensor_quant_overrides(self):
-        """
-        Get tensor quantization overrides and check correctness.
-        """
-        tensor_quant_overrides = self.extra_options.get("TensorQuantOverrides", {})
-
-        # Validate that compatible/valid overrides are provided.
-        if tensor_quant_overrides:
-            initializer_names = self.model.get_initializer_name_set()
-            value_info_names = set(self.value_infos.keys())
-            keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"}
-
-            for tensor_name, quant_overrides_list in tensor_quant_overrides.items():
-                if tensor_name not in initializer_names and tensor_name not in value_info_names:
-                    raise ValueError(f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model")
-
-                if not isinstance(quant_overrides_list, list):
-                    raise ValueError(f"Tensor quantization overrides for '{tensor_name}' are not in a list")
-
-                is_initializer = tensor_name in initializer_names
-                if not is_initializer and len(quant_overrides_list) > 1:
-                    raise ValueError(
-                        f"Tensor '{tensor_name}' has a list of per-channel overrides, but is not an initializer"
-                    )
-
-                quant_type = None
-                for index, quant_overrides in enumerate(quant_overrides_list):
-                    if not isinstance(quant_overrides, dict):
-                        raise ValueError(
-                            f"Tensor quantization overrides at index {index} for '{tensor_name}' are not in a dict"
-                        )
-
-                    # For per-channel quantization, all channels must use the same quantization type.
-                    # Therefore, if the user tries to override the quant_type for a channel, it must match in all
-                    # other channels.
-                    if index == 0:
-                        quant_type = quant_overrides.get("quant_type")
-                    elif quant_type != quant_overrides.get("quant_type"):
-                        raise ValueError(
-                            "Channel quantization types for tensor '{tensor_name}' do not match at index {index}."
-                        )
-
-                    has_scale = "scale" in quant_overrides
-                    has_zero_point = "zero_point" in quant_overrides
-
-                    if (has_scale and not has_zero_point) or (has_zero_point and not has_scale):
-                        raise ValueError(
-                            "Must provide both 'scale' and 'zero_point' if one of the overrides is provided"
-                        )
-
-                    if has_scale:
-                        for key in keys_unsupported_with_scale_zp:
-                            if key in quant_overrides:
-                                raise ValueError(
-                                    f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'"
-                                )
-
-        return tensor_quant_overrides
-
-    def get_per_tensor_quant_overrides(self, tensor_name):
-        quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{}])
-        num_overrides = len(quant_overrides_list)
-        if num_overrides > 1:
-            raise ValueError(
-                f"Expected tensor '{tensor_name}' to use per-tensor quantization overrides, "
-                f"but found {num_overrides} per-channel overrides."
-            )
-
-        return quant_overrides_list[0] if num_overrides > 0 else {}
-
-    def get_per_channel_quant_overrides(self, tensor_name, num_channels):
-        quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{} for i in range(num_channels)])
-
-        if len(quant_overrides_list) != num_channels:
-            raise ValueError(
-                f"Expected tensor '{tensor_name}' to have {num_channels} per-channel quantization overrides, "
-                f"but found {len(quant_overrides_list)} instead."
-            )
-
-        return quant_overrides_list
 
     # routines for subgraph support
     def quantize_subgraph(self, subgraph, graph_key):
@@ -321,46 +167,6 @@ def quantize_node_with_sub_graph(self, node):
             kwargs.update(kv)
         return onnx.helper.make_node(node.op_type, node.input, node.output, name=node.name, **kwargs)
 
-    def check_opset_version(self):
-        ai_onnx_domain = [
-            opset for opset in self.model.model.opset_import if not opset.domain or opset.domain == "ai.onnx"
-        ]
-        if len(ai_onnx_domain) != 1:
-            raise ValueError("Failed to find proper ai.onnx domain")
-        opset_version = ai_onnx_domain[0].version
-
-        if opset_version == 10:
-            logging.warning(
-                "The original model opset version is {}, which does not support node fusions. Please update the model to opset >= 11 for better performance.".format(
-                    opset_version
-                )
-            )
-            return 10
-
-        if opset_version < 10:
-            logging.warning(
-                "The original model opset version is {}, which does not support quantization. Please update the model to opset >= 11. Updating the model automatically to opset 11. Please verify the quantized model.".format(
-                    opset_version
-                )
-            )
-            self.model.model.opset_import.remove(ai_onnx_domain[0])
-            self.model.model.opset_import.extend([onnx.helper.make_opsetid("", 11)])
-            opset_version = 11
-
-        if opset_version < 19 and self.weight_qType == onnx_proto.TensorProto.FLOAT8E4M3FN:
-            logging.warning(
-                "The original model opset version is {}, which does not support quantization to float 8. "
-                "Please update the model to opset >= 19. Updating the model automatically to opset 19. "
-                "Please verify the quantized model.".format(opset_version)
-            )
-            self.model.model.opset_import.remove(ai_onnx_domain[0])
-            self.model.model.opset_import.extend([onnx.helper.make_opsetid("", 19)])
-            self.model.model.ir_version = 9
-            opset_version = 19
-
-        self.fuse_dynamic_quant = True
-        return opset_version
-
     def has_QDQ_nodes(self):  # noqa: N802
         """
         Detect if model already has QuantizeLinear or DequantizeLinear.
@@ -385,7 +191,7 @@ def add_new_nodes(self, nodes):
     def quantize_model(self):
         if self.has_QDQ_nodes():
             logging.warning(
-                "Please check if the model is already quantized."
+                "Please check if the model is already quantized. "
                 "Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly."
             )
 
@@ -427,20 +233,22 @@ def quantize_model(self):
 
         return self.model.model
 
-    def is_input_a_initializer(self, input_name):
-        initializer = find_by_name(input_name, self.model.initializer())
-        return initializer is not None
-
-    def is_per_channel(self):
-        return self.per_channel
-
-    def is_valid_quantize_weight(self, weight_name):
-        weight = find_by_name(weight_name, self.model.initializer())
-        if weight is not None:
-            return weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16)
-        if (not self.enable_subgraph_quantization) or (self.parent is None):
-            return False
-        return self.parent.is_valid_quantize_weight(weight_name)
+    def _get_default_tensor_type(self, tensor_name):
+        if "DefaultTensorType" in self.extra_options:
+            logging.info(
+                "get_tensor_type returns DefaultTensorType for tensor name %r, use %d",
+                tensor_name,
+                self.extra_options["DefaultTensorType"],
+            )
+            return self.extra_options["DefaultTensorType"]
+        raise RuntimeError(
+            f"Unable to find data type for weight_name={tensor_name!r}. "
+            f"shape_inference failed to return a type probably this node is "
+            f"from a different domain or using an input produced by such an operator. "
+            f"This may happen if you quantize a model already quantized. "
+            f"You may use extra_options `DefaultTensorType` to indicate "
+            f"the default weight type, usually `onnx.TensorProto.FLOAT`."
+        )
 
     def get_tensor_type(self, tensor_name, mandatory=False):
         weight = find_by_name(tensor_name, self.model.initializer())
@@ -450,11 +258,11 @@ def get_tensor_type(self, tensor_name, mandatory=False):
             vi = self.value_infos[tensor_name]
             if vi.type.HasField("tensor_type"):
                 if mandatory and vi.type.tensor_type.elem_type == 0:
-                    raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
+                    return self._get_default_tensor_type(tensor_name)
                 return vi.type.tensor_type.elem_type
         if (not self.enable_subgraph_quantization) or (self.parent is None):
             if mandatory:
-                raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
+                return self._get_default_tensor_type(tensor_name)
             return None
         otype = self.parent.is_valid_quantize_weight(tensor_name)
         if otype is not None:
@@ -464,7 +272,7 @@ def get_tensor_type(self, tensor_name, mandatory=False):
             if res is not None:
                 return res
         if mandatory:
-            raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}")
+            return self._get_default_tensor_type(tensor_name)
         return None
 
     def is_float_tensor(self, tensor_name):
@@ -492,22 +300,6 @@ def is_float_tensor(self, tensor_name):
         )
         return False
 
-    def should_quantize_node(self, node):
-        if (
-            self.nodes_to_quantize is not None
-            and len(self.nodes_to_quantize) != 0
-            and node.name not in self.nodes_to_quantize
-        ):
-            return False
-
-        if node.op_type not in self.op_types_to_quantize:
-            return False
-
-        if self.nodes_to_exclude is not None and node.name in self.nodes_to_exclude:
-            return False
-
-        return True
-
     def _get_dynamic_input_quantization_params(self, input_name, nodes_list, qType):
         """
         Create nodes for dynamic quantization of input and add them to nodes_list.
@@ -702,66 +494,6 @@ def _get_dynamic_input_quantization_params_uint8(self, input_name, nodes_list, i
 
         return input_scale_name, input_zp_name, [], []
 
-    def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=None):
-        """
-        Create initializers and inputs in the graph for zero point and scale of output.
-        Zero point and scale values are obtained from self.quantization_params if specified.
-            parameter param_name: Name of the quantization parameter.
-            return: result, scale_name, zero_point_name, scale_shape, zero_point_shape.
-        """
-        zero_point_type = self.activation_qType
-
-        if use_scale is None or use_zeropoint is None:
-            if self.quantization_params is None or param_name not in self.quantization_params:
-                logging.info(f'Quantization parameters for tensor:"{param_name}" not specified')
-                return False, "", "", "", ""
-
-            params = self.quantization_params[param_name]
-            if not isinstance(params, QuantizationParams):
-                raise TypeError(f"Unexpected type {type(params)} for {param_name!r}.")
-            if params is None or len(params) != 3:
-                raise ValueError(
-                    "Quantization parameters should contain zero point, scale, quant type. "
-                    f"Specified values for output {param_name}: {params}"
-                )
-
-            zero_point_values = np.array([params["zero_point"]])
-            if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16):
-                raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}")
-            scale_values = np.array([params["scale"]])
-            assert scale_values.dtype != np.float64
-            # zero_point_type = params["quant_type"]
-            assert zero_point_type == params["quant_type"]
-        else:
-            zero_point_values = np.array([use_zeropoint])
-            scale_values = np.array([use_scale])
-            params = self.quantization_params[param_name]
-            if "scale" in params:
-                dtype = params["scale"].dtype
-                scale_values = scale_values.astype(dtype)
-            assert scale_values.dtype != np.float64
-
-        zero_point_shape = []
-        zero_point_name = param_name + "_zero_point"
-        scale_shape = []
-        scale_name = param_name + "_scale"
-
-        # Add initializers
-        init_zp = onnx.helper.make_tensor(
-            zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist()
-        )
-        self.model.add_initializer(init_zp)
-        if scale_values.dtype == np.float32:
-            scale_type = onnx_proto.TensorProto.FLOAT
-        elif scale_values.dtype == np.float16:
-            scale_type = onnx_proto.TensorProto.FLOAT16
-        else:
-            raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}")
-        init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist())
-        self.model.add_initializer(init_scale)
-
-        return True, scale_name, zero_point_name, scale_shape, zero_point_shape
-
     def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=None, given_zp_name=None):
         """
         Given an input for a node (which is not a initializer), this function
@@ -825,19 +557,6 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N
         self.quantized_value_map[input_name] = QuantizedValue(input_name, output_name, scale_name, zp_name, qType)
         return [*nodes, qlinear_node]
 
-    def set_quant_scale_zp(self, tensor_name, value):
-        assert isinstance(value, tuple) and len(value) == 2, "value must be scale(float or float16) and zeropoint"
-        assert hasattr(value[0], "dtype")
-        assert tensor_name not in self.used_scale_zp_map, f"{tensor_name} has been setted before"
-        self.used_scale_zp_map[tensor_name] = value
-
-    def find_quant_scale_zp(self, input_name):
-        if input_name in self.used_scale_zp_map:
-            return self.used_scale_zp_map[input_name]
-        if self.parent is not None:
-            return self.parent.find_quantized_value(input_name)
-        return (None, None)
-
     def find_quantized_value(self, input_name):
         if input_name in self.quantized_value_map:
             return self.quantized_value_map[input_name]
@@ -845,102 +564,6 @@ def find_quantized_value(self, input_name):
             return self.parent.find_quantized_value(input_name)
         return None
 
-    def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0):
-        """
-        Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
-        """
-
-        # Handle case where bias already in quantization map
-        if bias_name in self.quantized_value_map:
-            return self.quantized_value_map[bias_name].q_name
-
-        # get scale for weight
-        weight_scale_name = self.quantized_value_map[weight_name].scale_name
-        weight_initializer = find_by_name(weight_scale_name, self.model.initializer())
-        weight_scale = tensor_proto_to_array(weight_initializer)
-
-        # get bias
-        bias_initializer = find_by_name(bias_name, self.model.initializer())
-        bias_data = tensor_proto_to_array(bias_initializer)
-        quantized_bias_name = bias_name + TENSOR_NAME_QUANT_SUFFIX
-
-        # get scale for input
-        if input_name in self.quantized_value_map:
-            input_scale_name = self.quantized_value_map[input_name].scale_name
-        elif input_name in self.quantization_params:
-            _, input_scale_name, _, _, _ = self._get_quantization_params(input_name)
-        else:
-            raise ValueError(f"Expected {input_name} to be in quantized value map for static quantization")
-
-        inputscale_initializer = find_by_name(input_scale_name, self.model.initializer())
-        input_scale = tensor_proto_to_array(inputscale_initializer)
-
-        # quantize bias
-        if self.weight_qType == onnx_proto.TensorProto.FLOAT8E4M3FN:
-            data = np.asarray(bias_data)
-            if data.dtype == np.float16:
-                node_qtype = onnx.TensorProto.FLOAT16
-            elif data.dtype == np.float32:
-                node_qtype = onnx.TensorProto.FLOAT
-            else:
-                raise TypeError(f"Only float16 or float32 are supported with float 8 but bias dtype is {data.dtype}.")
-            quantized_data = data.astype(np.float32)
-            bias_scale = np.array([1], dtype=quantized_data.dtype)
-            bias_scale_data = bias_scale.reshape(-1)
-            packed_bias_initializer = onnx.numpy_helper.from_array(quantized_data, quantized_bias_name)
-            self.model.initializer_extend([packed_bias_initializer])
-            node_type = "Cast"
-        else:
-            # calculate scale for bias
-            # TODO: This formula should be explained including why the scale is not estimated for the bias as well.
-            bias_scale = input_scale * weight_scale * beta
-
-            quantized_data = (np.asarray(bias_data) / bias_scale).round().astype(np.int32)
-
-            # update bias initializer
-            bias_np_data = np.asarray(quantized_data, dtype=np.int32).reshape(bias_initializer.dims)
-            packed_bias_initializer = onnx.numpy_helper.from_array(bias_np_data, quantized_bias_name)
-            self.model.initializer_extend([packed_bias_initializer])
-            bias_scale_data = np.asarray(bias_scale, dtype=np.float32).reshape(-1)
-            node_type = "DequantizeLinear"
-            node_qtype = self.weight_qType
-
-        # update scale initializer
-        quantized_bias_scale_name = quantized_bias_name + "_scale"
-        packed_bias_scale_initializer = onnx.numpy_helper.from_array(bias_scale_data, quantized_bias_scale_name)
-        self.model.initializer_extend([packed_bias_scale_initializer])
-
-        # update zero initializer
-        if self.weight_qType == onnx_proto.TensorProto.FLOAT8E4M3FN:
-            tensor_type = self.weight_qType
-        else:
-            tensor_type = onnx_proto.TensorProto.INT32
-
-        quantized_bias_zp_name = quantized_bias_name + "_zero_point"
-        if self.weight_qType == onnx_proto.TensorProto.FLOAT8E4M3FN:
-            packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, self.weight_qType, [1], [0.0])
-        elif self.is_per_channel():
-            bias_zp_data = np.zeros(bias_scale.shape, dtype=np.int32).reshape(-1)
-            packed_bias_zp_initializer = onnx.numpy_helper.from_array(bias_zp_data, quantized_bias_zp_name)
-        else:
-            packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, tensor_type, [], [0])
-        self.model.initializer_extend([packed_bias_zp_initializer])
-
-        assert bias_name not in self.quantized_value_map
-        quantized_value = QuantizedValue(
-            bias_name,
-            quantized_bias_name,
-            quantized_bias_scale_name,
-            quantized_bias_zp_name,
-            QuantizedValueType.Initializer,
-            0 if bias_scale_data.size > 1 else None,
-            node_type=node_type,
-            node_qtype=node_qtype,
-        )
-        self.quantized_value_map[bias_name] = quantized_value
-
-        return quantized_bias_name
-
     def contains_tensor(self, tensor_name):
         """
         only check for value info and newly generated tensor names, initializers are checked separately
@@ -1098,228 +721,6 @@ def __quantize_inputs(
 
         return quantized_input_names, zero_point_names, scale_names, nodes
 
-    def quantize_initializer(self, weight, qType, reduce_range=False, keep_float_weight=False):
-        """
-        :param weight: TensorProto initializer
-        :param qType: type to quantize to
-        :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point.
-                                  If keep_float_weight is False, quantize the weight, or don't quantize the weight.
-        :return: quantized weight name, zero point name, scale name
-        """
-        # Find if this input is already quantized
-        if weight.name in self.quantized_value_map:
-            quantized_value = self.quantized_value_map[weight.name]
-            return (
-                quantized_value.q_name,
-                quantized_value.zp_name,
-                quantized_value.scale_name,
-            )
-
-        q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX
-        zp_name = weight.name + "_zero_point"
-        scale_name = weight.name + "_scale"
-
-        # Quantize weight data. Use quantization overrides if provided by the user.
-        weight_data = tensor_proto_to_array(weight)
-        quant_overrides = self.get_per_tensor_quant_overrides(weight.name)
-        if "quant_type" in quant_overrides:
-            qType = quant_overrides["quant_type"].tensor_type  # noqa: N806
-
-        if "scale" in quant_overrides and "zero_point" in quant_overrides:
-            zero_point = np.array(quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[qType])
-            scale = np.array(quant_overrides["scale"])
-            q_weight_data = quantize_nparray(qType, weight_data.flatten(), scale, zero_point)
-            assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
-            assert (
-                zero_point.dtype != np.float32 and zero_point.dtype != np.float16
-            ), f"Unexpected dtype {zero_point.dtype}"
-            assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
-
-        else:
-            _, _, zero_point, scale, q_weight_data = quantize_data(
-                weight_data.flatten(),
-                qType,
-                quant_overrides.get("symmetric", self.is_weight_symmetric),
-                reduce_range=quant_overrides.get("reduce_range", self.reduce_range and reduce_range),
-                min_real_range=self.min_real_range,
-                rmin_override=quant_overrides.get("rmin"),
-                rmax_override=quant_overrides.get("rmax"),
-            )
-
-            assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
-            assert (
-                zero_point.dtype != np.float32 and zero_point.dtype != np.float16
-            ), f"Unexpected dtype {zero_point.dtype}"
-            assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
-
-        scale_dtype = weight.data_type
-        scale_initializer = onnx.helper.make_tensor(scale_name, scale_dtype, [], scale.reshape((-1,)).tolist())
-        zero_initializer = onnx.helper.make_tensor(zp_name, qType, [], zero_point.reshape((-1,)).tolist())
-        self.model.initializer_extend([scale_initializer, zero_initializer])
-
-        if not keep_float_weight:
-            if self.weight_qType == onnx_proto.TensorProto.FLOAT8E4M3FN:
-                q_weight_initializer = onnx.TensorProto()
-                q_weight_initializer.data_type = self.weight_qType
-                q_weight_initializer.dims.extend(weight.dims)
-                q_weight_initializer.name = q_weight_name
-                # Do not remove .flatten().copy() numpy is not clear about data persistence.
-                q_weight_initializer.raw_data = q_weight_data.flatten().copy().tobytes()
-                if to_array_extended is not None:
-                    # This test should not be needed but it helped catch some issues
-                    # with data persistence and tobytes.
-                    check = to_array_extended(q_weight_initializer)
-                    if check.shape != weight_data.shape or check.tobytes() != q_weight_data.tobytes():
-                        raise RuntimeError(
-                            f"The initializer of shape {weight_data.shape} could not be created, expecting "
-                            f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}"
-                            f"\nraw={str(q_weight_initializer)[:200]}."
-                        )
-            else:
-                q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape(
-                    weight.dims
-                )
-                q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name)
-            self.model.initializer_extend([q_weight_initializer])
-
-        # Log entry for this quantized weight
-        quantized_value = QuantizedValue(
-            weight.name,
-            q_weight_name,
-            scale_name,
-            zp_name,
-            QuantizedValueType.Initializer,
-            None,
-        )
-        self.quantized_value_map[weight.name] = quantized_value
-        return q_weight_name, zp_name, scale_name
-
-    def quantize_weight_per_channel(
-        self,
-        weight_name,
-        weight_qType,
-        channel_axis,
-        reduce_range=True,
-        keep_float_weight=False,
-    ):
-        # Find if this input is already quantized
-        if weight_name in self.quantized_value_map:
-            quantized_value = self.quantized_value_map[weight_name]
-            return (
-                quantized_value.q_name,
-                quantized_value.zp_name,
-                quantized_value.scale_name,
-            )
-
-        initializer = find_by_name(weight_name, self.model.initializer())
-        if initializer is None:
-            raise ValueError("{} is not an initializer", weight_name)
-
-        weights = tensor_proto_to_array(initializer)
-        channel_count = weights.shape[channel_axis]
-        quant_overrides_for_channels = self.get_per_channel_quant_overrides(weight_name, channel_count)
-
-        # If user provides per-channel quantization overrides, all channels must use the same quantization type.
-        # So, just use the first channel's type.
-        if "quant_type" in quant_overrides_for_channels[0]:
-            weight_qType = quant_overrides_for_channels[0]["quant_type"].tensor_type  # noqa: N806
-
-        zero_point_list = []
-        scale_list = []
-        quantized_per_channel_data_list = []
-        for i in range(channel_count):
-            per_channel_data = weights.take(i, channel_axis)
-            channel_quant_overrides = quant_overrides_for_channels[i]
-
-            if "scale" in channel_quant_overrides and "zero_point" in channel_quant_overrides:
-                zero_point = np.array(channel_quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[weight_qType])
-                scale = np.array(channel_quant_overrides["scale"])
-                quantized_per_channel_data = quantize_nparray(
-                    weight_qType, per_channel_data.flatten(), scale, zero_point
-                )
-                assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
-                assert (
-                    zero_point.dtype != np.float32 and zero_point.dtype != np.float16
-                ), f"Unexpected dtype {zero_point.dtype}"
-                assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
-                assert isinstance(
-                    quantized_per_channel_data, np.ndarray
-                ), f"Unexpected type {type(quantized_per_channel_data)}"
-
-            else:
-                symmetric = channel_quant_overrides.get(
-                    "symmetric",
-                    (
-                        self.is_weight_symmetric
-                        or weight_qType in (onnx_proto.TensorProto.INT8, onnx_proto.TensorProto.FLOAT8E4M3FN)
-                    ),
-                )
-                _, _, zero_point, scale, quantized_per_channel_data = quantize_data(
-                    per_channel_data.flatten(),
-                    weight_qType,
-                    symmetric,
-                    reduce_range=channel_quant_overrides.get("reduce_range", self.reduce_range and reduce_range),
-                    min_real_range=self.min_real_range,
-                    rmin_override=channel_quant_overrides.get("rmin"),
-                    rmax_override=channel_quant_overrides.get("rmax"),
-                )
-
-                assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
-                assert (
-                    zero_point.dtype != np.float32 and zero_point.dtype != np.float16
-                ), f"Unexpected dtype {zero_point.dtype}"
-                assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
-                assert isinstance(
-                    quantized_per_channel_data, np.ndarray
-                ), f"Unexpected type {type(quantized_per_channel_data)}"
-
-            zero_point_list.append(zero_point)
-            scale_list.append(scale)
-            quantized_per_channel_data_list.append(quantized_per_channel_data)
-
-        # combine per_channel_data into one
-        reshape_dims = list(weights.shape)  # deep copy
-        reshape_dims[channel_axis] = 1  # only one per channel for reshape
-        quantized_weights = np.asarray(quantized_per_channel_data_list[0]).reshape(reshape_dims)
-        for i in range(1, len(quantized_per_channel_data_list)):
-            channel_weights = np.asarray(quantized_per_channel_data_list[i]).reshape(reshape_dims)
-            quantized_weights = np.concatenate((quantized_weights, channel_weights), channel_axis)
-
-        q_weight_name = weight_name + TENSOR_NAME_QUANT_SUFFIX
-        zp_name = weight_name + "_zero_point"
-        scale_name = weight_name + "_scale"
-
-        quantized_value = QuantizedValue(
-            weight_name,
-            q_weight_name,
-            scale_name,
-            zp_name,
-            QuantizedValueType.Initializer,
-            None,
-        )
-        self.quantized_value_map[weight_name] = quantized_value
-
-        # Update packed weight, zero point, and scale initializers
-        zero_scale_shape = [initializer.dims[channel_axis]]
-        scale_initializer = onnx.helper.make_tensor(
-            scale_name, initializer.data_type, zero_scale_shape, np.hstack(scale_list).tolist()
-        )
-        zero_initializer = onnx.helper.make_tensor(
-            zp_name, weight_qType, zero_scale_shape, np.hstack(zero_point_list).tolist()
-        )
-
-        self.model.initializer_extend([scale_initializer, zero_initializer])
-
-        if not keep_float_weight:
-            quantized_weights = np.asarray(
-                quantized_weights,
-                dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[weight_qType],
-            ).reshape(initializer.dims)
-            q_weight_initializer = onnx.numpy_helper.from_array(quantized_weights, q_weight_name)
-            self.model.initializer_extend([q_weight_initializer])
-
-        return q_weight_name, zp_name, scale_name
-
     def _dequantize_value(self, value_name):
         """
         Given a value (input/output) which is quantized, add a DequantizeLinear node to dequantize
@@ -1332,9 +733,15 @@ def _dequantize_value(self, value_name):
         if (value_name in self.quantized_value_map) and (value_name not in self.generated_value_names):
             quantized_value = self.quantized_value_map[value_name]
             # Add DequantizeLinear Node for this input
+
             scale_init = find_by_name(quantized_value.scale_name, self.model.initializer())
-            # axis is not specified so scale_init must be a scalar.
-            assert onnx.numpy_helper.to_array(scale_init).size == 1
+
+            # In case we are working with subgraphs, the graph `producer_name` is set to `"onnx-quantizer"` in the `quantize_subgraph` method. In this case, the scale initializer may be on the top level graph, so the check below can not be done.
+            if self.model.model.producer_name != "onnx-quantizer" or (
+                self.model.model.producer_name == "onnx-quantizer" and scale_init is not None
+            ):
+                # axis is not specified so scale_init must be a scalar.
+                assert onnx.numpy_helper.to_array(scale_init).size == 1
 
             dqlinear_name = value_name + "_DequantizeLinear"
             dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph())
@@ -1364,52 +771,3 @@ def _dequantize_outputs(self):
             dequantize_node = self._dequantize_value(output.name)
             if dequantize_node is not None:
                 self.new_nodes.append(dequantize_node)
-
-    def calculate_quantization_params(self):
-        if self.tensors_range is None:
-            return
-
-        # adjust tensor_ranges for input of Clip and Relu node
-        for node in self.model.nodes():
-            if node.op_type not in ["Clip", "Relu"]:
-                continue
-            if self.is_activation_symmetric:
-                continue
-            if not self.should_quantize_node(node):
-                continue
-            if len(self.model.input_name_to_nodes()[node.input[0]]) != 1:
-                continue
-            if node.input[0] not in self.tensors_range or node.output[0] not in self.tensors_range:
-                continue
-            td = self.tensors_range[node.output[0]]
-            if not isinstance(td, TensorData):
-                raise TypeError(f"Unexpected type {type(td)} for {node.output[0]!r}.")
-            self.tensors_range[node.input[0]] = td
-
-        quantization_params = {}
-        for tensor_name in self.tensors_range:
-            td = self.tensors_range[tensor_name]
-            if not isinstance(td, TensorData):
-                raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.")
-
-            quant_overrides = self.get_per_tensor_quant_overrides(tensor_name)
-
-            quant_type = self.activation_qType
-            if "quant_type" in quant_overrides:
-                quant_type = quant_overrides["quant_type"].tensor_type
-
-            if "scale" in quant_overrides and "zero_point" in quant_overrides:
-                zero, scale = quant_overrides["zero_point"], quant_overrides["scale"]
-            elif quant_type == onnx.TensorProto.FLOAT8E4M3FN:
-                zero, scale = compute_scale_zp_float8(quant_type, td.avg_std[1])
-            else:
-                rmin = quant_overrides.get("rmin", td.range_value[0])
-                rmax = quant_overrides.get("rmax", td.range_value[1])
-                symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric)
-                reduce_range = quant_overrides.get("reduce_range", False)
-                qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric)
-                zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range)
-
-            quantization_params[tensor_name] = QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type)
-
-        return quantization_params
diff --git a/onnxruntime/python/tools/quantization/operators/concat.py b/onnxruntime/python/tools/quantization/operators/concat.py
index a4f359cf56847..57fcec9cd380b 100644
--- a/onnxruntime/python/tools/quantization/operators/concat.py
+++ b/onnxruntime/python/tools/quantization/operators/concat.py
@@ -30,7 +30,7 @@ def quantize(self):
             zero_point_names,
             scale_names,
             nodes,
-        ) = self.quantizer.quantize_activation(node, [*range(0, len(node.input))])
+        ) = self.quantizer.quantize_activation(node, [*range(len(node.input))])
         if not data_found or q_input_names is None:
             return super().quantize()
 
@@ -52,7 +52,7 @@ def quantize(self):
         qnode_name = node.name + "_quant" if node.name else ""
 
         qlconcat_inputs = [output_scale_name, output_zp_name]
-        for i in range(0, len(q_input_names)):
+        for i in range(len(q_input_names)):
             qlconcat_inputs.extend([q_input_names[i], scale_names[i], zero_point_names[i]])
         qlconcat_node = onnx.helper.make_node(
             "QLinearConcat", qlconcat_inputs, [quantized_output_value.q_name], qnode_name, **kwargs
diff --git a/onnxruntime/python/tools/quantization/operators/gemm.py b/onnxruntime/python/tools/quantization/operators/gemm.py
index 32fdb729635a8..d269c8fb47bd1 100644
--- a/onnxruntime/python/tools/quantization/operators/gemm.py
+++ b/onnxruntime/python/tools/quantization/operators/gemm.py
@@ -157,7 +157,5 @@ def quantize(self):
                 set_default_beta(self.node)
             else:
                 logging.warning(
-                    "Bias of Gemm node '{}' is not constant. Please exclude this node for better performance.".format(
-                        self.node.name
-                    )
+                    f"Bias of Gemm node '{self.node.name}' is not constant. Please exclude this node for better performance."
                 )
diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py
index 775a3e8b8b588..1875c552fab9c 100644
--- a/onnxruntime/python/tools/quantization/qdq_quantizer.py
+++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py
@@ -11,7 +11,7 @@
 from onnx import TensorProto
 from onnx import onnx_pb as onnx_proto
 
-from .onnx_quantizer import ONNXQuantizer
+from .base_quantizer import BaseQuantizer
 from .quant_utils import (
     DEQUANT_OP_NAME,
     QUANT_OP_NAME,
@@ -46,14 +46,12 @@ def __init__(self, tensor_type=QDQQuantTensorType.ACTIVATION, quant_para_provide
         self.data_type = data_type
 
 
-class QDQQuantizer(ONNXQuantizer):
+class QDQQuantizer(BaseQuantizer):
     def __init__(
         self,
         model,
         per_channel,
         reduce_range,
-        mode,
-        static,
         weight_qType,
         activation_qType,
         tensors_range,
@@ -62,13 +60,11 @@ def __init__(
         op_types_to_quantize,
         extra_options=None,
     ):
-        ONNXQuantizer.__init__(
+        BaseQuantizer.__init__(
             self,
             model,
             per_channel,
             reduce_range,
-            mode,
-            static,
             weight_qType,
             activation_qType,
             tensors_range,
@@ -116,7 +112,10 @@ def __init__(
         # if the activation or weight types are 16-bit integers.
         # TODO: Remove this override (and use only the 'UseQDQContribOps' option) if/when ONNX adds 16-bit support.
         int16_types = (TensorProto.UINT16, TensorProto.INT16)
-        if not self.qdq_op_domain and (self.activation_qType in int16_types or self.weight_qType in int16_types):
+        overrides_have_int16 = any(t in int16_types for t in self.tensor_quant_override_types)
+        if not self.qdq_op_domain and (
+            self.activation_qType in int16_types or self.weight_qType in int16_types or overrides_have_int16
+        ):
             logging.warning(
                 "ONNX QuantizeLinear and DequantizeLinear operators do not support 16-bit integer quantization types. "
                 f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to "
@@ -154,9 +153,7 @@ def _is_tensor_quantizable(self, tensor_name):
                 return True
         else:
             logging.warning(
-                "failed to infer the type of tensor: {}. Skip to quantize it. Please check if it is expected.".format(
-                    tensor_name
-                )
+                f"failed to infer the type of tensor: {tensor_name}. Skip to quantize it. Please check if it is expected."
             )
 
         return False
diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py
index 036f49b420734..131e55458fb86 100644
--- a/onnxruntime/python/tools/quantization/quant_utils.py
+++ b/onnxruntime/python/tools/quantization/quant_utils.py
@@ -276,7 +276,7 @@ def compute_scale_zp_float8(element_type, std):
             from onnx.reference.custom_element_types import float8e4m3fn
 
             zp_dtype = float8e4m3fn
-            all_values = [float8e4m3_to_float32(i) for i in range(0, 256)]
+            all_values = [float8e4m3_to_float32(i) for i in range(256)]
             values = numpy.array(
                 [f for f in all_values if not numpy.isnan(f) and not numpy.isinf(f)], dtype=numpy.float32
             )
@@ -530,7 +530,7 @@ def get_elem_index(elem_name, elem_list):
     Helper function to return index of an item in a node list
     """
     elem_idx = -1
-    for i in range(0, len(elem_list)):
+    for i in range(len(elem_list)):
         if elem_list[i] == elem_name:
             elem_idx = i
     return elem_idx
diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py
index 1bd2ef42151d0..9b0c15e4b4dde 100644
--- a/onnxruntime/python/tools/quantization/quantize.py
+++ b/onnxruntime/python/tools/quantization/quantize.py
@@ -479,7 +479,7 @@ def inc_dataloader():
         del dataloader
         model = sq.transform(extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True))
         sq_path = tempfile.TemporaryDirectory(prefix="ort.quant.")
-        model_input = Path(sq_path).joinpath("sq_model.onnx").as_posix()
+        model_input = Path(sq_path.name).joinpath("sq_model.onnx").as_posix()
         model.save(model_input)
         nodes_to_exclude.extend([i.name for i in model.model.graph.node if i.name not in orig_nodes])
         model = load_model_with_shape_infer(Path(model_input))  # use smooth quant model for calibration
@@ -523,8 +523,6 @@ def inc_dataloader():
             model,
             per_channel,
             reduce_range,
-            mode,
-            True,  # static
             weight_type,
             activation_type,
             tensors_range,
diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py
index 4b56bc1e8d828..8a911071864aa 100755
--- a/onnxruntime/python/tools/symbolic_shape_infer.py
+++ b/onnxruntime/python/tools/symbolic_shape_infer.py
@@ -282,7 +282,7 @@ def _add_suggested_merge(self, symbols, apply=False):
         # when nothing to map to, use the shorter one
         if map_to is None:
             if self.verbose_ > 0:
-                logger.warning("Potential unsafe merge between symbolic expressions: ({})".format(",".join(symbols)))
+                logger.warning("Potential unsafe merge between symbolic expressions: (%s)", ",".join(symbols))
             symbols_list = list(symbols)
             lens = [len(s) for s in symbols_list]
             map_to = symbols_list[lens.index(min(lens))]
@@ -335,10 +335,7 @@ def _merge_symbols(self, dims):
                     int_dim = is_int.index(1)
                     if self.verbose_ > 0:
                         logger.debug(
-                            "dim {} has been merged with value {}".format(
-                                unique_dims[:int_dim] + unique_dims[int_dim + 1 :],
-                                unique_dims[int_dim],
-                            )
+                            f"dim {unique_dims[:int_dim] + unique_dims[int_dim + 1 :]} has been merged with value {unique_dims[int_dim]}"
                         )
                     self._check_merged_dims(unique_dims, allow_broadcast=False)
                     return unique_dims[int_dim]
@@ -379,7 +376,7 @@ def _broadcast_shapes(self, shape1, shape2):
                     if self.auto_merge_:
                         self._add_suggested_merge([dim1, dim2], apply=True)
                     else:
-                        logger.warning("unsupported broadcast between " + str(dim1) + " " + str(dim2))
+                        logger.warning("unsupported broadcast between " + str(dim1) + " " + str(dim2))  # noqa: G003
             new_shape = [new_dim, *new_shape]
         return new_shape
 
@@ -663,12 +660,7 @@ def _new_symbolic_dim(self, prefix, dim):
 
     def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0):
         return self._new_symbolic_dim(
-            "{}{}_{}_o{}_".format(
-                node.op_type,
-                self.prefix_,
-                list(self.out_mp_.graph.node).index(node),
-                out_idx,
-            ),
+            f"{node.op_type}{self.prefix_}_{list(self.out_mp_.graph.node).index(node)}_o{out_idx}_",
             dim,
         )
 
@@ -1216,9 +1208,7 @@ def _infer_Loop(self, node):  # noqa: N802
         if need_second_infer:
             if self.verbose_ > 2:
                 logger.debug(
-                    "Rerun Loop: {}({}...), because of sequence in loop carried variables".format(
-                        node.name, node.output[0]
-                    )
+                    f"Rerun Loop: {node.name}({node.output[0]}...), because of sequence in loop carried variables"
                 )
             self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False)
 
@@ -1843,7 +1833,7 @@ def handle_negative_index(index, bound):
             axes = self._try_get_value(node, 3)
             steps = self._try_get_value(node, 4)
             if axes is None and not (starts is None and ends is None):
-                axes = list(range(0, len(starts if starts is not None else ends)))
+                axes = list(range(len(starts if starts is not None else ends)))
             if steps is None and not (starts is None and ends is None):
                 steps = [1] * len(starts if starts is not None else ends)
             axes = as_list(axes, keep_none=True)
@@ -1940,8 +1930,17 @@ def _infer_SoftmaxCrossEntropyLoss(self, node):  # noqa: N802
     def _infer_Split_Common(self, node, make_value_info_func):  # noqa: N802
         input_sympy_shape = self._get_sympy_shape(node, 0)
         axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape))
-        split = get_attribute(node, "split")
-        if not split:
+        op_set = get_opset(self.out_mp_)
+
+        # Depending on op-version 'split' are provided as attribute or via 2nd input
+        if op_set < 13:
+            split = get_attribute(node, "split")
+            assert self._try_get_value(node, 1) is None
+        else:
+            split = self._try_get_value(node, 1)
+            assert get_attribute(node, "split") is None
+
+        if split is None:
             num_outputs = len(node.output)
             split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs
             self._update_computed_dims(split)
@@ -2660,11 +2659,9 @@ def get_prereq(node):
                         break
 
             if self.verbose_ > 2:
-                logger.debug(node.op_type + ": " + node.name)
+                logger.debug(node.op_type + ": " + node.name)  # noqa: G003
                 for i, name in enumerate(node.input):
-                    logger.debug(
-                        "  Input {}: {} {}".format(i, name, "initializer" if name in self.initializers_ else "")
-                    )
+                    logger.debug("  Input %s: %s %s", i, name, "initializer" if name in self.initializers_ else "")
 
             # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb']
             # symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case
@@ -2713,7 +2710,7 @@ def get_prereq(node):
                             seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value")
                             if seq_cls_type == "tensor_type":
                                 logger.debug(
-                                    "  {}: sequence of {} {}".format(
+                                    "  {}: sequence of {} {}".format(  # noqa: G001
                                         node.output[i_o],
                                         str(get_shape_from_value_info(vi)),
                                         onnx.TensorProto.DataType.Name(
@@ -2731,14 +2728,10 @@ def get_prereq(node):
                 out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED
                 if self.verbose_ > 2:
                     logger.debug(
-                        "  {}: {} {}".format(
-                            node.output[i_o],
-                            str(out_shape),
-                            onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type),
-                        )
+                        f"  {node.output[i_o]}: {out_shape!s} {onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type)}"
                     )
                     if node.output[i_o] in self.sympy_data_:
-                        logger.debug("  Sympy Data: " + str(self.sympy_data_[node.output[i_o]]))
+                        logger.debug("  Sympy Data: " + str(self.sympy_data_[node.output[i_o]]))  # noqa: G003
 
                 # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain
                 if (
@@ -2839,24 +2832,16 @@ def get_prereq(node):
                             if self.verbose_ > 0:
                                 if is_unknown_op:
                                     logger.debug(
-                                        "Possible unknown op: {} node: {}, guessing {} shape".format(
-                                            node.op_type, node.name, vi.name
-                                        )
+                                        f"Possible unknown op: {node.op_type} node: {node.name}, guessing {vi.name} shape"
                                     )
                                 if self.verbose_ > 2:
-                                    logger.debug(
-                                        "  {}: {} {}".format(
-                                            node.output[i_o],
-                                            str(new_shape),
-                                            vi.type.tensor_type.elem_type,
-                                        )
-                                    )
+                                    logger.debug(f"  {node.output[i_o]}: {new_shape!s} {vi.type.tensor_type.elem_type}")
 
                             self.run_ = True
                             continue  # continue the inference after guess, no need to stop as no merge is needed
 
                     if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined:
-                        logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name)
+                        logger.debug("Stopping at incomplete shape inference at %s: %s", node.op_type, node.name)
                         logger.debug("node inputs:")
                         for i in node.input:
                             if i in self.known_vi_:
@@ -2870,7 +2855,7 @@ def get_prereq(node):
                             else:
                                 logger.debug(f"not in known_vi_ for {o}")
                         if self.auto_merge_ and not out_type_undefined:
-                            logger.debug("Merging: " + str(self.suggested_merge_))
+                            logger.debug("Merging: " + str(self.suggested_merge_))  # noqa: G003
                     return False
 
         self.run_ = False
@@ -2955,9 +2940,9 @@ def parse_arguments():
 
 if __name__ == "__main__":
     args = parse_arguments()
-    logger.info("input model: " + args.input)
+    logger.info("input model: " + args.input)  # noqa: G003
     if args.output:
-        logger.info("output model " + args.output)
+        logger.info("output model " + args.output)  # noqa: G003
     logger.info("Doing symbolic shape inference...")
     out_mp = SymbolicShapeInference.infer_shapes(
         onnx.load(args.input),
diff --git a/onnxruntime/python/tools/tensorrt/perf/benchmark.py b/onnxruntime/python/tools/tensorrt/perf/benchmark.py
index 20bb8a71dc35f..8af074f24acc9 100644
--- a/onnxruntime/python/tools/tensorrt/perf/benchmark.py
+++ b/onnxruntime/python/tools/tensorrt/perf/benchmark.py
@@ -790,7 +790,7 @@ def skip_ep(model_name, ep, model_to_fail_ep):
 
     # if ep in fail_ep_list and fail_ep_list[ep] == "runtime error":
     if ep in fail_ep_list:
-        logger.info("Skip testing " + model_name + " using " + ep + " since it has some issues.")
+        logger.info("Skip testing " + model_name + " using " + ep + " since it has some issues.")  # noqa: G003
         return True
 
     return False
@@ -925,7 +925,7 @@ def find_model_path(path):
 
     logger.info(target_model_path)
     if len(target_model_path) > 1:
-        logger.error("We expect to find only one model in " + path)
+        logger.error("We expect to find only one model in " + path)  # noqa: G003
         raise
 
     return target_model_path[0]
diff --git a/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py b/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py
index 93d41551c7121..f12d4599817b7 100644
--- a/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py
+++ b/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py
@@ -80,9 +80,9 @@ def main():
     benchmark = is_benchmark_mode(args.running_mode)  # noqa: F405
 
     for model, model_info in models.items():
-        logger.info("\n" + "=" * 40 + "=" * len(model))  # noqa: F405
-        logger.info("=" * 20 + model + "=" * 20)  # noqa: F405
-        logger.info("=" * 40 + "=" * len(model))  # noqa: F405
+        logger.info("\n" + "=" * 40 + "=" * len(model))  # noqa: F405, G003
+        logger.info("=" * 20 + model + "=" * 20)  # noqa: F405, G003
+        logger.info("=" * 40 + "=" * len(model))  # noqa: F405, G003
 
         model_info["model_name"] = model
 
diff --git a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py
index b98aafc27579a..b95ad3c0a55ef 100644
--- a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py
+++ b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py
@@ -14,9 +14,10 @@
 from typing import List, Optional
 
 TRT_DOCKER_FILES = {
-    "8.4": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4",
-    "8.5": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5",
-    "8.6": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6",
+    "8.4.cuda_11_6_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4",
+    "8.5.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5",
+    "8.6.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6",
+    "8.6.cuda_12_3_cudnn_9": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6",
     "BIN": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin",
 }
 
@@ -45,7 +46,7 @@ def get_common_docker_build_args(args: argparse.Namespace) -> List[str]:
     :return: A list of common 'docker build' arguments.
     """
 
-    return [
+    command = [
         "--no-cache",
         "-t",
         f"{args.image_name}",
@@ -54,6 +55,14 @@ def get_common_docker_build_args(args: argparse.Namespace) -> List[str]:
         "--build-arg",
         f"ONNXRUNTIME_BRANCH={args.branch}",
     ]
+    if args.use_tensorrt_oss_parser:
+        command.extend(
+            [
+                "--build-arg",
+                "PARSER_CONFIG=--use_tensorrt_oss_parser",
+            ]
+        )
+    return command
 
 
 def is_valid_ver_str(version: str, min_comps: int = 0, max_comps: int = 0) -> bool:
@@ -91,18 +100,11 @@ def docker_build_trt(args: argparse.Namespace):
     :param args: The arguments to this script.
     """
 
-    if not is_valid_ver_str(args.trt_version, min_comps=2, max_comps=4):
-        print(f"[ERROR]: Invalid TensorRT version '{args.trt_version}'", file=sys.stderr)
-        sys.exit(1)
-
-    vers_comps = args.trt_version.split(".")
-    trt_ver_key = f"{vers_comps[0]}.{vers_comps[1]}"
-
-    if trt_ver_key not in TRT_DOCKER_FILES:
+    if args.trt_version not in TRT_DOCKER_FILES:
         print(f"[ERROR]: TensorRT version '{args.trt_version}' is currently unsupported", file=sys.stderr)
         sys.exit(1)
 
-    docker_file = TRT_DOCKER_FILES[trt_ver_key]
+    docker_file = TRT_DOCKER_FILES[args.trt_version]
     docker_file_path = os.path.normpath(os.path.join(args.repo_path, docker_file))
 
     if not os.path.isfile(docker_file_path):
@@ -136,11 +138,7 @@ def docker_build_trt_bin(args: argparse.Namespace):
         sys.exit(1)
 
     if not is_valid_ver_str(args.tar_cuda_version, 2, 2):
-        print("[ERROR]: Must specify a valid CUDA version for binary TensorRT installs (e.g., 11.x)", file=sys.stderr)
-        sys.exit(1)
-
-    if not is_valid_ver_str(args.tar_cudnn_version, 2, 2):
-        print("[ERROR]: Must specify a valid cuDNN version for binary TensorRT installs (e.g., 8.x)", file=sys.stderr)
+        print("[ERROR]: Must specify a valid CUDA version for binary TensorRT installs (e.g., 12.4)", file=sys.stderr)
         sys.exit(1)
 
     if not os.path.isfile(docker_file_path):
@@ -162,8 +160,6 @@ def docker_build_trt_bin(args: argparse.Namespace):
             "--build-arg",
             f"TAR_CUDA_VERSION={args.tar_cuda_version}",
             "--build-arg",
-            f"TAR_CUDNN_VERSION={args.tar_cudnn_version}",
-            "--build-arg",
             f"TRT_BINS_DIR={args.trt_bins_dir}",
             "-f",
             f"{docker_file_path}",
@@ -187,7 +183,9 @@ def parse_arguments() -> argparse.Namespace:
     parser.add_argument("-r", "--repo_path", required=True, help="Path to the onnxruntime repository")
     parser.add_argument("-i", "--image_name", required=True, help="The resulting Docker image name")
     parser.add_argument("-b", "--branch", default="main", help="Name of the onnxruntime git branch to checkout")
-    parser.add_argument("-t", "--trt_version", default="8.4.1.5", help="TensorRT version (e.g., 8.4.1.5)")
+    parser.add_argument(
+        "-t", "--trt_version", default="8.6.cuda_11_8_cudnn_8", help="TensorRT version (e.g., 8.6.cuda_11_8_cudnn_8)"
+    )
     parser.add_argument("-a", "--cuda_arch", default="75", help="CUDA architecture (e.g., 75)")
 
     # Command-line options for installing TensorRT from binaries.
@@ -200,14 +198,15 @@ def parse_arguments() -> argparse.Namespace:
     parser.add_argument(
         "--tar_cuda_version",
         default="",
-        help="CUDA version (e.g., 11.8) used to find TensorRT EA binary tar.gz package",
+        help="CUDA version (e.g., 12.4) used to find TensorRT EA binary tar.gz package",
     )
+    parser.add_argument("--trt_bins_dir", default="", help="Directory containing TensorRT tar.gz package")
     parser.add_argument(
-        "--tar_cudnn_version",
-        default="",
-        help="CUDA version (e.g., 8.6) used to find TensorRT EA binary tar.gz package",
+        "--use_tensorrt_oss_parser",
+        action="store_true",
+        default=False,
+        help="Use TensorRT OSS Parser",
     )
-    parser.add_argument("--trt_bins_dir", default="", help="Directory containing TensorRT tar.gz package")
 
     return parser.parse_args()
 
diff --git a/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py b/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py
index 6e20071683d90..c7d4a7836132a 100755
--- a/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py
+++ b/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py
@@ -13,6 +13,12 @@ def parse_arguments():
     parser.add_argument("-b", "--branch", required=False, default="master", help="Github branch to test perf off of")
     parser.add_argument("-s", "--save", required=False, help="Directory to archive wheel file")
     parser.add_argument("-a", "--use_archived", required=False, help="Archived wheel file")
+    parser.add_argument(
+        "--use_tensorrt_oss_parser",
+        action="store_true",
+        default=False,
+        help="Use TensorRT OSS Parser",
+    )
     args = parser.parse_args()
     return args
 
@@ -35,14 +41,14 @@ def install_new_ort_wheel(ort_master_path):
 def main():
     args = parse_arguments()
 
-    cmake_tar = "cmake-3.18.4-Linux-x86_64.tar.gz"
+    cmake_tar = "cmake-3.28.3-linux-x86_64.tar.gz"
     if not os.path.exists(cmake_tar):
-        subprocess.run(["wget", "-c", "https://cmake.org/files/v3.18/" + cmake_tar], check=True)
+        subprocess.run(["wget", "-c", "https://cmake.org/files/v3.28/" + cmake_tar], check=True)
     tar = tarfile.open(cmake_tar)
     tar.extractall()
     tar.close()
 
-    os.environ["PATH"] = os.path.join(os.path.abspath("cmake-3.18.4-Linux-x86_64"), "bin") + ":" + os.environ["PATH"]
+    os.environ["PATH"] = os.path.join(os.path.abspath("cmake-3.28.3-linux-x86_64"), "bin") + ":" + os.environ["PATH"]
     os.environ["CUDACXX"] = os.path.join(args.cuda_home, "bin", "nvcc")
 
     ort_master_path = args.ort_master_path
@@ -57,24 +63,24 @@ def main():
         subprocess.run(["git", "fetch"], check=True)
         subprocess.run(["git", "checkout", args.branch], check=True)
         subprocess.run(["git", "pull", "origin", args.branch], check=True)
-        subprocess.run(
-            [
-                "./build.sh",
-                "--config",
-                "Release",
-                "--use_tensorrt",
-                "--tensorrt_home",
-                args.tensorrt_home,
-                "--cuda_home",
-                args.cuda_home,
-                "--cudnn",
-                "/usr/lib/x86_64-linux-gnu",
-                "--build_wheel",
-                "--skip_tests",
-                "--parallel",
-            ],
-            check=True,
-        )
+        command = [
+            "./build.sh",
+            "--config",
+            "Release",
+            "--use_tensorrt",
+            "--tensorrt_home",
+            args.tensorrt_home,
+            "--cuda_home",
+            args.cuda_home,
+            "--cudnn",
+            "/usr/lib/x86_64-linux-gnu",
+            "--build_wheel",
+            "--skip_tests",
+            "--parallel",
+        ]
+        if args.use_tensorrt_oss_parser:
+            command.append("--use_tensorrt_oss_parser")
+        subprocess.run(command, check=True)
 
         ort_wheel_file = install_new_ort_wheel(ort_master_path)
 
diff --git a/onnxruntime/python/tools/tensorrt/perf/mem_test/run.sh b/onnxruntime/python/tools/tensorrt/perf/mem_test/run.sh
index dd53fe6127462..2cfdd39bc96aa 100755
--- a/onnxruntime/python/tools/tensorrt/perf/mem_test/run.sh
+++ b/onnxruntime/python/tools/tensorrt/perf/mem_test/run.sh
@@ -4,13 +4,14 @@
 
 set -x
 
-while getopts p:o:l:s: parameter
+while getopts p:o:l:s:c: parameter
 do case "${parameter}"
 in
 p) WORKSPACE=${OPTARG};;
 o) ORT_BINARY_PATH=${OPTARG};;
 l) BUILD_ORT_LATEST=${OPTARG};;
 s) ORT_SOURCE=${OPTARG};;
+c) CONCURRENCY=${OPTARG};;
 esac
 done
 
@@ -104,6 +105,26 @@ fi
 
 mv valgrind.log result
 
+# Concurrency Test
+FRCNN_FOLDER="/data/ep-perf-models/onnx-zoo-models/FasterRCNN-10/"
+
+mkdir FasterRCNN-10/
+cp -r ${FRCNN_FOLDER}/test_data_set_0 ${FRCNN_FOLDER}/faster_rcnn_R_50_FPN_1x.onnx ./FasterRCNN-10/
+
+# replicate test inputs
+for (( i=1; i<CONCURRENCY; i++ )); do
+    cp -r "./FasterRCNN-10/test_data_set_0/" "./FasterRCNN-10/test_data_set_$i/"
+done
+
+pip install onnx requests packaging
+python ${ORT_SOURCE}/onnxruntime/python/tools/symbolic_shape_infer.py \
+    --input="./FasterRCNN-10/faster_rcnn_R_50_FPN_1x.onnx" \
+    --output="./FasterRCNN-10/faster_rcnn_R_50_FPN_1x.onnx" \
+    --auto_merge
+
+${ORT_SOURCE}/build/Linux/Release/onnx_test_runner -e tensorrt -c ${CONCURRENCY} -r 100 ./FasterRCNN-10/ > concurrency_test.log 2>&1
+mv concurrency_test.log result
+
 # Run AddressSanitizer 
 ASAN_OPTIONS=${ASAN_OPTIONS} ./onnx_memtest
 
diff --git a/onnxruntime/python/tools/tensorrt/perf/mem_test/run_mem_test_docker.sh b/onnxruntime/python/tools/tensorrt/perf/mem_test/run_mem_test_docker.sh
index 4e94c63ee6c25..a355e4cf5d365 100755
--- a/onnxruntime/python/tools/tensorrt/perf/mem_test/run_mem_test_docker.sh
+++ b/onnxruntime/python/tools/tensorrt/perf/mem_test/run_mem_test_docker.sh
@@ -3,13 +3,14 @@
 set -x
 
 # Parse Arguments
-while getopts w:d:p:l: parameter
+while getopts w:d:p:l:c: parameter
 do case "${parameter}"
 in 
 w) WORKSPACE=${OPTARG};; # workspace folder of onnxruntime
 d) DOCKER_IMAGE=${OPTARG};; # docker image:"trt-ep-mem-test" docker image is already pre-built on perf machine
 p) MEM_TEST_DIR=${OPTARG};; # mem test dir
 l) BUILD_ORT_LATEST=${OPTARG};; # whether to build latest ORT
+c) CONCURRENCY=${OPTARG};;
 esac
 done 
 
@@ -24,4 +25,4 @@ then
     BUILD_ORT_LATEST="true"
 fi
 
-docker run --rm --gpus all -v $MEM_TEST_DIR:$DOCKER_MEM_TEST_DIR -v /data/ep-perf-models:/data/ep-perf-models $DOCKER_IMAGE /bin/bash $DOCKER_MEM_TEST_DIR'run.sh' -p $DOCKER_MEM_TEST_DIR -o $DOCKER_ORT_LIBS -s $DOCKER_ORT_SOURCE -l $BUILD_ORT_LATEST
+docker run --rm --gpus all -v $MEM_TEST_DIR:$DOCKER_MEM_TEST_DIR -v /data/ep-perf-models:/data/ep-perf-models $DOCKER_IMAGE /bin/bash $DOCKER_MEM_TEST_DIR'run.sh' -p $DOCKER_MEM_TEST_DIR -o $DOCKER_ORT_LIBS -s $DOCKER_ORT_SOURCE -l $BUILD_ORT_LATEST -c $CONCURRENCY
diff --git a/onnxruntime/python/tools/tensorrt/perf/post.py b/onnxruntime/python/tools/tensorrt/perf/post.py
index 0f5614bd5160f..df389ad572596 100644
--- a/onnxruntime/python/tools/tensorrt/perf/post.py
+++ b/onnxruntime/python/tools/tensorrt/perf/post.py
@@ -3,6 +3,7 @@
 # Licensed under the MIT License.
 # --------------------------------------------------------------------------
 import argparse
+import csv
 import datetime
 import os
 import sys
@@ -56,6 +57,7 @@ def parse_arguments():
     parser.add_argument("-b", "--branch", help="Branch", required=True)
     parser.add_argument("--kusto_conn", help="Kusto connection URL", required=True)
     parser.add_argument("--database", help="Database name", required=True)
+    parser.add_argument("--use_tensorrt_oss_parser", help="Use TensorRT OSS parser", required=False)
     parser.add_argument(
         "-d",
         "--commit_datetime",
@@ -370,7 +372,7 @@ def write_table(
     ingest_client.ingest_from_dataframe(table, ingestion_properties=ingestion_props)
 
 
-def get_identifier(commit_datetime, commit_hash, trt_version, branch):
+def get_identifier(commit_datetime, commit_hash, trt_version, branch, use_tensorrt_oss_parser):
     """
     Returns an identifier that associates uploaded data with an ORT commit/date/branch and a TensorRT version.
 
@@ -383,7 +385,23 @@ def get_identifier(commit_datetime, commit_hash, trt_version, branch):
     """
 
     date = str(commit_datetime.date())  # extract date only
-    return date + "_" + commit_hash + "_" + trt_version + "_" + branch
+    if use_tensorrt_oss_parser:
+        current_dir = os.path.dirname(os.path.abspath(__file__))
+        root_dir = os.path.abspath(os.path.join(current_dir, "../../../../.."))
+        deps_txt_path = os.path.join(root_dir, "cmake", "deps.txt")
+        commit_head = ""
+        with open(deps_txt_path) as file:
+            for line in file:
+                parts = line.split(";")
+                if parts[0] == "onnx_tensorrt":
+                    url = parts[1]
+                    commit = url.split("/")[-1]
+                    commit_head = commit[:6]
+                    break
+        parser = f"oss_{commit_head}"
+    else:
+        parser = "builtin"
+    return "_".join([date, commit_hash, trt_version, parser, branch])
 
 
 def main():
@@ -396,14 +414,17 @@ def main():
     # connect to database
     kcsb_ingest = KustoConnectionStringBuilder.with_az_cli_authentication(args.kusto_conn)
     ingest_client = QueuedIngestClient(kcsb_ingest)
-    identifier = get_identifier(args.commit_datetime, args.commit_hash, args.trt_version, args.branch)
+    identifier = get_identifier(
+        args.commit_datetime, args.commit_hash, args.trt_version, args.branch, args.use_tensorrt_oss_parser
+    )
     upload_time = datetime.datetime.now(tz=datetime.timezone.utc).replace(microsecond=0)
 
     try:
+        # Load EP Perf test results from /result
         result_file = args.report_folder
-
-        folders = os.listdir(result_file)
-        os.chdir(result_file)
+        result_perf_test_path = os.path.join(result_file, "result")
+        folders = os.listdir(result_perf_test_path)
+        os.chdir(result_perf_test_path)
 
         tables = [
             fail_name,
@@ -426,13 +447,13 @@ def main():
         for model_group in folders:
             os.chdir(model_group)
             csv_filenames = os.listdir()
-            for csv in csv_filenames:
-                table = pd.read_csv(csv)
-                if session_name in csv:
+            for csv_file in csv_filenames:
+                table = pd.read_csv(csv_file)
+                if session_name in csv_file:
                     table_results[session_name] = pd.concat(
                         [table_results[session_name], get_session(table, model_group)], ignore_index=True
                     )
-                elif specs_name in csv:
+                elif specs_name in csv_file:
                     table_results[specs_name] = pd.concat(
                         [
                             table_results[specs_name],
@@ -440,12 +461,12 @@ def main():
                         ],
                         ignore_index=True,
                     )
-                elif fail_name in csv:
+                elif fail_name in csv_file:
                     table_results[fail_name] = pd.concat(
                         [table_results[fail_name], get_failures(table, model_group)],
                         ignore_index=True,
                     )
-                elif latency_name in csv:
+                elif latency_name in csv_file:
                     table_results[memory_name] = pd.concat(
                         [table_results[memory_name], get_memory(table, model_group)],
                         ignore_index=True,
@@ -455,11 +476,11 @@ def main():
                         [table_results[latency_name], get_latency(table, model_group)],
                         ignore_index=True,
                     )
-                elif status_name in csv:
+                elif status_name in csv_file:
                     table_results[status_name] = pd.concat(
                         [table_results[status_name], get_status(table, model_group)], ignore_index=True
                     )
-                elif op_metrics_name in csv:
+                elif op_metrics_name in csv_file:
                     table = table.assign(Group=model_group)
                     table_results[op_metrics_name] = pd.concat(
                         [table_results[op_metrics_name], table], ignore_index=True
@@ -493,6 +514,43 @@ def main():
                 args.commit_datetime,
             )
 
+        # Load concurrency test results
+        result_mem_test_path = os.path.join(result_file, "result_mem_test")
+        os.chdir(result_mem_test_path)
+        log_path = "concurrency_test.log"
+        if os.path.exists(log_path):
+            print("Generating concurrency test report")
+            with open(log_path) as log_file:
+                log_content = log_file.read()
+
+            failed_cases_section = log_content.split("Failed Test Cases:")[1]
+
+            # passed = 1 if no failed test cases
+            if failed_cases_section.strip() == "":
+                passed = 1
+            else:
+                passed = 0
+
+            csv_path = "concurrency_test.csv"
+            with open(csv_path, "w", newline="") as csv_file:
+                csv_writer = csv.writer(csv_file)
+                csv_writer.writerow(["Passed", "Log"])
+                csv_writer.writerow([passed, log_content])
+
+            db_table_name = "ep_concurrencytest_record"
+            table = pd.read_csv(csv_path)
+            write_table(
+                ingest_client,
+                args.database,
+                table,
+                db_table_name,
+                upload_time,
+                identifier,
+                args.branch,
+                args.commit_hash,
+                args.commit_datetime,
+            )
+
     except BaseException as e:
         print(str(e))
         sys.exit(1)
diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py
index 89f9947688583..9baafbbfff0e3 100644
--- a/onnxruntime/python/tools/transformers/benchmark.py
+++ b/onnxruntime/python/tools/transformers/benchmark.py
@@ -802,7 +802,7 @@ def main():
         try:
             os.mkdir(args.cache_dir)
         except OSError:
-            logger.error("Creation of the directory %s failed" % args.cache_dir)
+            logger.error("Creation of the directory %s failed" % args.cache_dir)  # noqa: G002
 
     enable_torch = "torch" in args.engines
     enable_torch2 = "torch2" in args.engines
@@ -921,7 +921,7 @@ def main():
                     args,
                 )
             except Exception:
-                logger.error("Exception", exc_info=True)
+                logger.exception("Exception")
 
     time_stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
     if model_fusion_statistics:
diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py
index c7d93470a729e..66f7a63447764 100644
--- a/onnxruntime/python/tools/transformers/benchmark_helper.py
+++ b/onnxruntime/python/tools/transformers/benchmark_helper.py
@@ -142,7 +142,7 @@ def create_onnxruntime_session(
 
         session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers)
     except Exception:
-        logger.error("Exception", exc_info=True)
+        logger.error("Exception", exc_info=True)  # noqa: G201
 
     return session
 
@@ -589,7 +589,7 @@ def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None):
             if max_usage is None:
                 return None
 
-            print(f"GPU memory usage: before={memory_before_test}  peak={max_usage}")
+            logger.info(f"GPU memory usage: before={memory_before_test}  peak={max_usage}")
             if len(memory_before_test) >= 1 and len(max_usage) >= 1 and len(memory_before_test) == len(max_usage):
                 # When there are multiple GPUs, we will check the one with maximum usage.
                 max_used = 0
@@ -620,7 +620,7 @@ def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None):
             monitor.keep_measuring = False
             max_usage = mem_thread.result()
 
-        print(f"CPU memory usage: before={memory_before_test:.1f} MB, peak={max_usage:.1f} MB")
+        logger.info(f"CPU memory usage: before={memory_before_test:.1f} MB, peak={max_usage:.1f} MB")
         return max_usage - memory_before_test
 
 
diff --git a/onnxruntime/python/tools/transformers/bert_perf_test.py b/onnxruntime/python/tools/transformers/bert_perf_test.py
index 9c743a83819c3..17c5d3602bb3b 100644
--- a/onnxruntime/python/tools/transformers/bert_perf_test.py
+++ b/onnxruntime/python/tools/transformers/bert_perf_test.py
@@ -232,9 +232,9 @@ def onnxruntime_inference(session, all_inputs, output_names):
 def to_string(model_path, session, test_setting):
     sess_options = session.get_session_options()
     option = f"model={os.path.basename(model_path)},"
-    option += "graph_optimization_level={},intra_op_num_threads={},".format(
-        sess_options.graph_optimization_level, sess_options.intra_op_num_threads
-    ).replace("GraphOptimizationLevel.ORT_", "")
+    option += f"graph_optimization_level={sess_options.graph_optimization_level},intra_op_num_threads={sess_options.intra_op_num_threads},".replace(
+        "GraphOptimizationLevel.ORT_", ""
+    )
 
     option += f"batch_size={test_setting.batch_size},sequence_length={test_setting.sequence_length},"
     option += f"test_cases={test_setting.test_cases},test_times={test_setting.test_times},"
diff --git a/onnxruntime/python/tools/transformers/compare_bert_results.py b/onnxruntime/python/tools/transformers/compare_bert_results.py
index 61e4c97c75c8c..0c5125e74c8a4 100644
--- a/onnxruntime/python/tools/transformers/compare_bert_results.py
+++ b/onnxruntime/python/tools/transformers/compare_bert_results.py
@@ -59,16 +59,10 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3):
                         print(f"abs_diff={abs_diff}")
 
     if diff_count == 0:
-        print(
-            "100% passed for {} random inputs given thresholds (rtol={}, atol={}).".format(
-                len(baseline_results), rtol, atol
-            )
-        )
+        print(f"100% passed for {len(baseline_results)} random inputs given thresholds (rtol={rtol}, atol={atol}).")
     else:
         print(
-            "WARNING: {} out of {} results NOT passed for thresholds (rtol={}, atol={}).".format(
-                diff_count, len(baseline_results), rtol, atol
-            )
+            f"WARNING: {diff_count} out of {len(baseline_results)} results NOT passed for thresholds (rtol={rtol}, atol={atol})."
         )
 
     print(f"maximum absolute difference={max_abs_diff}")
@@ -117,11 +111,7 @@ def run_test(
         baseline_model, all_inputs, use_gpu, disable_optimization=True
     )
     if verbose:
-        print(
-            "baseline average latency (all optimizations disabled): {} ms".format(
-                statistics.mean(baseline_latency) * 1000
-            )
-        )
+        print(f"baseline average latency (all optimizations disabled): {statistics.mean(baseline_latency) * 1000} ms")
 
     if output_dir is not None:
         for i, inputs in enumerate(all_inputs):
diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py
index a2cdd17e19fa5..894e11275056e 100644
--- a/onnxruntime/python/tools/transformers/convert_generation.py
+++ b/onnxruntime/python/tools/transformers/convert_generation.py
@@ -1273,7 +1273,7 @@ def find_past_seq_len_usage(subg: GraphProto):
 
 
 def replace_mha_with_gqa(
-    model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = 0
+    model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = -1
 ):
     # Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes
     #
@@ -1339,31 +1339,163 @@ def replace_mha_with_gqa(
     )
 
     # Replace MultiHeadAttention with GroupQueryAttention
+    #
+    # When replacing, fuse the following subgraph:
+    #
+    #                 root_input
+    #               /     |      \
+    #         MatMul    MatMul    MatMul
+    #           |         |         |
+    #          Add       Add       Add      (optional Adds)
+    #           |         |         |
+    #         RotEmb    RotEmb      |
+    #            \        |        /
+    #             MultiHeadAttention
+    #
+    # to this new subgraph:
+    #
+    #                 root_input
+    #                     |
+    #                PackedMatMul           (if possible)
+    #                     |
+    #                 PackedAdd             (if possible)
+    #                     |
+    #             GroupQueryAttention
+    #
+
     mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
-    for node in mha_nodes:
-        num_heads_mha = 0
+    for idx, node in enumerate(mha_nodes):
+        # Detect Q path to MHA
+        q_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [0, 0, 0])
+        q_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [0, 0])
+
+        q_rotary, q_add, q_matmul = None, None, None
+        if q_path_1 is not None:
+            q_rotary, q_add, q_matmul = q_path_1
+        elif q_path_2 is not None:
+            q_rotary, q_matmul = q_path_2
+
+        # Detect K path to MHA
+        k_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [1, 0, 0])
+        k_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [1, 0])
+
+        k_rotary, k_add, k_matmul = None, None, None
+        if k_path_1 is not None:
+            k_rotary, k_add, k_matmul = k_path_1
+        elif k_path_2 is not None:
+            k_rotary, k_matmul = k_path_2
+
+        # Detect V path to MHA
+        v_path_1 = model.match_parent_path(node, ["Add", "MatMul"], [2, 0])
+        v_path_2 = model.match_parent_path(node, ["MatMul"], [2])
+
+        v_add, v_matmul = None, None
+        if v_path_1 is not None:
+            v_add, v_matmul = v_path_1
+        elif v_path_2 is not None:
+            v_matmul = v_path_2[0]
+
+        # Get `interleaved` attribute from RotaryEmbedding
+        interleaved = 0
+        if q_rotary is not None and k_rotary is not None:
+            for att in q_rotary.attribute:
+                if att.name == "interleaved":
+                    interleaved = att.i
+
+        # Get `num_heads` attribute from MHA
+        num_heads = 0
         for att in node.attribute:
             if att.name == "num_heads":
-                num_heads_mha = att.i
+                num_heads = att.i
+
+        # Check if root_input to Q/K/V paths is the same
+        root_input_is_same = q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]
+
+        # Check if Q/K/V paths all have bias or all don't have bias
+        all_paths_have_bias = q_add is not None and k_add is not None and v_add is not None
+        all_paths_have_no_bias = q_add is None and k_add is None and v_add is None
+
+        # Make PackedMatMul node if possible
+        q_input_to_attention, k_input_to_attention, v_input_to_attention = "", "", ""
+        if root_input_is_same and (all_paths_have_bias or all_paths_have_no_bias):
+            qw = NumpyHelper.to_array(model.get_initializer(q_matmul.input[1]))
+            kw = NumpyHelper.to_array(model.get_initializer(k_matmul.input[1]))
+            vw = NumpyHelper.to_array(model.get_initializer(v_matmul.input[1]))
+
+            dim = qw.shape[-1]
+            qkv_weight = np.stack((qw, kw, vw), axis=1).reshape(dim, 3 * dim)
+            qkv_weight = onnx.numpy_helper.from_array(qkv_weight, name=f"QKV_Weight_{idx}")
+            model.add_initializer(qkv_weight)
+
+            packed_matmul_node = onnx.helper.make_node(
+                "MatMul",
+                inputs=[q_matmul.input[0], qkv_weight.name],
+                outputs=[f"{qkv_weight.name}_output"],
+                name=model.create_node_name("MatMul"),
+            )
+            model.model.graph.node.extend([packed_matmul_node])
+            model.model.graph.node.remove(q_matmul)
+            model.model.graph.node.remove(k_matmul)
+            model.model.graph.node.remove(v_matmul)
+            q_input_to_attention = packed_matmul_node.output[0]
+
+            # Make PackedAdd node if possible
+            if all_paths_have_bias:
+                qb = NumpyHelper.to_array(model.get_initializer(q_add.input[1]))
+                kb = NumpyHelper.to_array(model.get_initializer(k_add.input[1]))
+                vb = NumpyHelper.to_array(model.get_initializer(v_add.input[1]))
+
+                dim = qb.shape[-1]
+                qkv_bias = np.stack((qb, kb, vb), axis=0).reshape(3 * dim)
+                qkv_bias = onnx.numpy_helper.from_array(qkv_bias, name=f"QKV_Bias_{idx}")
+                model.add_initializer(qkv_bias)
+                packed_add_node = onnx.helper.make_node(
+                    "Add",
+                    inputs=[packed_matmul_node.output[0], qkv_bias.name],
+                    outputs=[f"{qkv_bias.name}_output"],
+                )
+                model.model.graph.node.extend([packed_add_node])
+                model.model.graph.node.remove(q_add)
+                model.model.graph.node.remove(k_add)
+                model.model.graph.node.remove(v_add)
+                q_input_to_attention = packed_add_node.output[0]
+
+        else:
+            q_input_to_attention = q_matmul.output[0]
+            k_input_to_attention = k_matmul.output[0]
+            v_input_to_attention = v_matmul.output[0]
+
+        # Make GQA node
         gqa_node = onnx.helper.make_node(
             "GroupQueryAttention",
             inputs=[
-                node.input[0],  # query
-                node.input[1],  # key
-                node.input[2],  # value
+                q_input_to_attention,  # query
+                k_input_to_attention,  # key
+                v_input_to_attention,  # value
                 node.input[6],  # past_key
                 node.input[7],  # past_value
-                "seqlens_k",  # seqlens_k (for attention_mask)
-                "total_seq_len",  # total_seq_len (for attention_mask)
+                seqlen_k_cast_node.output[0],  # seqlens_k (for attention mask)
+                total_seqlen_cast_node.output[0],  # total_seq_len (for attention mask)
+                q_rotary.input[2] if q_rotary is not None else "",  # cos_cache (for rotary embeddings)
+                q_rotary.input[3] if q_rotary is not None else "",  # sin_cache (for rotary embeddings)
             ],
             outputs=node.output,
             name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
             domain="com.microsoft",
-            num_heads=num_heads_mha // world_size,
-            kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
+            num_heads=num_heads // world_size,
+            kv_num_heads=num_heads // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
+            local_window_size=window_size,
+            do_rotary=int(q_rotary is not None and k_rotary is not None),
+            rotary_interleaved=interleaved,
         )
         model.model.graph.node.remove(node)
         model.model.graph.node.extend([gqa_node])
+
+        if q_rotary is not None:
+            model.model.graph.node.remove(q_rotary)
+        if k_rotary is not None:
+            model.model.graph.node.remove(k_rotary)
+
     return model
 
 
diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py
index 48c79b1d5fa0f..2398bb9d6031b 100644
--- a/onnxruntime/python/tools/transformers/float16.py
+++ b/onnxruntime/python/tools/transformers/float16.py
@@ -411,9 +411,7 @@ def convert_float_to_float16(
             value_info_list.append(make_value_info_from_tensor(value.initializer))
             if value.fp32_nodes and not force_fp16_initializers:
                 logger.info(
-                    "initializer is used by both fp32 and fp16 nodes. Consider add these nodes to block list:{}".format(
-                        value.fp16_nodes
-                    )
+                    f"initializer is used by both fp32 and fp16 nodes. Consider add these nodes to block list:{value.fp16_nodes}"
                 )
 
     # Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs.
diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py
index 9a353e7e2d675..048c13cdb1e2c 100644
--- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py
+++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py
@@ -373,9 +373,7 @@ def create_attention_node(
             else "MultiHeadAttention ({})".format(
                 "self attention with packed qkv"
                 if self.enable_packed_qkv
-                else "cross attention with packed kv"
-                if self.enable_packed_kv
-                else "cross attention"
+                else "cross attention with packed kv" if self.enable_packed_kv else "cross attention"
             )
         )
         self.increase_counter(counter_name)
@@ -843,9 +841,7 @@ def create_attention_node_lora(
             else "MultiHeadAttention ({})".format(
                 "self attention with packed qkv"
                 if self.enable_packed_qkv
-                else "cross attention with packed kv"
-                if self.enable_packed_kv
-                else "cross attention"
+                else "cross attention with packed kv" if self.enable_packed_kv else "cross attention"
             )
         )
         self.increase_counter(counter_name)
diff --git a/onnxruntime/python/tools/transformers/fusion_embedlayer.py b/onnxruntime/python/tools/transformers/fusion_embedlayer.py
index 42156d9123383..70ff57f0626e1 100644
--- a/onnxruntime/python/tools/transformers/fusion_embedlayer.py
+++ b/onnxruntime/python/tools/transformers/fusion_embedlayer.py
@@ -345,18 +345,13 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit
                 and input_ids_shape[1] == position_ids_shape[1]
             ):
                 logger.info(
-                    "Cannot fuse EmbedLayerNormalization: input_ids and position_ids not matched in 2nd dimension: {} vs {}".format(
-                        input_ids_shape, position_ids_shape
-                    )
+                    f"Cannot fuse EmbedLayerNormalization: input_ids and position_ids not matched in 2nd dimension: {input_ids_shape} vs {position_ids_shape}"
                 )
                 return False
 
             if segment_ids and not self.shape_infer.compare_shape(input_ids, segment_ids):
                 logger.info(
-                    "Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {} != {}".format(
-                        input_ids_shape,
-                        self.shape_infer.get_edge_shape(segment_ids),
-                    )
+                    f"Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {input_ids_shape} != {self.shape_infer.get_edge_shape(segment_ids)}"
                 )
                 return False
 
diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py
index 4c43e4487bfb1..edac1989e4e9e 100644
--- a/onnxruntime/python/tools/transformers/fusion_options.py
+++ b/onnxruntime/python/tools/transformers/fusion_options.py
@@ -29,6 +29,13 @@ class AttentionOpType(Enum):
     def __str__(self):
         return self.value
 
+    # Override __eq__ to return string comparison
+    def __hash__(self):
+        return hash(self.value)
+
+    def __eq__(self, other):
+        return other.value == self.value
+
 
 class FusionOptions:
     """Options of fusion in graph optimization"""
diff --git a/onnxruntime/python/tools/transformers/fusion_qordered_gelu.py b/onnxruntime/python/tools/transformers/fusion_qordered_gelu.py
index 6c44bb11e24dc..5f395b364eb6f 100644
--- a/onnxruntime/python/tools/transformers/fusion_qordered_gelu.py
+++ b/onnxruntime/python/tools/transformers/fusion_qordered_gelu.py
@@ -75,9 +75,11 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
 
         if not self.model.is_safe_to_fuse_nodes(
             subgraph_nodes,
-            [node.output[0], downstream_quantize_node.output[0]]
-            if downstream_shape_node is not None
-            else downstream_quantize_node.output,
+            (
+                [node.output[0], downstream_quantize_node.output[0]]
+                if downstream_shape_node is not None
+                else downstream_quantize_node.output
+            ),
             input_name_to_nodes,
             output_name_to_node,
         ):
diff --git a/onnxruntime/python/tools/transformers/fusion_qordered_layernorm.py b/onnxruntime/python/tools/transformers/fusion_qordered_layernorm.py
index cf2b357721757..5ec6dadc1e677 100644
--- a/onnxruntime/python/tools/transformers/fusion_qordered_layernorm.py
+++ b/onnxruntime/python/tools/transformers/fusion_qordered_layernorm.py
@@ -77,9 +77,11 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
 
         if not self.model.is_safe_to_fuse_nodes(
             subgraph_nodes,
-            [node.output[0], downstream_quantize_node.output[0]]
-            if downstream_shape_node is not None
-            else downstream_quantize_node.output,
+            (
+                [node.output[0], downstream_quantize_node.output[0]]
+                if downstream_shape_node is not None
+                else downstream_quantize_node.output
+            ),
             input_name_to_nodes,
             output_name_to_node,
         ):
diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py
index 50703b9c17e03..58a49525b9199 100644
--- a/onnxruntime/python/tools/transformers/io_binding_helper.py
+++ b/onnxruntime/python/tools/transformers/io_binding_helper.py
@@ -1,3 +1,4 @@
+import copy
 import logging
 from collections import OrderedDict
 from typing import Any, Dict, List, Tuple, Union
@@ -5,7 +6,7 @@
 import numpy
 import torch
 
-from onnxruntime import InferenceSession
+from onnxruntime import InferenceSession, RunOptions
 
 logger = logging.getLogger(__name__)
 
@@ -227,7 +228,6 @@ def __del__(self):
         del self.input_tensors
         del self.output_tensors
         del self.io_binding
-        del self.ort_session
 
     def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]):
         """Allocate tensors for I/O Binding"""
@@ -276,7 +276,7 @@ def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]):
                     tensor.data_ptr(),
                 )
 
-    def infer(self, feed_dict: Dict[str, torch.Tensor]):
+    def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = False):
         """Bind input tensors and run inference"""
         for name, tensor in feed_dict.items():
             assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous()
@@ -285,16 +285,7 @@ def infer(self, feed_dict: Dict[str, torch.Tensor]):
                     assert self.input_tensors[name].nelement() == tensor.nelement()
                     assert self.input_tensors[name].dtype == tensor.dtype
                     assert tensor.device.type == "cuda"
-                    # Please install cuda-python package with a version corresponding to CUDA in your machine.
-                    from cuda import cudart
-
-                    # Update input tensor inplace since cuda graph requires input and output has fixed memory address.
-                    cudart.cudaMemcpy(
-                        self.input_tensors[name].data_ptr(),
-                        tensor.data_ptr(),
-                        tensor.element_size() * tensor.nelement(),
-                        cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice,
-                    )
+                    self.input_tensors[name].copy_(tensor)
                 else:
                     self.io_binding.bind_input(
                         name,
@@ -305,14 +296,115 @@ def infer(self, feed_dict: Dict[str, torch.Tensor]):
                         tensor.data_ptr(),
                     )
 
-        self.ort_session.run_with_iobinding(self.io_binding)
+        # Synchronization are not needed in most cases unless different streams are used or inputs/outputs are in CPU.
+        if synchronize:
+            self.io_binding.synchronize_inputs()
+            self.ort_session.run_with_iobinding(self.io_binding, run_options)
+            self.io_binding.synchronize_outputs()
+        else:
+            self.ort_session.run_with_iobinding(self.io_binding, run_options)
 
         return self.output_tensors
 
     @staticmethod
-    def get_cuda_provider_options(device_id: int, enable_cuda_graph: bool) -> Dict[str, Any]:
-        return {
+    def get_cuda_provider_options(device_id: int, enable_cuda_graph: bool, stream: int = 0) -> Dict[str, Any]:
+        options = {
             "device_id": device_id,
             "arena_extend_strategy": "kSameAsRequested",
             "enable_cuda_graph": enable_cuda_graph,
         }
+
+        # Stream is address of a CUDA stream. 0 means the default stream.
+        if stream != 0:
+            options["user_compute_stream"] = str(stream)
+
+        return options
+
+
+class GpuBinding(CudaSession):
+    def __init__(
+        self,
+        ort_session: InferenceSession,
+        device: torch.device,
+        shape_dict: Dict[str, Union[Tuple[int], List[int]]],
+        enable_gpu_graph: bool = False,
+        gpu_graph_id: int = -1,
+        stream: int = 0,
+    ):
+        super().__init__(ort_session, device, enable_gpu_graph)
+        self.allocate_buffers(shape_dict)
+        self.gpu_graph_id = gpu_graph_id
+        # For cuda graph, we need to keep a copy of shape_dict to check if the shape is same in inference later.
+        self.shape_dict = copy.deepcopy(shape_dict) if enable_gpu_graph else None
+        self.stream = stream
+        # The gpu graph id of last run. It will be saved to image metadata.
+        self.last_run_gpu_graph_id = None
+
+    def get_run_options(self, disable_cuda_graph_in_run: bool = False) -> RunOptions:
+        options = RunOptions()
+
+        gpu_graph_id = -1 if disable_cuda_graph_in_run else self.gpu_graph_id
+
+        options.add_run_config_entry("gpu_graph_id", str(gpu_graph_id))
+
+        self.last_run_gpu_graph_id = gpu_graph_id
+
+        return options
+
+    def infer(self, feed_dict: Dict[str, torch.Tensor], disable_cuda_graph_in_run: bool = False):
+        run_options = self.get_run_options(disable_cuda_graph_in_run)
+
+        if self.stream:
+            run_options.add_run_config_entry("disable_synchronize_execution_providers", "1")
+
+        return super().infer(feed_dict, run_options)
+
+
+class GpuBindingManager:
+    """A manager for I/O bindings that support multiple CUDA Graphs.
+    One cuda graph is reused for same input shape. Automatically add a new cuda graph for new input shape.
+    """
+
+    def __init__(self, ort_session: InferenceSession, device: torch.device, stream: int = 0, max_cuda_graphs: int = 1):
+        self.ort_session = ort_session
+        self.device = device
+
+        # Binding supports cuda graphs. For a binding, it is able to disable cuda graph for a specific run.
+        self.graph_bindings = []
+
+        # Binding for not using cuda graph.
+        self.no_graph_binding = None
+
+        self.stream = stream
+
+        self.max_cuda_graphs = max_cuda_graphs
+
+    def get_binding(
+        self,
+        shape_dict: Dict[str, Union[Tuple[int], List[int]]],
+        use_cuda_graph: bool = False,
+    ) -> GpuBinding:
+        for gpu_graph_binding in self.graph_bindings:
+            # Found a cuda graph that captured with the same shape
+            if gpu_graph_binding.shape_dict == shape_dict:
+                return gpu_graph_binding
+
+        # Reached the maximum number of cuda graphs. Return a binding without cuda graph.
+        if len(self.graph_bindings) >= self.max_cuda_graphs or (not use_cuda_graph):
+            if self.no_graph_binding is None:
+                self.no_graph_binding = GpuBinding(self.ort_session, self.device, shape_dict, stream=self.stream)
+            else:
+                self.no_graph_binding.allocate_buffers(shape_dict)
+            return self.no_graph_binding
+
+        # This is a new input shape, create a new cuda graph
+        gpu_graph_binding = GpuBinding(
+            self.ort_session,
+            self.device,
+            shape_dict,
+            enable_gpu_graph=True,
+            gpu_graph_id=len(self.graph_bindings),
+            stream=self.stream,
+        )
+        self.graph_bindings.append(gpu_graph_binding)
+        return gpu_graph_binding
diff --git a/onnxruntime/python/tools/transformers/models/gpt2/benchmark_gpt2.py b/onnxruntime/python/tools/transformers/models/gpt2/benchmark_gpt2.py
index e48f0adc832c5..6d6a057574a17 100644
--- a/onnxruntime/python/tools/transformers/models/gpt2/benchmark_gpt2.py
+++ b/onnxruntime/python/tools/transformers/models/gpt2/benchmark_gpt2.py
@@ -400,7 +400,7 @@ def main(args):
                         }
                         csv_writer.writerow(row)
                     except Exception:
-                        logger.error("Exception", exc_info=True)
+                        logger.error("Exception", exc_info=True)  # noqa: G201
                         return None
 
     logger.info(f"Results are saved to file {csv_filename}")
diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py
index e01585ae84163..9153193a4974a 100644
--- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py
+++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py
@@ -630,7 +630,7 @@ def pytorch_inference(model, inputs: Gpt2Inputs, total_runs: int = 0):
                 latency.append(time.time() - start)
 
         average_latency = sum(latency) * 1000 / len(latency)
-        logger.debug("PyTorch inference time = {} ms".format(format(average_latency, ".2f")))
+        logger.debug("PyTorch inference time = {} ms".format(format(average_latency, ".2f")))  # noqa: G001
 
         return outputs, average_latency
 
@@ -662,7 +662,7 @@ def onnxruntime_inference(ort_session, inputs: Gpt2Inputs, total_runs: int = 0):
             latency.append(time.time() - start)
 
         average_latency = sum(latency) * 1000 / len(latency)
-        logger.debug("OnnxRuntime Inference time = {} ms".format(format(average_latency, ".2f")))
+        logger.debug("OnnxRuntime Inference time = {} ms".format(format(average_latency, ".2f")))  # noqa: G001
 
         return ort_outputs, average_latency
 
@@ -741,7 +741,7 @@ def onnxruntime_inference_with_binded_io(
             latency.append(time.time() - start)
 
         average_latency = sum(latency) * 1000 / len(latency)
-        logger.debug("OnnxRuntime with IO binding inference time = {} ms".format(format(average_latency, ".2f")))
+        logger.debug("OnnxRuntime with IO binding inference time = %.2f ms", average_latency)
 
         return ort_outputs, average_latency
 
diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py
index 4823f0d5874dd..b039f1351b1d0 100644
--- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py
+++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py
@@ -179,7 +179,7 @@ def print_wins(wins, rows, test_name):
         for row in rows:
             if row["run_id"] == key:
                 logger.info(
-                    "{:02d}: WINs={:02d}, run_id={}, latency={:5.2f}, top1_match={:.4f}, size={}_MB, experiment={}, {}".format(
+                    "{:02d}: WINs={:02d}, run_id={}, latency={:5.2f}, top1_match={:.4f}, size={}_MB, experiment={}, {}".format(  # noqa: G001
                         rank,
                         value,
                         key,
diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md
index f9552e02d74b9..2e8cd3e1ac7f9 100644
--- a/onnxruntime/python/tools/transformers/models/llama/README.md
+++ b/onnxruntime/python/tools/transformers/models/llama/README.md
@@ -1,7 +1,14 @@
 # Contents
  - [LLaMA-2](#llama-2)
+   - [Prerequisites](#prerequisites)
    - [Exporting LLaMA-2](#exporting-llama-2)
+   - [Examples of Exporting LLaMA-2](#examples-of-exporting-llama-2)
+   - [Parity Checking LLaMA-2](#parity-checking-llama-2)
    - [Benchmarking LLaMA-2](#benchmark-llama-2)
+     - [Variants](#variants)
+     - [Benchmark All](#benchmark-all)
+     - [Benchmark E2E](#benchmark-e2e)
+   - [E2E Inference with LLaMA-2](#e2e-inference-with-llama-2)
  - [Mistral](#mistral)
    - [Exporting Mistral](#exporting-mistral)
    - [Optimizing and Quantizing Mistral](#optimizing-and-quantizing-mistral)
@@ -229,6 +236,55 @@ $ ./build.sh --config Release --use_cuda --cuda_home /usr/local/cuda-12.2 --cudn
 $ CUDA_VISIBLE_DEVICES=0,1,2,3 bash convert_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-distributed --precision fp16 --execution_provider cuda --use_gqa
 ```
 
+## Parity Checking LLaMA-2
+
+Here are some examples of how you can use the parity checker to verify your LLaMA-2 ONNX model.
+
+1. Merged ONNX model, FP32 CPU
+```
+CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.llama_parity \
+    --model_name meta-llama/Llama-2-7b-hf \
+    --onnx_model_path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
+    --merged \
+    --execution_provider cpu \
+    --precision fp32 \
+    --cache_dir ./model_cache \
+```
+
+2. Merged ONNX model, FP32 CUDA
+```
+CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.llama_parity \
+    --model_name meta-llama/Llama-2-7b-hf \
+    --onnx_model_path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
+    --merged \
+    --execution_provider cuda \
+    --precision fp32 \
+    --cache_dir ./model_cache \
+```
+
+3. Merged ONNX model, FP16 CUDA
+```
+CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.llama_parity \
+    --model_name meta-llama/Llama-2-7b-hf \
+    --onnx_model_path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
+    --merged \
+    --execution_provider cuda \
+    --precision fp16 \
+    --cache_dir ./model_cache \
+```
+
+4. Merged ONNX model, FP16 CUDA with GroupQueryAttention
+```
+CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.llama_parity \
+    --model_name meta-llama/Llama-2-7b-hf \
+    --onnx_model_path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
+    --merged \
+    --use_gqa \
+    --execution_provider cuda \
+    --precision fp16 \
+    --cache_dir ./model_cache \
+```
+
 ## Benchmark LLaMA-2
 
 Here are some examples of how you can benchmark LLaMA-2.
@@ -240,6 +296,7 @@ Here are some examples of how you can benchmark LLaMA-2.
 CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
     --benchmark-type hf-pt-eager \
     --model-name meta-llama/Llama-2-7b-hf \
+    --cache-dir ./model_cache \
     --precision fp32 \
     --batch-sizes "1 2" \
     --sequence-lengths "8 16" \
@@ -252,6 +309,7 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
 CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
     --benchmark-type hf-pt-compile \
     --model-name meta-llama/Llama-2-7b-hf \
+    --cache-dir ./model_cache \
     --precision fp16 \
     --batch-sizes "1 2" \
     --sequence-lengths "8 16" \
@@ -265,6 +323,7 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
     --benchmark-type hf-ort \
     --hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
     --model-name meta-llama/Llama-2-7b-hf \
+    --cache-dir ./model_cache \
     --precision fp32 \
     --batch-sizes "1 2" \
     --sequence-lengths "8 16" \
@@ -278,6 +337,7 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
     --benchmark-type hf-ort \
     --hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
     --model-name meta-llama/Llama-2-7b-hf \
+    --cache-dir ./model_cache \
     --precision fp16 \
     --batch-sizes "1 2" \
     --sequence-lengths "8 16" \
@@ -291,6 +351,7 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
     --benchmark-type ort-msft \
     --ort-model-path ./llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \
     --model-name meta-llama/Llama-2-7b-hf \
+    --cache-dir ./model_cache \
     --precision fp32 \
     --batch-sizes "1 2" \
     --sequence-lengths "8 16" \
@@ -303,6 +364,7 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
     --benchmark-type ort-msft \
     --ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
     --model-name meta-llama/Llama-2-7b-hf \
+    --cache-dir ./model_cache \
     --precision fp16 \
     --batch-sizes "1 2" \
     --sequence-lengths "8 16" \
@@ -315,6 +377,7 @@ CUDA_VISIBLE_DEVICES=1 python3 -m models.llama.benchmark \
     --benchmark-type ort-convert-to-onnx \
     --ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
     --model-name meta-llama/Llama-2-7b-hf \
+    --cache-dir ./model_cache \
     --precision fp32 \
     --batch-sizes "1 2" \
     --sequence-lengths "8 16" \
@@ -327,6 +390,7 @@ CUDA_VISIBLE_DEVICES=4 python3 -m models.llama.benchmark \
     --benchmark-type ort-convert-to-onnx \
     --ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp16.onnx \
     --model-name meta-llama/Llama-2-7b-hf \
+    --cache-dir ./model_cache \
     --precision fp16 \
     --batch-sizes "1 2" \
     --sequence-lengths "8 16" \
@@ -339,6 +403,7 @@ CUDA_VISIBLE_DEVICES=4,5,6,7 bash benchmark_70b_model.sh 4 \
     --benchmark-type ort-convert-to-onnx \
     --ort-model-path ./llama2-70b-dis/rank_{}_Llama-2-70b-hf_decoder_merged_model_fp16.onnx \
     --model-name meta-llama/Llama-2-70b-hf \
+    --cache-dir ./model_cache \
     --precision fp16 \
     --device cuda \
     --warmup-runs 5 \
@@ -357,6 +422,7 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_all \
     --ort-convert-to-onnx-model-path ./llama2-7b-fp16/Llama-2-7b-hf_decoder_merged_model_fp16.onnx \
     --ort-msft-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
     --model-name meta-llama/Llama-2-7b-hf \
+    --cache-dir ./model_cache \
     --precision fp16 \
     --batch-sizes "1 2" \
     --sequence-lengths "8 16" \
@@ -366,6 +432,72 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_all \
     --timeout 60  # number of minutes before moving to the next benchmark
 ```
 
+### Benchmark E2E
+You can use `benchmark_e2e.py` to benchmark the full end-to-end scenario and automatically store the results in a CSV file. This tool uses `argmax` for sampling to standardize the benchmarking process.
+
+1. PyTorch without `torch.compile`, FP32
+```
+CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_e2e \
+    --benchmark-type pt-eager \
+    --model-name meta-llama/Llama-2-7b-hf \
+    --cache-dir ./model_cache \
+    --prompts-file ./models/llama/prompts.json \
+    --precision fp32 \
+    --batch-sizes "1 2" \
+    --prompt-lengths "16 64" \
+    --device cpu \
+    --auth
+```
+
+2. PyTorch with `torch.compile`, FP16
+```
+CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_e2e \
+    --benchmark-type pt-compile \
+    --model-name meta-llama/Llama-2-7b-hf \
+    --cache-dir ./model_cache \
+    --prompts-file ./models/llama/prompts.json \
+    --precision fp16 \
+    --batch-sizes "1 2" \
+    --prompt-lengths "16 64" \
+    --device cuda \
+    --auth
+```
+
+3. ONNX Runtime with `convert_to_onnx`, FP32
+```
+CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_e2e \
+    --benchmark-type ort \
+    --model-name meta-llama/Llama-2-7b-hf \
+    --cache-dir ./model_cache \
+    --onnx-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
+    --prompts-file ./models/llama/prompts.json \
+    --precision fp32 \
+    --batch-sizes "1 2" \
+    --prompt-lengths "16 64" \
+    --device cpu \
+    --auth
+```
+
+4. ONNX Runtime with `convert_to_onnx`, FP16
+```
+CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_e2e \
+    --benchmark-type ort \
+    --model-name meta-llama/Llama-2-7b-hf \
+    --cache-dir ./model_cache \
+    --onnx-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
+    --prompts-file ./models/llama/prompts.json \
+    --precision fp16 \
+    --batch-sizes "1 2" \
+    --prompt-lengths "16 64" \
+    --device cuda \
+    --use_buffer_share \
+    --auth
+```
+
+## E2E Inference with LLaMA-2
+
+For end-to-end inference, please visit the [ONNX Runtime Inference Examples folder](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/python/models/llama) for a step-by-step walkthrough, code examples, and performance metrics.
+
 # Mistral
 
 ## Introduction
diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py
index a53dead77dea6..6184298c471ac 100644
--- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py
+++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py
@@ -1,3 +1,8 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.  All rights reserved.
+# Licensed under the MIT License.  See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
 import argparse
 import datetime
 import gc
@@ -14,11 +19,12 @@
 from benchmark_helper import measure_memory, setup_logger
 from dist_settings import get_rank, get_size
 from llama_inputs import (
-    add_io_bindings,
+    add_io_bindings_as_ortvalues,
     get_merged_sample_with_past_kv_inputs,
     get_msft_sample_inputs,
     get_sample_inputs,
     get_sample_with_past_kv_inputs,
+    verify_ort_inputs,
 )
 from optimum.onnxruntime import ORTModelForCausalLM
 from torch.profiler import ProfilerActivity, profile, record_function
@@ -55,11 +61,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
     max_seq_len = (
         2048
         if args.benchmark_type == "ort-msft"
-        else 16384
-        if "codellama" in temp_name
-        else 4096
-        if "llama2" in temp_name
-        else 2048
+        else 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048
     )
 
     if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
@@ -203,6 +205,7 @@ def get_model(args: argparse.Namespace):
             torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
             use_auth_token=args.auth,
             use_cache=True,
+            cache_dir=args.cache_dir,
         ).to(args.target_device)
         end_time = time.time()
 
@@ -243,7 +246,7 @@ def get_model(args: argparse.Namespace):
             decoder_file_name=decoder_file_name,
             decoder_with_past_file_name=decoder_with_past_file_name,
             use_auth_token=args.auth,
-            use_io_binding=(args.device != "cpu"),
+            use_io_binding=True,  # Large perf gain even for cpu due to avoiding output copy.
             use_merged=(True if decoder_file_name == "model.onnx" else None),
             provider=provider,
             provider_options=provider_options,
@@ -278,21 +281,25 @@ def time_fn(args, fn, inputs):
         outputs = fn(inputs)
         logger.info(outputs)
 
-    input_sync = (  # noqa: E731
-        lambda *kwargs: args.io_binding.synchronize_inputs()
+    input_sync = lambda *kwargs: (  # noqa: E731
+        args.io_binding.synchronize_inputs()
         if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}  # ORT synchronize
-        else lambda *kwargs: torch.cuda.synchronize()
-        if args.device != "cpu" and torch.cuda.is_available()  # PyTorch synchronize
-        else lambda *kwargs: None  # no-op function
-    )
+        else lambda *kwargs: (
+            torch.cuda.synchronize()
+            if args.device != "cpu" and torch.cuda.is_available()  # PyTorch synchronize
+            else lambda *kwargs: None
+        )
+    )  # no-op function
 
-    output_sync = (  # noqa: E731
-        lambda *kwargs: args.io_binding.synchronize_outputs()
+    output_sync = lambda *kwargs: (  # noqa: E731
+        args.io_binding.synchronize_outputs()
         if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}  # ORT synchronize
-        else lambda *kwargs: torch.cuda.synchronize()
-        if args.device != "cpu" and torch.cuda.is_available()  # PyTorch synchronize
-        else lambda *kwargs: None  # no-op function
-    )
+        else lambda *kwargs: (
+            torch.cuda.synchronize()
+            if args.device != "cpu" and torch.cuda.is_available()  # PyTorch synchronize
+            else lambda *kwargs: None
+        )
+    )  # no-op function
 
     for _ in warmup_range:
         input_sync()
@@ -444,24 +451,12 @@ def get_logits(inputs):
 
 def run_ort_inference(args, init_inputs, iter_inputs, model):
     def prepare_ort_inputs(inputs, kv_cache_ortvalues):
-        # Check that all model inputs will be provided
-        model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
-        user_inputs = set(inputs.keys())
-        missing_inputs = model_inputs - user_inputs
-        if len(missing_inputs):
-            logger.error(f"The following model inputs are missing: {missing_inputs}")
-            raise Exception("There are missing inputs to the model. Please add them and try again.")
-
-        # Remove unnecessary inputs from model inputs
-        unnecessary_inputs = user_inputs - model_inputs
-        if len(unnecessary_inputs):
-            for unnecessary_input in unnecessary_inputs:
-                logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
-                del inputs[unnecessary_input]
+        # Verify model inputs
+        inputs = verify_ort_inputs(model, inputs)
 
         # Add IO bindings for non-CPU execution providers
         if args.device != "cpu":
-            io_binding, kv_cache_ortvalues = add_io_bindings(
+            io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
                 model, inputs, args.device, int(args.rank), args.use_gqa, kv_cache_ortvalues
             )
             setattr(args, "io_binding", io_binding)  # noqa: B010
@@ -612,6 +607,13 @@ def get_args(rank=0):
     parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
     parser.add_argument("--verbose", default=False, action="store_true")
     parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
+    parser.add_argument(
+        "--cache-dir",
+        type=str,
+        required=True,
+        default="./model_cache",
+        help="Cache dir where Hugging Face files are stored",
+    )
 
     args = parser.parse_args()
 
@@ -662,8 +664,8 @@ def main():
 
     args.rank = rank
     args.world_size = world_size
-    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
-    config = AutoConfig.from_pretrained(args.model_name)
+    tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir)
+    config = AutoConfig.from_pretrained(args.model_name, cache_dir=args.cache_dir)
     target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device
     use_fp16 = args.precision == "fp16"
 
diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py
index c6d550d47cf4c..2433ae3d9b5ee 100644
--- a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py
+++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py
@@ -1,3 +1,8 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.  All rights reserved.
+# Licensed under the MIT License.  See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
 import argparse
 import datetime
 import json
@@ -78,6 +83,13 @@ def get_args():
         help="Path to ONNX model from convert_to_onnx",
     )
 
+    parser.add_argument(
+        "--cache-dir",
+        type=str,
+        default="./model_cache",
+        help="Cache dir where Hugging Face files are stored",
+    )
+
     parser.add_argument(
         "--model-name",
         type=str,
@@ -332,6 +344,8 @@ def main():
             str(args.num_runs),
             "--log-folder",
             args.log_folder,
+            "--cache-dir",
+            args.cache_dir,
             "--auth",
         ]
         logger.info("Benchmark PyTorch without torch.compile")
@@ -362,6 +376,8 @@ def main():
             str(args.num_runs),
             "--log-folder",
             args.log_folder,
+            "--cache-dir",
+            args.cache_dir,
             "--auth",
         ]
         logger.info("Benchmark PyTorch with torch.compile")
@@ -394,6 +410,8 @@ def main():
             str(args.num_runs),
             "--log-folder",
             args.log_folder,
+            "--cache-dir",
+            args.cache_dir,
             "--auth",
         ]
         logger.info("Benchmark Optimum + ONNX Runtime")
@@ -426,6 +444,8 @@ def main():
             str(args.num_runs),
             "--log-folder",
             args.log_folder,
+            "--cache-dir",
+            args.cache_dir,
         ]
         logger.info("Benchmark Microsoft model in ONNX Runtime")
         results = benchmark(args, benchmark_cmd, "ort-msft")
@@ -457,6 +477,8 @@ def main():
             str(args.num_runs),
             "--log-folder",
             args.log_folder,
+            "--cache-dir",
+            args.cache_dir,
         ]
         logger.info("Benchmark convert_to_onnx model in ONNX Runtime")
         results = benchmark(args, benchmark_cmd, "onnxruntime")
diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py
new file mode 100644
index 0000000000000..4d0d2e68e8983
--- /dev/null
+++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py
@@ -0,0 +1,554 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.  All rights reserved.
+# Licensed under the MIT License.  See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+# This is an end-to-end benchmarking script for the Hugging Face LLaMA-2 model.
+#
+# Prerequisites:
+# 1) Install `huggingface-cli`:
+#
+# $ pip install huggingface_hub
+#
+# 2) Authenticate with Hugging Face's CLI:
+#
+# $ huggingface-cli login
+#
+# 3) Accept Meta's license in Hugging Face to access the models at https://huggingface.co/meta-llama/
+#
+# 4) Install the latest ONNX Runtime version
+#
+# $ pip install onnxruntime-gpu
+
+from __future__ import annotations
+
+import argparse
+import datetime
+import gc
+import itertools
+import json
+import logging
+import os
+import textwrap
+import time
+
+import numpy as np
+import pandas as pd
+import torch
+from benchmark_helper import setup_logger
+from llama_inputs import add_io_bindings_as_tensors, get_initial_inputs_and_outputs
+from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
+
+import onnxruntime as ort
+
+logger = logging.getLogger(__name__)
+
+
+def get_model(args):
+    if args.benchmark_type in {"pt-eager", "pt-compile"}:
+        model = AutoModelForCausalLM.from_pretrained(
+            args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
+            cache_dir=args.cache_dir,
+            torch_dtype=args.torch_dtype,
+            use_auth_token=args.auth,
+            use_cache=True,
+        ).to(args.target_device)
+        model.eval()
+
+        if args.benchmark_type == "pt-compile":
+            model = torch.compile(model)
+
+    else:
+        sess_options = ort.SessionOptions()
+        ep = (
+            ("CUDAExecutionProvider", {"device_id": args.device_id})
+            if args.device == "cuda"
+            else "CPUExecutionProvider"
+        )
+        model = ort.InferenceSession(args.onnx_model_path, sess_options=sess_options, providers=[ep])
+
+    return model
+
+
+def run_inference(args, model, runs, inputs, outputs):
+    if args.benchmark_type == "pt-compile":
+        with torch.no_grad():
+            outputs = model(**inputs)
+
+    # Synchronize inputs
+    io_binding = None
+    if args.benchmark_type in {"pt-eager", "pt-compile"}:
+        if args.device != "cpu":
+            torch.cuda.synchronize(args.target_device)
+    else:
+        io_binding = add_io_bindings_as_tensors(model, inputs, outputs, args.use_fp16, args.use_buffer_share)
+        io_binding.synchronize_inputs()
+
+    # Run inference
+    start = time.perf_counter()
+    for _ in range(runs):
+        if args.benchmark_type in {"pt-eager", "pt-compile"}:
+            with torch.no_grad():
+                outputs = model(**inputs)
+                if args.device != "cpu":
+                    torch.cuda.synchronize(args.target_device)
+        else:
+            model.run_with_iobinding(io_binding)
+            io_binding.synchronize_outputs()
+
+    end = time.perf_counter()
+    avg = (end - start) / runs
+    return avg, outputs
+
+
+def prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt):
+    clear_cache()
+    inputs, outputs = get_initial_inputs_and_outputs(
+        config, tokenizer, prompt_length, prompt, args.target_device, args.use_fp16, args.use_buffer_share, args.engine
+    )
+    _, outputs = run_inference(args, model, args.warmup_runs, inputs, outputs)
+    return inputs, outputs
+
+
+def clear_cache():
+    gc.collect()
+    torch.cuda.empty_cache()
+
+
+def save_results(results, filename, gen_length):
+    df = pd.DataFrame(
+        results,
+        columns=[
+            "Batch Size",
+            "Prompt Length",
+            "Prompt Processing Latency (ms)",
+            "Prompt Processing Throughput (tps)",
+            "Sampling Latency (ms)",
+            "Sampling Throughput (tps)",
+            "First Token Generated Latency (ms)",
+            "First Token Generated Throughput (tps)",
+            f"Average Latency of First {gen_length // 2} Tokens Generated (ms)",
+            f"Average Throughput of First {gen_length // 2} Tokens Generated (tps)",
+            f"Average Latency of First {gen_length} Tokens Generated (ms)",
+            f"Average Throughput of First {gen_length} Tokens Generated (tps)",
+            "Wall-Clock Latency (s)",
+            "Wall-Clock Throughput (tps)",
+        ],
+    )
+
+    df.to_csv(filename, index=False)
+    logger.info(f"Results saved in {filename}!")
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument(
+        "-bt",
+        "--benchmark-type",
+        type=str,
+        required=True,
+        choices=["pt-eager", "pt-compile", "ort"],
+    )
+
+    parser.add_argument(
+        "-m",
+        "--model-name",
+        type=str,
+        required=False,
+        help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
+    )
+
+    parser.add_argument(
+        "-a",
+        "--auth",
+        default=False,
+        action="store_true",
+        help="Use Hugging Face authentication token to access model",
+    )
+
+    parser.add_argument(
+        "-c",
+        "--cache-dir",
+        type=str,
+        default=os.path.join(".", "model_cache"),
+        help="Path to directory containing all Hugging Face files (e.g. config, tokenizer, PyTorch model). Use when loading model as `AutoModel.from_pretrained(model_name, cache_dir=cache_dir)`.",
+    )
+
+    parser.add_argument(
+        "--hf-dir-path",
+        type=str,
+        default="",
+        help="Path to directory containing all Hugging Face files (e.g. config, tokenizer, PyTorch model). Use when loading model as `AutoModel.from_pretrained(folder_path)`.",
+    )
+
+    parser.add_argument(
+        "-o",
+        "--onnx-model-path",
+        required=False,
+        help="Path to ONNX model",
+    )
+
+    parser.add_argument(
+        "-f",
+        "--prompts-file",
+        required=True,
+        default=os.path.join(".", "models", "llama", "prompts.json"),
+        help="JSON file containing entries in the format 'prompt length: prompt' where prompt length = tokenized length of prompt",
+    )
+
+    parser.add_argument(
+        "--use_buffer_share",
+        default=False,
+        action="store_true",
+        help="Use when GroupQueryAttention (GQA) is in ONNX model",
+    )
+
+    parser.add_argument(
+        "--anomaly-filtering",
+        default=False,
+        action="store_true",
+        help="Use this flag to filter anomaly accelerator times for tokens generated. \
+              This may give more accurate latency and throughput metrics for tokens generated. \
+              Wall-clock metrics are still reported with anomaly times though.",
+    ),
+
+    parser.add_argument(
+        "-b",
+        "--batch-sizes",
+        default="1 2",
+    )
+
+    parser.add_argument(
+        "-s",
+        "--prompt-lengths",
+        default="32 64 128 256 512",
+    )
+
+    parser.add_argument(
+        "-p",
+        "--precision",
+        required=True,
+        type=str,
+        default="fp32",
+        choices=["int4", "int8", "fp16", "fp32"],
+        help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
+    )
+
+    parser.add_argument(
+        "-g",
+        "--generation-length",
+        type=int,
+        default=256,
+        help="Number of new tokens to generate",
+    )
+
+    parser.add_argument(
+        "-d",
+        "--device",
+        type=str,
+        default="cuda" if torch.cuda.is_available() else "cpu",
+        choices=["cpu", "cuda"],
+    )
+
+    parser.add_argument("-id", "--device-id", type=int, default=0)
+    parser.add_argument("-w", "--warmup-runs", type=int, default=5)
+    parser.add_argument("-n", "--num-runs", type=int, default=100)
+    parser.add_argument("--seed", type=int, default=2)
+
+    args = parser.parse_args()
+
+    # Set seed properties
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+
+    # Set runtime properties
+    if "ort" in args.benchmark_type:
+        setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider")  # noqa: B010
+        if args.execution_provider == "CUDAExecutionProvider":
+            args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
+
+    # Check that paths have been specified for any benchmarking with ORT
+    if args.benchmark_type == "ort":
+        assert args.onnx_model_path, "Please specify a path to `--onnx-model-path`"
+
+    args.batch_sizes = args.batch_sizes.split(" ")
+    args.prompt_lengths = args.prompt_lengths.split(" ")
+
+    # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
+    args.precision = (
+        "fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
+    )
+
+    target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
+    torch_dtype = torch.float16 if args.precision == "fp16" else torch.float32
+    engine = "ort" if args.benchmark_type == "ort" else "pt"
+    setattr(args, "target_device", target_device)  # noqa: B010
+    setattr(args, "torch_dtype", torch_dtype)  # noqa: B010
+    setattr(args, "engine", engine)  # noqa: B010
+    setattr(args, "use_fp16", args.precision == "fp16")  # noqa: B010
+
+    return args
+
+
+def main():
+    args = get_args()
+    setup_logger(False)
+    logger.info(args.__dict__)
+
+    # Get prompts and prompt sizes
+    size_to_prompt = None
+    with open(args.prompts_file) as f:
+        size_to_prompt = json.load(f, object_hook=lambda d: {int(k): v for k, v in d.items()})
+
+    # Get config, tokenizer, and model
+    config = AutoConfig.from_pretrained(
+        args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
+        cache_dir=args.cache_dir,
+        use_auth_token=args.auth,
+    )
+    tokenizer = AutoTokenizer.from_pretrained(
+        args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
+        cache_dir=args.cache_dir,
+        use_auth_token=args.auth,
+    )
+    model = get_model(args)
+
+    all_csv_metrics = []
+    for batch_size, prompt_length in itertools.product(args.batch_sizes, args.prompt_lengths):
+        batch_size, prompt_length = int(batch_size), int(prompt_length)  # noqa: PLW2901
+        logger.info(f"Running batch size = {batch_size}, prompt length = {prompt_length}")
+        clear_cache()
+        max_length = prompt_length + args.generation_length
+
+        if prompt_length not in size_to_prompt:
+            raise NotImplementedError(
+                textwrap.dedent(
+                    f"""
+                                A prompt of size {prompt_length} was not found in '{args.prompts_file}'. There are a couple of solutions to fix this.
+                                1) You can change one of the keys in '{args.prompts_file}' to be {prompt_length}.
+                                    If {prompt_length} < actual prompt's length, the benchmark E2E tool will repeat the first word in the prompt until {prompt_length} = actual prompt's length.
+                                    If {prompt_length} > actual prompt's length, the benchmark E2E tool will automatically trim the actual prompt's length so that {prompt_length} = actual prompt's length.
+                                2) You can add a new key-value entry in '{args.prompts_file}' of the form '{prompt_length}': 'your prompt goes here'.
+                """
+                )
+            )
+        prompt = [size_to_prompt[prompt_length]] * batch_size
+        csv_metrics = [batch_size, prompt_length]
+
+        try:
+            # Measure prompt processing
+            logger.info("Measuring prompt processing...")
+            inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
+            accelerator_prompt_latency_s, outputs = run_inference(args, model, args.num_runs, inputs, outputs)
+
+            # Calculate prompt metrics
+            accelerator_prompt_latency_ms = accelerator_prompt_latency_s * 1000
+            accelerator_prompt_thrpt = batch_size * (prompt_length / accelerator_prompt_latency_s)
+            logger.info(f"Average Latency of Prompt Processing: {accelerator_prompt_latency_ms} ms")
+            logger.info(
+                f"Average Throughput of Prompt Processing: {batch_size * (prompt_length / accelerator_prompt_latency_s)} tps"
+            )
+            csv_metrics.extend([accelerator_prompt_latency_ms, accelerator_prompt_thrpt])
+
+            # Measure token generation
+            logger.info("Measuring token generation...")
+            clear_cache()
+            inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
+
+            all_token_ids = inputs["input_ids"].clone()
+            current_length = all_token_ids.shape[-1]
+            num_heads = config.num_key_value_heads
+            head_size = (
+                config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
+            )
+
+            has_eos = torch.zeros(batch_size, device=args.target_device, dtype=torch.bool)
+
+            # 0th entry will have prompt accelerator time, 1st entry onwards will have token generation accelerator time
+            accelerator_times = []
+            sampling_times = []  # cost to sample after each model run
+
+            wall_clock_start_time = time.perf_counter()
+            while current_length <= max_length:
+                # Run inference
+                accelerator_time_latency_s, outputs = run_inference(args, model, 1, inputs, outputs)
+                accelerator_times.append(accelerator_time_latency_s)
+
+                # Sample with argmax (greedy search)
+                sampling_start_time = time.perf_counter()
+                if outputs["logits"].shape[1] > 1:
+                    prompt_end_indices = inputs["attention_mask"].sum(1) - 1
+                    idxs = (
+                        prompt_end_indices.unsqueeze(dim=1)
+                        .repeat(1, config.vocab_size)
+                        .view(batch_size, 1, config.vocab_size)
+                    )
+                    next_token_logits = torch.gather(outputs["logits"], 1, idxs).squeeze()
+                else:
+                    next_token_logits = outputs["logits"][:, -1, :]
+                next_tokens = torch.argmax(next_token_logits, dim=-1)
+
+                # Check if we previously reached EOS token id or if generated token id is EOS token id
+                has_eos = has_eos | next_tokens == tokenizer.eos_token_id
+
+                # Determine which new tokens to add to list of all token ids
+                # Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
+                tokens_to_add = next_tokens.masked_fill(has_eos, tokenizer.eos_token_id).reshape([batch_size, 1])
+                sampling_end_time = time.perf_counter()
+                sampling_times.append(sampling_end_time - sampling_start_time)
+
+                all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
+
+                # Return early if all batch entries have reached EOS token id
+                current_length += 1
+                if torch.all(has_eos) or current_length > max_length:
+                    break
+
+                # Update inputs for next inference run
+                inputs["input_ids"] = tokens_to_add
+                inputs["attention_mask"] = torch.cat(
+                    [inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1
+                )
+                inputs["position_ids"] = (
+                    None
+                    if "position_ids" not in inputs
+                    else torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1
+                )
+
+                # Set logits to zeros for next inference run and re-use memory buffer
+                if outputs["logits"].shape[1] != 1:
+                    outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
+                outputs["logits"].zero_()
+
+                # Update KV caches for next inference run
+                if args.engine == "pt":
+                    # Update KV caches for PyTorch
+                    inputs["past_key_values"] = outputs["past_key_values"]
+                elif not args.use_buffer_share:
+                    # Update KV caches for ONNX Runtime if buffer sharing is not used
+                    for i in range(config.num_hidden_layers):
+                        inputs[f"past_key_values.{i}.key"] = outputs[f"present.{i}.key"]
+                        inputs[f"past_key_values.{i}.value"] = outputs[f"present.{i}.value"]
+
+                    new_sequence_length = inputs["attention_mask"].shape[1]
+                    for i in range(config.num_hidden_layers):
+                        present_key = torch.zeros(
+                            batch_size,
+                            num_heads,
+                            new_sequence_length,
+                            head_size,
+                            device=args.target_device,
+                            dtype=args.torch_dtype,
+                        )
+                        present_value = torch.zeros(
+                            batch_size,
+                            num_heads,
+                            new_sequence_length,
+                            head_size,
+                            device=args.target_device,
+                            dtype=args.torch_dtype,
+                        )
+                        outputs.update(
+                            {
+                                f"present.{i}.key": present_key.contiguous(),
+                                f"present.{i}.value": present_value.contiguous(),
+                            }
+                        )
+
+            wall_clock_end_time = time.perf_counter()
+
+            # Filter out any anomaly accelerator times (e.g. for `torch.compile`)
+            accelerator_times.pop(0)  # Remove prompt processing time
+            if args.anomaly_filtering:
+                anomaly_threshold_factor = 10
+                min_time_s = min(accelerator_times)
+                orig_size = len(accelerator_times)
+                accelerator_times = list(
+                    filter(lambda acc_time: acc_time < anomaly_threshold_factor * min_time_s, accelerator_times)
+                )
+                new_size = len(accelerator_times)
+                logger.info(
+                    f"Filtered out {orig_size - new_size} anomaly accelerator times that are {anomaly_threshold_factor}x greater than {min_time_s * 1000} ms..."
+                )
+
+            #######################################################
+            # Calculate sampling and first token generated metrics
+            #######################################################
+
+            # Calculate sampling metrics
+            avg_sampling_latency_s = sum(sampling_times) / len(sampling_times)
+            avg_sampling_latency_ms = avg_sampling_latency_s * 1000
+            avg_sampling_thrpt = batch_size * (1 / avg_sampling_latency_s)
+            logger.info(f"Average Latency of Sampling: {avg_sampling_latency_ms} ms")
+            logger.info(f"Average Throughput of Sampling: {avg_sampling_thrpt} tps")
+
+            # Calculate first token generated metrics
+            first_token_latency_s = accelerator_times[0]
+            first_token_latency_ms = first_token_latency_s * 1000
+            first_token_thrpt = batch_size * (1 / first_token_latency_s)
+            logger.info(f"Latency of First Token Generated: {first_token_latency_ms} ms")
+            logger.info(f"Throughput of First Token Generated: {first_token_thrpt} tps")
+
+            ####################################################
+            # Calculate first `halfway` token generated metrics
+            ####################################################
+
+            halfway = args.generation_length // 2
+            halfway_token_latency_s = sum(accelerator_times[:halfway]) / len(accelerator_times[:halfway])
+            halfway_token_latency_ms = halfway_token_latency_s * 1000
+            halfway_token_thrpt = batch_size * (1 / halfway_token_latency_s)
+            logger.info(f"Average Latency of First {halfway} Tokens Generated: {halfway_token_latency_ms} ms")
+            logger.info(f"Average Throughput of First {halfway} Tokens Generated: {halfway_token_thrpt} tps")
+
+            #########################################
+            # Calculate all tokens generated metrics
+            #########################################
+
+            all_token_latency_s = sum(accelerator_times) / len(accelerator_times)
+            all_token_latency_ms = all_token_latency_s * 1000
+            all_token_thrpt = batch_size * (1 / all_token_latency_s)
+            logger.info(
+                f"Average Latency of First {args.generation_length} Tokens Generated: {all_token_latency_ms} ms"
+            )
+            logger.info(f"Average Throughput of First {args.generation_length} Tokens Generated: {all_token_thrpt} tps")
+
+            ###############################
+            # Calculate wall clock metrics
+            ###############################
+
+            wall_clock_latency_s = wall_clock_end_time - wall_clock_start_time
+            wall_clock_thrpt = batch_size * ((prompt_length + args.generation_length) / wall_clock_latency_s)
+            logger.info(f"Wall-Clock Latency: {wall_clock_latency_s} s")
+            logger.info(
+                f"Wall-Clock Throughput: {batch_size * ((prompt_length + args.generation_length) / wall_clock_latency_s)} tps"
+            )
+
+            # Add metrics to CSV
+            logger.info("Adding results to CSV")
+            csv_metrics.extend(
+                [
+                    avg_sampling_latency_ms,
+                    avg_sampling_thrpt,
+                    first_token_latency_ms,
+                    first_token_thrpt,
+                    halfway_token_latency_ms,
+                    halfway_token_thrpt,
+                    all_token_latency_ms,
+                    all_token_thrpt,
+                    wall_clock_latency_s,
+                    wall_clock_thrpt,
+                ]
+            )
+            all_csv_metrics.append(csv_metrics)
+
+        except:  # noqa: E722
+            logger.info(f"Could not benchmark at batch size = {batch_size}, prompt length = {prompt_length}")
+
+    filename = f"benchmark_{args.engine}_e2e_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
+    save_results(all_csv_metrics, filename, args.generation_length)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py
index c9ff384a4c856..b649f7ab65049 100644
--- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py
+++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py
@@ -1,3 +1,8 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.  All rights reserved.
+# Licensed under the MIT License.  See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
 from __future__ import annotations
 
 import argparse
@@ -944,9 +949,11 @@ def main():
                             ort_quantization.quantize_dynamic(
                                 fp32_path,
                                 int8_path,
-                                op_types_to_quantize=["MatMul", "Gemm", "Gather"]
-                                if args.quantize_embedding_layer
-                                else ["MatMul", "Gemm"],
+                                op_types_to_quantize=(
+                                    ["MatMul", "Gemm", "Gather"]
+                                    if args.quantize_embedding_layer
+                                    else ["MatMul", "Gemm"]
+                                ),
                                 per_channel=args.quantize_per_channel,
                                 reduce_range=args.quantize_reduce_range,
                                 use_external_data_format=True,
diff --git a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py
index 72192ce8d8c63..3b53f60758b27 100644
--- a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py
+++ b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py
@@ -1,3 +1,8 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.  All rights reserved.
+# Licensed under the MIT License.  See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
 import os
 
 import torch.distributed as dist
diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
index a329b73259dda..5aed55c12f38f 100644
--- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
+++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
@@ -1,8 +1,13 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.  All rights reserved.
+# Licensed under the MIT License.  See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
 from __future__ import annotations
 
 import numpy as np
 import torch
-from transformers import AutoConfig
+from transformers import AutoConfig, AutoTokenizer
 
 from onnxruntime import InferenceSession, OrtValue
 
@@ -222,7 +227,8 @@ def get_msft_sample_inputs(
 # Create past_key_values
 # Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
 def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
-    num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads
+    num_heads = config.num_key_value_heads // world_size
+    head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
     torch_dtype = torch.float16 if use_fp16 else torch.float32
     past_kv = [
         (
@@ -268,6 +274,8 @@ def convert_inputs_for_ort(
     return ort_inputs
 
 
+# Re-allocate KV caches from (batch_size, num_heads, past_sequence_length, head_size) to
+# (batch_size, num_heads, max_sequence_length, head_size) for past-present buffer sharing
 def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int):
     for k, v in ort_inputs.items():
         # Allocate new buffers with max_sequence_length for GQA
@@ -280,13 +288,41 @@ def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_se
     return ort_inputs
 
 
-# Add IO bindings for execution providers
-def add_io_bindings(
+# Verify ONNX Runtime inputs with model
+def verify_ort_inputs(model: InferenceSession, ort_inputs: dict):
+    # Check that all model inputs will be provided
+    model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
+    user_inputs = set(ort_inputs.keys())
+    missing_inputs = model_inputs - user_inputs
+    if len(missing_inputs):
+        print(f"The following model inputs are missing: {missing_inputs}")
+        raise Exception("There are missing inputs to the model. Please add them and try again.")
+
+    # Remove unnecessary inputs from model inputs
+    unnecessary_inputs = user_inputs - model_inputs
+    if len(unnecessary_inputs):
+        for unnecessary_input in unnecessary_inputs:
+            print(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
+            del ort_inputs[unnecessary_input]
+
+    return ort_inputs
+
+
+# Add IO bindings for execution providers using OrtValue
+# Use when you need to run inference once or twice to save memory
+def add_io_bindings_as_ortvalues(
     model: InferenceSession, ort_inputs: dict, device: str, device_id: int, use_gqa: bool, kv_cache_ortvalues: dict
 ):
     io_binding = model.io_binding()
 
+    model_inputs = set(map(lambda i: i.name, model.get_inputs()))
     for k, v in ort_inputs.items():
+        # Use this check to handle scenarios such as INT4 CUDA and FP16 CUDA models with
+        # GQA + RotaryEmbedding fusion where `position_ids` is removed as an ONNX model input
+        # but `position_ids` is used as a PyTorch model input
+        if k not in model_inputs:
+            continue
+
         # Bind OrtValue inputs to device
         if use_gqa and ("cache" in k or "past_key_values" in k):
             if k not in kv_cache_ortvalues:
@@ -310,3 +346,163 @@ def add_io_bindings(
             io_binding.bind_output(name, device_type=device, device_id=device_id)
 
     return io_binding, kv_cache_ortvalues
+
+
+# Add IO bindings for execution providers using PyTorch tensors
+# Use when you need to run inference many times
+def add_io_bindings_as_tensors(
+    model: InferenceSession, inputs: dict, outputs: dict, use_fp16: bool, use_buffer_share: bool
+):
+    # Verify model inputs
+    inputs = verify_ort_inputs(model, inputs)
+
+    device = None
+    pt_to_np = {
+        "torch.int32": np.int32,
+        "torch.int64": np.int64,
+        "torch.float16": np.float16,
+        "torch.float32": np.float32,
+    }
+
+    # Bind inputs/outputs to IO binding
+    io_binding = model.io_binding()
+    for k, v in inputs.items():
+        io_binding.bind_input(
+            name=k,
+            device_type=v.device.type,
+            device_id=0 if v.device.type == "cpu" else v.device.index,
+            element_type=pt_to_np[repr(v.dtype)],
+            shape=tuple(v.shape),
+            buffer_ptr=v.data_ptr(),
+        )
+        device = v.device
+
+    for output in model.get_outputs():
+        name = output.name
+        if use_buffer_share and "present" in name:
+            # Bind KV cache outputs to KV cache inputs
+            v = inputs[name.replace("present", "past_key_values")]
+            io_binding.bind_output(
+                name=name,
+                device_type=v.device.type,
+                device_id=v.device.index,
+                element_type=np.float16,
+                shape=tuple(v.shape),
+                buffer_ptr=v.data_ptr(),
+            )
+        else:
+            v = outputs[name]
+            io_binding.bind_output(
+                name=name,
+                device_type=device.type,
+                device_id=0 if device.type == "cpu" else device.index,
+                element_type=(np.float16 if use_fp16 else np.float32),
+                shape=tuple(v.shape),
+                buffer_ptr=v.data_ptr(),
+            )
+
+    return io_binding
+
+
+# Get actual inputs when using real data (instead of sample data) and initialize outputs
+def get_initial_inputs_and_outputs(
+    config: AutoConfig,
+    tokenizer: AutoTokenizer,
+    requested_length: int,
+    prompt: list[str],
+    device: torch.device,
+    use_fp16: bool,
+    use_buffer_share: bool,
+    engine: str,
+):
+    tokenizer.pad_token = "[PAD]"
+    encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
+    torch_dtype = torch.float16 if use_fp16 else torch.float32
+
+    # input_ids:      pad token id is 0
+    # attention_mask: pad token id is 0
+    # position_ids:   pad token id is 1
+    input_ids = torch.tensor(encodings_dict["input_ids"], device=device, dtype=torch.int64)
+    attention_mask = torch.tensor(encodings_dict["attention_mask"], device=device, dtype=torch.int64)
+    position_ids = get_position_ids(attention_mask, use_past_kv=False)
+
+    # Check if tokenized prompt length matches the requested prompt length
+    tokenized_length = input_ids.shape[-1]
+    if tokenized_length > requested_length:
+        # Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
+        input_ids = input_ids[:, :requested_length]
+        attention_mask = attention_mask[:, :requested_length]
+        position_ids = get_position_ids(attention_mask, use_past_kv=False)
+    elif tokenized_length < requested_length:
+        # Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
+        input_ids_first_col = input_ids[:, 0].unsqueeze(0).T
+        attention_mask_first_col = attention_mask[:, 0].unsqueeze(0).T
+        for _ in range(requested_length - tokenized_length):
+            input_ids = torch.hstack((input_ids_first_col, input_ids))
+            attention_mask = torch.hstack((attention_mask_first_col, attention_mask))
+        position_ids = get_position_ids(attention_mask, use_past_kv=False)
+
+    tokenized_length = input_ids.shape[-1]
+    assert tokenized_length == requested_length
+
+    # Create inputs
+    inputs = {
+        "input_ids": input_ids.contiguous() if engine == "ort" else input_ids,
+        "attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask,
+        "position_ids": position_ids.contiguous() if engine == "ort" else position_ids,
+    }
+    if engine != "ort":
+        inputs["past_key_values"] = []
+
+    # Get shape of KV cache inputs
+    batch_size, sequence_length = input_ids.shape
+    max_sequence_length = config.max_position_embeddings
+    num_heads = config.num_key_value_heads
+    head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
+
+    # Create KV cache inputs
+    for i in range(config.num_hidden_layers):
+        past_key = torch.zeros(
+            batch_size,
+            num_heads,
+            max_sequence_length if use_buffer_share else 0,
+            head_size,
+            device=device,
+            dtype=torch_dtype,
+        )
+        past_value = torch.zeros(
+            batch_size,
+            num_heads,
+            max_sequence_length if use_buffer_share else 0,
+            head_size,
+            device=device,
+            dtype=torch_dtype,
+        )
+        if engine == "ort":
+            inputs.update(
+                {
+                    f"past_key_values.{i}.key": past_key.contiguous(),
+                    f"past_key_values.{i}.value": past_value.contiguous(),
+                }
+            )
+        else:
+            inputs["past_key_values"].append((past_key, past_value))
+
+    outputs = None
+    if engine == "ort":
+        # Create outputs
+        logits = torch.zeros(batch_size, sequence_length, config.vocab_size, device=device, dtype=torch_dtype)
+        outputs = {"logits": logits.contiguous()}
+        if not use_buffer_share:
+            for i in range(config.num_hidden_layers):
+                present_key = torch.zeros(
+                    batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
+                )
+                present_value = torch.zeros(
+                    batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
+                )
+                outputs.update(
+                    {f"present.{i}.key": present_key.contiguous(), f"present.{i}.value": present_value.contiguous()}
+                )
+
+    return inputs, outputs
diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py
index f41a90208c51b..9cbc9af7fe9b5 100644
--- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py
+++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py
@@ -1,3 +1,8 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.  All rights reserved.
+# Licensed under the MIT License.  See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
 from __future__ import annotations
 
 import argparse
@@ -10,7 +15,7 @@
 from benchmark_helper import setup_logger
 from dist_settings import get_rank, get_size
 from llama_inputs import (
-    add_io_bindings,
+    add_io_bindings_as_ortvalues,
     convert_inputs_for_ort,
     get_merged_sample_with_past_kv_inputs,
     get_sample_inputs,
@@ -123,7 +128,7 @@ def verify_parity(
 
     # Add IO bindings for non-CPU execution providers
     if args.execution_provider != "cpu":
-        io_binding, kv_cache_ortvalues = add_io_bindings(
+        io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
             ort_model,
             inputs,
             args.execution_provider,
diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py
index 89b459c80beec..d570e2d7ee086 100644
--- a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py
+++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py
@@ -1,3 +1,8 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.  All rights reserved.
+# Licensed under the MIT License.  See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
 import logging
 import os
 
diff --git a/onnxruntime/python/tools/transformers/models/llama/prompts.json b/onnxruntime/python/tools/transformers/models/llama/prompts.json
new file mode 100644
index 0000000000000..5d8fae99dbc7e
--- /dev/null
+++ b/onnxruntime/python/tools/transformers/models/llama/prompts.json
@@ -0,0 +1,11 @@
+{
+    "16": "How are astronauts launched into space quickly on those rockets? ",
+    "64": "Today, we will learn how to bake a chocolate cake. First, you need to have all of the ingredients to bake. Otherwise, the chocolate cake won't be tasty. You will also need a large baking pan to hold the batter. ",
+    "256": "Risk Management and Insurance (RMI) is a field that focuses on the identification, assessment and financial mitigation of risk. It's about insurance but also more than that. For example, insurance companies look at risk factors such as age, gender and medical history to determine how much they will charge for life insurance coverage. However, RMI is not just about buying insurance (although it is a big part of this). It is also about taking steps to reduce the likelihood that something bad happens in the first place. For example, you may think twice before crossing a busy road if there's a high risk of being hit by a car or getting injured. In addition to insurance companies and financial services firms, RMI professionals work with individuals (customers), businesses and other entities (clients). Their job is to identify potential risks and help mitigate them before they become problems for their clients. This can include helping people prepare financially for unexpected events like losing a job or being injured in an accident, as well as assisting businesses with managing risk exposure from things like natural disasters or cyber attacks. Insurance companies use RMI to ",
+    "1024": "Risk Management and Insurance (RMI) is a field that focuses on the identification, assessment and financial mitigation of risk. It's about insurance but also more than that. For example, insurance companies look at risk factors such as age, gender and medical history to determine how much they will charge for life insurance coverage. However, RMI is not just about buying insurance (although it is a big part of this). It is also about taking steps to reduce the likelihood that something bad happens in the first place. For example, you may think twice before crossing a busy road if there's a high risk of being hit by a car or getting injured. In addition to insurance companies and financial services firms, RMI professionals work with individuals (customers), businesses and other entities (clients). Their job is to identify potential risks and help mitigate them before they become problems for their clients. This can include helping people prepare financially for unexpected events like losing a job or being injured in an accident, as well as assisting businesses with managing risk exposure from things like natural disasters or cyber attacks. Insurance companies use RMI to assess the level of risk associated with potential customers and determine how much they should charge them for coverage. For example, if you are a healthy 25-year old male who doesn't smoke and has never been in an accident, your insurance premiums will likely be lower than those of someone else who fits into one or more of these categories (or all three). Risk Management & Insurance is the process by which you can protect yourself from financial loss. It's about taking control of your money and making sure that it's safe, secure and accessible to you when you need it most. The first step in risk management is understanding what risks are important to you as an individual or a family member who may depend on the income generated by these investments for their livelihood. Once you have identified these key risk factors, then we can help identify how best to manage them through various strategies such as setting up automatic payments into savings accounts so that money is always available when needed most; setting aside emergency funds in case something unexpected happens (e.g., illness); investing wisely so that returns outpace inflation over time; diversifying portfolios by adding stocks and bonds which will help reduce volatility while still providing growth potential through dividends/interest payments over longer periods of time than if invested solely into one type of asset class alone etc. The field of risk management and insurance is growing rapidly, as more people become aware of the potential dangers that can arise from an unforeseen event or accident. As a result, there are many different careers within this field that you may want to consider if you're interested in working with risks and helping others protect themselves from them.One common career path in risk management is as an insurance agent/broker. This person would work for an insurance company or brokerage firm, selling policies to clients who need coverage against things like car accidents or home damage caused by natural disasters such as fires or floods. Insurance agents typically work on commission (i.e., they receive a percentage of every sale). This is important because it means that the more successful an agent is at selling policies, the higher his/her income will be. Another career option within risk management is working for an insurance company itself rather than as an external broker or salesperson. In this case, you'd help manage claims made by policyholders who have been injured through no fault of their own (for example after being hit by another driver). You can also work in risk analysis, a field that involves analyzing the potential risks associated with various investments and projects. This is done to determine whether or not an opportunity has enough upside to justify taking on any related risks. In addition, you might also be responsible for developing strategies to minimize those risks so they don't result in big losses if something goes wrong down the road. If your goal is to work as a broker or agent, then there are some prerequisites that will need to be met before beginning this career path: You must have an associate's degree from an accredited college; pass an exam administered by state regulators (the Series 6) and/or complete additional training offered by professional organizations such as NAFA, which stands for National Association of Financial Advisors. After meeting these requirements, you'll then need to find employment at one or more insurance companies where they offer positions that allow new hires some flexibility when starting out their careers.Risk management and insurance is a broad field that includes many different types of jobs. ",
+    "2048": "Artificial Intelligence (AI) is a transformative technology that has the potential to revolutionize society in many ways. AI can be used to enhance the accuracy and efficiency of decision-making, improve lives through new apps and services, and solve some of the thorny policy problems of climate change, infrastructure, and healthcare. In this essay, I will discuss some of the ways AI can benefit society. One of the most significant benefits of AI is its ability to improve healthcare. AI can assist doctors, nurses, and other healthcare professionals in making better diagnoses and faster decisions on a course of treatment, based on the large amount of data that currently exists. AI allows doctors to pinpoint effective drugs that may have otherwise been overlooked and can identify higher-risk individuals before any human can. AI can also help relieve the burden on healthcare professionals by taking care of routine data collection and filing, freeing up time for other higher-value activities. Another area where AI can benefit society is in the fight against climate change. AI can be used to analyze vast amounts of data, identify patterns, and provide accurate predictions. It can help us forecast what further spread of pandemics is going to look like, and track their development around the world. AI can also help us predict the impact of climate change on our planet and develop strategies to mitigate its effects. For example, AI can be used to optimize energy consumption, reduce waste, and improve the efficiency of transportation systems. AI can also benefit society by improving education. AI-powered educational tools can help students learn more effectively by providing personalized learning experiences tailored to their individual needs. AI can also help teachers by automating routine tasks such as grading and providing feedback on student work. This can free up time for teachers to focus on more important tasks such as lesson planning and student engagement. AI can also benefit society by improving public safety. AI-powered surveillance systems can help law enforcement agencies detect and prevent crime more effectively. AI can also be used to analyze social media data to identify potential threats and prevent them before they occur. For example, AI can be used to detect hate speech and other forms of online harassment, which can help prevent cyberbullying and other forms of online abuse. Finally, AI can benefit society by improving the economy. AI can help businesses become more efficient by automating routine tasks and providing insights into customer behavior. This can help businesses make better decisions and improve their bottom line. AI can also help create new jobs by enabling the development of new products and services that were previously impossible. In conclusion, AI has the potential to benefit society in many ways. From improving healthcare and education to fighting climate change and improving public safety, AI can help us solve some of the most pressing problems facing our world today. As we continue to develop and refine this transformative technology, it is important that we do so in an ethical and responsible manner, ensuring that the benefits of AI are shared by all members of society. AI has been a topic of discussion for many years, and while it has brought many benefits to society, there are also concerns about its impact. In this essay, I will discuss some of the reasons why AI may not help society. Firstly, AI can be biased. AI systems are designed by humans, and they can be infused with the biases of their creators. This can lead to discrimination against certain groups of people and can perpetuate existing inequalities in society. Additionally, AI can lack transparency, making it difficult to understand how decisions are being made. This can lead to mistrust of AI systems and can hinder their adoption. Secondly, AI can be used to automate jobs, which can lead to unemployment. While AI can increase productivity and efficiency, it can also lead to job displacement, particularly in industries that rely heavily on manual labor. This can have a negative impact on individuals and communities, particularly those that are already marginalized. Thirdly, AI can be used to create fake content, such as deepfakes, which can be used to spread misinformation and propaganda. This can have serious consequences for democracy and can undermine trust in institutions. Fourthly, AI can be used to create autonomous weapons, which can have devastating consequences. These weapons can make decisions without human intervention, which can lead to unintended consequences and can be difficult to control. Fifthly, AI can be used to create surveillance systems that infringe on privacy rights. These systems can be used to monitor individuals without their knowledge or consent, which can have serious consequences for civil liberties. In conclusion, while AI has many potential benefits, there are also concerns about its impact on society. It is important to consider these concerns and to ensure that AI is developed and used in a responsible and ethical manner. Within AI, there are also many subfields. Reinforcement learning is a type of machine learning algorithm that focuses on training models to make decisions in an environment in order to maximize a reward. This is typically done through trial and error, as the algorithm receives feedback in the form of rewards or punishments for its actions. Reinforcement learning has many potential benefits for society, some of which are discussed below. Firstly, reinforcement learning can be used to improve industrial automation and robotics. By training robots to learn from their own experiences, they can gain the skills necessary to perform complex tasks without human intervention. This can lead to increased efficiency and productivity in industries such as manufacturing and logistics. Secondly, reinforcement learning can be used to optimize traffic control systems. By training models to make real-time decisions based on traffic patterns and other data, traffic flow can be improved, reducing congestion and travel times. Thirdly, reinforcement learning can be used to improve healthcare. By training models to make decisions based on patient data, doctors can make more accurate diagnoses and develop more effective treatment plans. This can lead to better health outcomes for patients and can reduce healthcare costs. Fourthly, reinforcement learning can be used to improve education. By training models to adapt to individual student needs, personalized learning experiences can be created that are tailored to each student\u2019s strengths and weaknesses. This can lead to improved academic performance and can help to close the achievement gap. Finally, reinforcement learning can be used to improve environmental sustainability. By training models to make decisions based on environmental data, such as weather patterns and pollution levels, more effective policies can be developed to reduce carbon emissions and protect natural resources. In conclusion, reinforcement learning has many potential benefits for society. By training models to make decisions based on feedback from their environment, we can create more efficient and effective systems in a wide range of fields. However, it is important to consider the ethical implications of these technologies and to ensure that they are developed and used in a responsible and ethical manner. Multi-modal models are another type of machine learning that can process and find relationships between different types of data, such as images, video, audio, and text. They have the potential to revolutionize many aspects of our lives, from healthcare to transportation to education. In this essay, I will discuss how multi-modal models can help society in various ways. One of the most significant benefits of multi-modal models is their ability to transform unstructured data into structured data that can be analyzed. For example, a company could use a multi-modal model to extract data from images or PDFs of invoices or receipts. This would enable them to analyze the data more efficiently and make better-informed decisions. Another benefit of multi-modal models is their ability to cater to various learning styles. Blended and multi-modal learning can reach people who benefit from different learning styles. By understanding their individual learning styles, employees can leverage resources that are compatible with how they process information most effectively. Multi-modal models can also help improve healthcare. For example, they can be used to analyze medical images and identify patterns that might be difficult for human doctors to detect. This can lead to earlier diagnoses and more effective treatments. In addition, multi-modal models can help improve transportation. For example, they can be used to analyze traffic patterns and optimize traffic flow. This can help reduce congestion and improve safety on the roads. Finally, multi-modal models can help improve education. For example, they can be used to create personalized learning experiences for students based on their individual learning styles. This can help students learn more effectively and efficiently. In conclusion, multi-modal models have the potential to help society in many ways. They can transform unstructured data into structured data, cater to various learning styles, improve healthcare, transportation, and education. However, like any new technology, it is important to approach it with caution and consider the potential risks and benefits. I hope this essay has provided some insight into the potential benefits of multi-modal models. Throughout this essay, I have demonstrated the numerous benefits that artificial intelligence will bring to our society. I have also shown some examples of various categories within artificial intelligence that have varying purposes. It is important to consider that each category has its own purpose and has its own pros and cons to it. In conclusion, we must use AI responsibly. ",
+    "3840": "Artificial Intelligence (AI) is a transformative technology that has the potential to revolutionize society in many ways. AI can be used to enhance the accuracy and efficiency of decision-making, improve lives through new apps and services, and solve some of the thorny policy problems of climate change, infrastructure, and healthcare. In this essay, I will discuss some of the ways AI can benefit society. One of the most significant benefits of AI is its ability to improve healthcare. AI can assist doctors, nurses, and other healthcare professionals in making better diagnoses and faster decisions on a course of treatment, based on the large amount of data that currently exists. AI allows doctors to pinpoint effective drugs that may have otherwise been overlooked and can identify higher-risk individuals before any human can. AI can also help relieve the burden on healthcare professionals by taking care of routine data collection and filing, freeing up time for other higher-value activities. Another area where AI can benefit society is in the fight against climate change. AI can be used to analyze vast amounts of data, identify patterns, and provide accurate predictions. It can help us forecast what further spread of pandemics is going to look like, and track their development around the world. AI can also help us predict the impact of climate change on our planet and develop strategies to mitigate its effects. For example, AI can be used to optimize energy consumption, reduce waste, and improve the efficiency of transportation systems. AI can also benefit society by improving education. AI-powered educational tools can help students learn more effectively by providing personalized learning experiences tailored to their individual needs. AI can also help teachers by automating routine tasks such as grading and providing feedback on student work. This can free up time for teachers to focus on more important tasks such as lesson planning and student engagement. AI can also benefit society by improving public safety. AI-powered surveillance systems can help law enforcement agencies detect and prevent crime more effectively. AI can also be used to analyze social media data to identify potential threats and prevent them before they occur. For example, AI can be used to detect hate speech and other forms of online harassment, which can help prevent cyberbullying and other forms of online abuse. Finally, AI can benefit society by improving the economy. AI can help businesses become more efficient by automating routine tasks and providing insights into customer behavior. This can help businesses make better decisions and improve their bottom line. AI can also help create new jobs by enabling the development of new products and services that were previously impossible. In conclusion, AI has the potential to benefit society in many ways. From improving healthcare and education to fighting climate change and improving public safety, AI can help us solve some of the most pressing problems facing our world today. As we continue to develop and refine this transformative technology, it is important that we do so in an ethical and responsible manner, ensuring that the benefits of AI are shared by all members of society. AI has been a topic of discussion for many years, and while it has brought many benefits to society, there are also concerns about its impact. In this essay, I will discuss some of the reasons why AI may not help society. Firstly, AI can be biased. AI systems are designed by humans, and they can be infused with the biases of their creators. This can lead to discrimination against certain groups of people and can perpetuate existing inequalities in society. Additionally, AI can lack transparency, making it difficult to understand how decisions are being made. This can lead to mistrust of AI systems and can hinder their adoption. Secondly, AI can be used to automate jobs, which can lead to unemployment. While AI can increase productivity and efficiency, it can also lead to job displacement, particularly in industries that rely heavily on manual labor. This can have a negative impact on individuals and communities, particularly those that are already marginalized. Thirdly, AI can be used to create fake content, such as deepfakes, which can be used to spread misinformation and propaganda. This can have serious consequences for democracy and can undermine trust in institutions. Fourthly, AI can be used to create autonomous weapons, which can have devastating consequences. These weapons can make decisions without human intervention, which can lead to unintended consequences and can be difficult to control. Fifthly, AI can be used to create surveillance systems that infringe on privacy rights. These systems can be used to monitor individuals without their knowledge or consent, which can have serious consequences for civil liberties. In conclusion, while AI has many potential benefits, there are also concerns about its impact on society. It is important to consider these concerns and to ensure that AI is developed and used in a responsible and ethical manner. Within AI, there are also many subfields. Reinforcement learning is a type of machine learning algorithm that focuses on training models to make decisions in an environment in order to maximize a reward. This is typically done through trial and error, as the algorithm receives feedback in the form of rewards or punishments for its actions. Reinforcement learning has many potential benefits for society, some of which are discussed below. Firstly, reinforcement learning can be used to improve industrial automation and robotics. By training robots to learn from their own experiences, they can gain the skills necessary to perform complex tasks without human intervention. This can lead to increased efficiency and productivity in industries such as manufacturing and logistics. Secondly, reinforcement learning can be used to optimize traffic control systems. By training models to make real-time decisions based on traffic patterns and other data, traffic flow can be improved, reducing congestion and travel times. Thirdly, reinforcement learning can be used to improve healthcare. By training models to make decisions based on patient data, doctors can make more accurate diagnoses and develop more effective treatment plans. This can lead to better health outcomes for patients and can reduce healthcare costs. Fourthly, reinforcement learning can be used to improve education. By training models to adapt to individual student needs, personalized learning experiences can be created that are tailored to each student\u2019s strengths and weaknesses. This can lead to improved academic performance and can help to close the achievement gap. Finally, reinforcement learning can be used to improve environmental sustainability. By training models to make decisions based on environmental data, such as weather patterns and pollution levels, more effective policies can be developed to reduce carbon emissions and protect natural resources. In conclusion, reinforcement learning has many potential benefits for society. By training models to make decisions based on feedback from their environment, we can create more efficient and effective systems in a wide range of fields. However, it is important to consider the ethical implications of these technologies and to ensure that they are developed and used in a responsible and ethical manner. Multi-modal models are another type of machine learning that can process and find relationships between different types of data, such as images, video, audio, and text. They have the potential to revolutionize many aspects of our lives, from healthcare to transportation to education. In this essay, I will discuss how multi-modal models can help society in various ways. One of the most significant benefits of multi-modal models is their ability to transform unstructured data into structured data that can be analyzed. For example, a company could use a multi-modal model to extract data from images or PDFs of invoices or receipts. This would enable them to analyze the data more efficiently and make better-informed decisions. Another benefit of multi-modal models is their ability to cater to various learning styles. Blended and multi-modal learning can reach people who benefit from different learning styles. By understanding their individual learning styles, employees can leverage resources that are compatible with how they process information most effectively. Multi-modal models can also help improve healthcare. For example, they can be used to analyze medical images and identify patterns that might be difficult for human doctors to detect. This can lead to earlier diagnoses and more effective treatments. In addition, multi-modal models can help improve transportation. For example, they can be used to analyze traffic patterns and optimize traffic flow. This can help reduce congestion and improve safety on the roads. Finally, multi-modal models can help improve education. For example, they can be used to create personalized learning experiences for students based on their individual learning styles. This can help students learn more effectively and efficiently. In conclusion, multi-modal models have the potential to help society in many ways. They can transform unstructured data into structured data, cater to various learning styles, improve healthcare, transportation, and education. However, like any new technology, it is important to approach it with caution and consider the potential risks and benefits. I hope this essay has provided some insight into the potential benefits of multi-modal models. Semi-supervised learning is a type of machine learning that falls in between supervised and unsupervised learning. It is a method that uses a small amount of labeled data and a large amount of unlabeled data to train a model. The goal of semi-supervised learning is to learn a function that can accurately predict the output variable based on the input variables, similar to supervised learning. However, unlike supervised learning, the algorithm is trained on a dataset that contains both labeled and unlabeled data. Semi-supervised learning is particularly useful when there is a large amount of unlabeled data available, but it\u2019s too expensive or difficult to label all of it. The primary advantage of semi-supervised learning is that it can reduce the amount of annotated data used. This is particularly useful when labeled data is scarce or expensive to obtain. By using a small amount of labeled data and a large amount of unlabeled data, semi-supervised learning algorithms can learn from both types of data and improve their accuracy. Semi-supervised learning algorithms are also capable of consolidating overfitting tendencies, which is a common problem in supervised learning. Another advantage of semi-supervised learning is that it is versatile. It can be applied in various situations, from image recognition to crawlers. For example, in text classification, the goal is to classify a given text into one or more predefined categories. Semi-supervised learning can be used to train a text classification model using a small amount of labeled data and a large amount of unlabeled text data. In image classification, the goal is to classify a given image into one or more predefined categories. Semi-supervised learning can be used to train an image classification model using a small amount of labeled data and a large amount of unlabeled image data. In anomaly detection, the goal is to detect patterns or observations that are unusual or different from the norm. Semi-supervised learning can be used to detect anomalies using a small amount of labeled data and a large amount of unlabeled data. Semi-supervised learning algorithms are also stable and simple. They have high efficiency and can be used to improve the performance and generalization of models. However, semi-supervised learning algorithms also have some disadvantages. One of the main disadvantages is that they require a large amount of unlabeled data to be effective. If there is not enough unlabeled data available, the algorithm may not be able to learn effectively. Additionally, semi-supervised learning algorithms can be sensitive to the quality of the labeled data. If the labeled data is noisy or incorrect, the algorithm may not be able to learn effectively. In conclusion, semi-supervised learning is a powerful tool that can be used to improve the accuracy and generalization of machine learning models. It is particularly useful when labeled data is scarce or expensive to obtain. Semi-supervised learning algorithms can learn from both labeled and unlabeled data, which makes them versatile and capable of consolidating overfitting tendencies. However, semi-supervised learning algorithms also have some disadvantages, such as requiring a large amount of unlabeled data to be effective and being sensitive to the quality of the labeled data. Despite these disadvantages, semi-supervised learning is a valuable technique that can be used to improve the performance of machine learning models. Supervised learning is a type of machine learning that involves training a model on labeled data. The goal of supervised learning is to learn a function that can accurately predict the output variable based on the input variables. Supervised learning is widely used in various fields, including image recognition, speech recognition, natural language processing, and more. One of the primary advantages of supervised learning is that it allows for accurate predictions. Supervised learning models can provide highly accurate predictions or classifications when trained on a diverse and representative dataset. This makes supervised learning particularly useful in situations where accuracy is critical, such as in medical diagnosis or fraud detection. Another advantage of supervised learning is that it is easy to understand and implement. Supervised learning algorithms are relatively simple and can be implemented using a variety of programming languages and libraries. This makes it accessible to a wide range of developers and data scientists. Supervised learning is also versatile. It can be applied to a wide range of problem domains, making it a flexible approach for various industries and applications. For example, in image classification, the goal is to classify a given image into one or more predefined categories. Supervised learning can be used to train an image classification model using a labeled dataset of images and their corresponding categories. In speech recognition, the goal is to transcribe spoken words into text. Supervised learning can be used to train a speech recognition model using a labeled dataset of audio recordings and their corresponding transcriptions. Supervised learning algorithms are also capable of handling missing data. If there is missing data in the labeled dataset, supervised learning algorithms can still learn from the available data and make accurate predictions. This is particularly useful in situations where data is incomplete or noisy. However, supervised learning algorithms also have some disadvantages. One of the main disadvantages is that they require a large amount of labeled data to be effective. If there is not enough labeled data available, the algorithm may not be able to learn effectively. Additionally, supervised learning algorithms can be sensitive to the quality of the labeled data. If the labeled data is noisy or incorrect, the algorithm may not be able to learn effectively. In conclusion, supervised learning is a powerful tool that can be used to make accurate predictions and classifications. It is easy to understand and implement, and it is versatile enough to be applied to a wide range of problem domains. However, supervised learning algorithms also have some disadvantages, such as requiring a large amount of labeled data to be effective and being sensitive to the quality of the labeled data. Despite these disadvantages, supervised learning is a valuable technique that can be used to improve the performance of machine learning models. Unsupervised learning is a type of machine learning that involves training a model on unlabeled data. The goal of unsupervised learning is to learn the underlying structure of the data, without any prior knowledge of the output variable. Unsupervised learning is widely used in various fields, including image recognition, natural language processing, and more. One of the primary advantages of unsupervised learning is that it can handle large amounts of unlabeled and unstructured data. This makes unsupervised learning particularly useful in situations where labeled data is scarce or expensive to obtain. By using unsupervised learning algorithms, we can learn from the available data and make accurate predictions. Another advantage of unsupervised learning is that it can identify previously undetected patterns in data. Unsupervised learning algorithms can be used to cluster data points into groups based on their similarities. This can be useful in various applications, such as customer segmentation, anomaly detection, and more. Unsupervised learning algorithms are also capable of dimensionality reduction. This is particularly useful when dealing with high-dimensional data, such as images or text. By reducing the dimensionality of the data, unsupervised learning algorithms can improve the efficiency and accuracy of the model. Unsupervised learning algorithms are also capable of feature learning. Feature learning is the process of automatically learning features from the input data. This can be useful in various applications, such as image recognition, where the algorithm can learn features such as edges, corners, and more. However, unsupervised learning algorithms also have some disadvantages. One of the main disadvantages is that they require a large amount of unlabeled data to be effective. If there is not enough unlabeled data available, the algorithm may not be able to learn effectively. Additionally, unsupervised learning algorithms can be sensitive to the quality of the data. If the data is noisy or incorrect, the algorithm may not be able to learn effectively. As you can see, artificial intelligence (AI) is a wide-ranging field that encompasses various sub-fields. Some of the sub-fields that we have previously discussed include reinforcement learning, multi-modal learning, semi-supervised learning, supervised learning, unsupervised learning, and much more. There are also many application domains for artificial intelligence (AI) that can utilize it. Throughout this essay, I have demonstrated the numerous benefits that artificial intelligence (AI) will bring to our society. I have also shown some examples of various categories within artificial intelligence that have varying purposes. It is important to consider that each category has its own purpose and has its own pros and cons to it. What do you think artificial intelligence will bring to our society? Will it be used in a responsible manner? ",
+    "4096": "In the heart of Eldoria, where ancient forests whispered secrets and rivers sang forgotten melodies, lay the Enchanted Labyrinth. Its walls, adorned with shimmering runes, concealed a portal to realms unknown. Few dared to venture inside, for the labyrinth was said to twist time and reality. Evelyn, a curious young mage, stood before the labyrinth's entrance. Her emerald eyes sparkled with determination. She clutched a cracked map, its ink fading like memories lost to the wind. Legends spoke of a treasure hidden deep within - a relic capable of granting any wish. As Evelyn stepped across the threshold, the air thickened. The walls shifted, rearranging themselves. She followed the faint glow of her lantern, each step echoing through eternity. Shadows danced, whispering forgotten names. Was this a dream or a nightmare? Deeper into the labyrinth, Evelyn encountered Aelar, the Guardian of Time. His silver hair flowed like moonlight, and his eyes held the weight of centuries. Aelar barred her path, his staff crackling with energy. 'Seeker,' he intoned, 'answer my riddle, and the way shall open.' Evelyn's heart raced. 'Ask, Guardian.' 'What has roots as old as time, yet dances with the wind?' She pondered, memories of her grandmother's tales flooding her mind. 'A tree,' she replied. Aelar smiled, and the walls shifted once more. 'Proceed, Seeker.' The labyrinth twisted, revealing a moonlit grove. Trees hummed ancient lullabies, and fireflies wove constellations in the air. At the center stood a weeping willow, its branches brushing the ground like a grieving widow's veil. Evelyn approached, her fingers tracing the bark. 'Why do you weep?' The willow's voice, soft as falling petals, answered, 'I guard the Tear of Eternity.' Evelyn's breath caught. The Tear - a gem said to hold memories of lost civilizations. She plucked it from a low branch, its facets reflecting forgotten faces. As Evelyn pressed onward, the labyrinth tightened its grip. She faced illusions - lovers lost, friends betrayed. Doubt gnawed at her resolve. Was the treasure worth the cost? At the labyrinth's heart, she found a mirror. Her reflection wavered, revealing her deepest desire: her sister, Lysandra, who vanished years ago. Tears blurred the glass. 'Speak your wish,' the mirror whispered. Evelyn's voice trembled. 'Bring Lysandra back.' The mirror shattered, and reality fractured. Lysandra stepped through, eyes wide with wonder. 'Evelyn?' Lysandra's return came at a cost - the labyrinth demanded balance. For every wish granted, a memory faded. Evelyn watched as her childhood laughter dissolved like mist. Together, they exited the labyrinth, the Tear pulsing in Evelyn's palm. She gazed at her sister, both joy and sorrow in her eyes. 'Was it worth it?' Lysandra asked. Evelyn smiled. 'In Eldoria, every choice we make becomes a story. And ours, dear sister, is woven in stardust and sacrifice.' And so, the Enchanted Labyrinth whispered its final secret: Wishes are threads, and memories their loom. In the land of Aetherfall, where mist-clad mountains touched the heavens and rivers whispered forgotten spells, a prophecy echoed through time. It spoke of the Starstone, a gem said to hold the universe's secrets - the key to creation and destruction. Eldric, a humble blacksmith with eyes like storm clouds, stumbled upon an ancient map. Its ink had faded, but the constellations remained. Guided by fate, he set forth, leaving his forge behind. Eldric's journey led him to the Whispering Forest, where trees conversed in hushed tones. Their leaves whispered of hidden paths and treacherous guardians. Eldric's heart pounded as he stepped into the shadows. There, he met Lyria, a forest nymph with silver hair and eyes like moonlit pools. She guarded the first clue - a riddle etched into a petal: 'In the heart of the forest, where time bends, seek the Wellspring of Echoes. There, the Starstone awaits.' Eldric followed Lyria's guidance. The Wellspring lay within a moon-kissed glade. Its waters shimmered, reflecting memories of lost lovers, ancient battles, and forgotten oaths. Eldric dipped his hand, and the riddle unfolded: 'To find the Starstone, seek the Three Keys: the tear of a fallen star, the breath of a dragon, and the song of a forgotten bard.' Eldric climbed the Stardust Peaks, where fallen stars lay embedded in the rock. Each tear held a fragment of cosmic sorrow. He found one - a sapphire gem pulsing with celestial fire. But it was guarded by Drakor, the last of the star dragons. Drakor's scales shimmered like galaxies. His eyes held eons of wisdom. 'Why seek the Tear, mortal?' 'To save Aetherfall,' Eldric replied. 'To restore balance.' Drakor nodded, and with a breath, he shattered the gem. Eldric caught the falling tear - a shard of eternity. Next, Eldric sailed to the Isle of Shadows, where the void whispered secrets. There, he faced Nyxia, the ancient shadow dragon. Her wings spanned continents, and her breath could devour stars. 'Why seek my breath?' Nyxia hissed. 'To awaken the Starstone,' Eldric said. 'To mend the rifts.' Nyxia's eyes glowed. She exhaled - a stream of darkness. Eldric captured it in a crystal vial - the Breath of the Void. The final key lay in the Bard's Hollow, where echoes of lost melodies lingered. Eldric met Silvan, a ghostly minstrel who strummed a lute of moonwood. 'Sing,' Silvan urged. 'The Song of the Forgotten.' Eldric sang of battles, love, and sacrifice. The hollow trembled, and from the mist, a spectral harp appeared. Its strings hummed - the Song of Ages. Eldric plucked the notes, and they merged into a silver key - the Song of the Forgotten. At the Nexus of Worlds, Eldric assembled the keys - the Tear, the Breath, and the Song. The ground quaked, and the Starstone emerged - a gem of cosmic hues. Its light wove reality, mending fractures in Aetherfall. But the prophecy held a twist: the Starstone demanded a choice. Eldric could use it to reshape the world or sacrifice it to heal the void. He gazed at Lyria, Drakor, Nyxia, and Silvan - their fates intertwined. With a heavy heart, he whispered, 'Balance.' And so, the Starstone shattered, its fragments seeding new constellations. Eldric returned to his forge, but his hammer now shaped more than iron - it forged destiny. Lyria, the Forest Nymph Lyria, with her silver hair and eyes like moonlit pools, remained in the Whispering Forest. She became its guardian, weaving spells to protect the ancient trees. Her laughter echoed through the glades, and travelers whispered of a nymph who danced with moonbeams. Lyria's heart held a secret - the memory of Eldric's touch, the warmth of their shared quest. She tended to the Wellspring of Echoes, ensuring its waters flowed through time, carrying whispers of forgotten tales. Drakor, the Last Star Dragon Drakor, the last of the star dragons, retreated to the highest peak of the Stardust Peaks. There, he curled his immense form around the shattered Tear of the Fallen. His scales absorbed its cosmic fire, and he became a living constellation - a beacon for lost souls. Drakor's breath no longer consumed stars; instead, it birthed new constellations. Travelers gazed at the night sky, seeking guidance in his patterns. Drakor's eyes held both sorrow and hope, for he knew that balance required sacrifice. Nyxia, the Ancient Shadow Dragon Nyxia, with wings spanning continents, chose a different path. She descended to the Isle of Shadows, where the void whispered secrets. There, she guarded the Abyss of Remembrance - a rift between worlds. Nyxia's breath no longer devoured stars; it sealed the rifts. She became a bridge, allowing souls to traverse realms. Those who sought lost loved ones or glimpses of forgotten memories found solace in her shadowed embrace. Nyxia's eyes held the weight of choices made and unmade, and she vowed to keep the balance intact. Silvan, the Ghostly Minstrel Silvan, the spectral minstrel, wandered the Bard's Hollow. His lute of moonwood sang melodies of love, loss, and courage. Silvan's song echoed through time, touching hearts across Aetherfall. He became the keeper of memories - the forgotten bard who whispered forgotten names. When travelers stumbled upon the hollow, Silvan strummed his lute, and their own stories surfaced. He wove their experiences into the Song of Ages, ensuring that no tale would fade into oblivion. Silvan's translucent form danced in moonlight, a bridge between the living and the departed. Eldric, the Blacksmith As for Eldric, the humble blacksmith, he returned to his forge in the village of Hearthstone. His hammer now shaped more than iron - it forged destiny. Eldric crafted talismans from the Tear of the Fallen, the Breath of the Void, and the Song of the Forgotten. These talismans healed rifts, mended broken hearts, and ignited hope. Eldric's eyes held the wisdom of realms explored, and he knew that Aetherfall's balance rested on the choices of ordinary souls. He continued to tell the tale of the Starstone, passing it down through generations, ensuring that the magic endured. And so, dear reader, the threads of fate intertwined - a forest nymph, a star dragon, a shadow, and a minstrel - all bound by the echoes of a forgotten song. The Chronicles of the Celestial Weaver In the forgotten village of Astralis, where the night sky wept silver tears, lived a young girl named Elara. Her eyes held the secrets of constellations, and her fingers danced like stardust. But Astralis suffered - a curse had befallen the heavens. The stars dimmed, their brilliance fading. Elara's grandmother, Lyris, whispered of an ancient prophecy: 'When the stars falter, seek the Celestial Weaver.' Elara vowed to unravel the mystery and save her village. Guided by Lyris's map, Elara ventured into the Veiled Forest, where moonlight wove through ancient oaks. There, she met Silas, the enigmatic weaver. His loom hummed with cosmic threads - the Loom of Eternity. 'Seek the lost constellations,' Silas said. 'Weave them anew.' Elara's heart raced. She plucked a silver thread - the remnants of Orion - and began to weave. The loom responded, stars rekindling. But the cost was memory - Elara forgot her childhood laughter. Elara's journey spanned realms: The Nebula Caves: She retrieved the Pleiades, their sisterhood echoing through time. The Comet's Trail: She chased Halley's Comet, capturing its fiery tail. The Abyss of Lyra: There, Vega's song echoed - a melody of love and longing. Each constellation restored, Elara's memories faded. She forgot her first kiss, her mother's lullabies. Yet Astralis glimmered - the stars brightened. In the Celestial Citadel, Elara faced Draco, the fallen dragon. His scales bore scars - the price of rebellion. He guarded the final constellation - the Serpent. 'Why weave the stars?' Draco hissed. 'They betrayed me.' Elara's fingers trembled. 'To save my village.' Draco's eyes softened. 'We were once kin. We'll share this memory.' As Elara wove the Serpent, she glimpsed Draco's love for Lyris - their forbidden bond. The constellation blazed, and Elara remembered both love and sacrifice. Back in Astralis, the stars blazed anew. Villagers rejoiced, but Elara's memories were fragile threads. Lyris embraced her. 'You've woven fate,' Lyris said. 'But the Loom demands balance.' Elara faced Silas. 'What price?' He smiled - a constellation of wrinkles. 'Your memories or the stars.' Elara hesitated. She remembered her grandmother's stories, her stolen kisses. She chose the stars. Elara became the new Celestial Weaver. Her memories - her life - wove into the cosmos. Astralis thrived, but Elara forgot her name, her laughter, her love. Lyris whispered, 'Weavers are forgotten, but their constellations endure.' And so, Elara wove - the forgotten girl who stitched eternity. Elara, now the Celestial Weaver, wove constellations with threads of memory. Astralis thrived - the villagers danced under starlit skies, unaware of their forgotten histories. Lyris watched her granddaughter, her eyes both proud and sorrowful. 'Elara,' Lyris whispered, 'the Loom demands more than memories.' Elara's fingers trembled. She glimpsed her own reflection in the cosmic threads - the girl who once dreamed of love and laughter. But now, her past was a constellation of faded stars. Silas, the former weaver, lingered in the shadows. His form blurred - a specter between realms. He spoke of the Whispering Veil, a boundary separating memory from oblivion. Beyond it lay forgotten worlds, lost loves, and forbidden truths. 'Cross the Veil,' Silas urged. 'Retrieve what was sacrificed.' Elara hesitated. She yearned for her stolen memories - the taste of strawberries, the warmth of a lover's touch. But the Veil was treacherous - a labyrinth of half-remembered echoes. Elara stepped into the Veil. Its mist clung to her skin, whispering secrets. She glimpsed fragments of her past - a stolen kiss, a tear shed for a fallen friend. The path forked: The Garden of Remembrance: Blooming with forgotten faces, this garden promised reunion. Elara could reclaim her lost memories, but at a cost - the stars would dim once more. The Abyss of Oblivion: A chasm of emptiness. Here, Elara could sever her ties to Astralis, becoming a true Celestial Weaver. The stars would blaze forever, but her existence would be a threadless void. Elara hesitated. She remembered Lyris's lullabies, Silas's enigmatic smile, and Draco's love for her grandmother. She yearned for her stolen laughter - the taste of strawberries, the warmth of a lover's touch. But the stars - Astralis - called to her. The village thrived, its people dancing under constellations she had rekindled. Elara's choice would echo across eternity. She faced the Veil's center - a mirror reflecting her fragmented self. Her fingers trembled. 'Balance,' she whispered. And so, Elara wove anew. She plucked threads from the Garden of Remembrance, reclaiming stolen moments. The stars dimmed, but Astralis glowed with forgotten love. Silas nodded. 'You've chosen well, Weaver.' Elara's memories returned - the taste of strawberries, the warmth of a lover's touch. She kissed Lyris's forehead, whispered Draco's name, and stepped back into Astralis. The stars blazed - the legacy of a girl who stitched eternity. Short stories like these are great to listen and read because they allow us to explore our creative minds and broaden our imaginations. They also inspire us to learn from others and can become culturally impactful. The themes of these stories can also dive deep into philosophical questions and raise awareness for important issues. The plots for these stories are sometimes based on real life events as well and can have deep emotional impact.",
+    "7936": "The Effects of Airplanes: A Closer Look Airplanes have revolutionized the way we travel, connect, and explore the world. From short domestic flights to transcontinental journeys, these metal birds have become an integral part of our lives. However, their impact extends beyond convenience and adventure. Let's delve into the effects of airplanes from various angles. Environmental Impact Fuel Consumption and Emissions Airplanes consume vast amounts of fuel during flight. For instance, a Boeing 747, with a gas tank capacity of 63,500 gallons, burns approximately five gallons of jet fuel per mile traveled. On a 4,000-mile flight, this translates to 20,000 gallons of fuel. However, when we consider the number of passengers (around 400), the fuel efficiency per traveler is surprisingly better than that of cars. A Honda Civic, which gets 30 miles per gallon, would need 133 gallons of fuel for the same distance. Even an RV, which moves just seven miles on a gallon of gasoline, would require about 285 gallons per traveler. Greenhouse Gas Emissions Airplanes emit greenhouse gases directly into the upper atmosphere, where they can linger longer and cause more damage than the same gases at lower altitudes. While air travel contributes to climate change, it's essential to recognize that other forms of transportation, such as cars and ships, also emit greenhouse gases. The challenge lies in finding ways to reduce aviation emissions without compromising connectivity and mobility. Ozone Depletion and Contrails Planes affect the concentration of other gases and pollutants in the atmosphere. They lead to a short-term increase in ozone (O3) but a long-term decrease. Contrails - those white streaks left behind by planes - can contribute to cloud formation and impact local weather patterns. Balancing the benefits of air travel with environmental concerns remains a critical challenge. Human Health Implications Jet Lag and Sleep Disruption Frequent flyers are no strangers to jet lag. Crossing time zones disrupts our circadian rhythms, affecting sleep patterns, mood, and overall well-being. Pilots, flight attendants, and passengers alike experience the effects of rapid travel across time zones. Dehydration and Blood Pressure Changes The low humidity in airplane cabins can lead to dehydration. Additionally, changes in cabin pressure affect blood pressure, especially during takeoff and landing. Staying hydrated and moving around during long flights can mitigate these effects. Risk of Contagious Diseases Airplanes put passengers in close proximity to one another. Recirculated air, shared surfaces, and confined spaces create an environment conducive to the spread of infections. While airlines take precautions, travelers should remain vigilant, especially during flu seasons. The Perspective Shift: Seeing Earth from Above Beyond the environmental and health impacts, airplanes have transformed our worldview. Before the Wright brothers' epochal breakthrough, humans were grounded, limited to terrestrial views. The advent of flight not only boosted our power of movement but also enhanced our vision. From above, we witness the curvature of the Earth, the vastness of oceans, and the intricate patterns of landscapes. Airplanes have made us global citizens, connecting us to distant lands and cultures. In conclusion, airplanes are a double-edged sword. They offer unparalleled mobility and exploration but come with environmental consequences and health considerations. As we continue to innovate and improve aviation technology, let's strive for a balance - a world where we soar through the skies while safeguarding our planet and well-being. Economic Impact Air Travel Industry The aviation industry is a significant contributor to the global economy. Airlines, airports, manufacturers, and associated services generate substantial revenue and employment. Air travel facilitates international trade, tourism, and business interactions. However, it also faces challenges such as fuel price fluctuations, competition, and regulatory complexities. Supply Chain and Cargo Transport Airplanes play a crucial role in transporting goods across continents. High-value and time-sensitive cargo, including perishable items, pharmaceuticals, and electronics, rely on air freight. The efficiency of supply chains owes much to the speed and reach of airplanes. Tourism and Local Economies Tourism heavily depends on air travel. Popular destinations thrive due to the influx of visitors arriving by plane. Local economies benefit from tourism-related activities, including hospitality, restaurants, and souvenir shops. Conversely, overreliance on tourism can strain natural resources and cultural heritage. Technological Advancements Aerospace Engineering The development of airplanes has driven advancements in aerospace engineering. Innovations in materials, aerodynamics, and propulsion systems have led to more efficient and safer aircraft. Research in areas like supersonic flight, electric planes, and autonomous drones continues to shape the industry. Navigation and Communication Airplanes rely on sophisticated navigation systems, including GPS, radar, and inertial guidance. These technologies enhance safety, accuracy, and efficiency. Communication networks allow pilots to stay connected with air traffic control, other planes, and ground stations. Social and Cultural Effects Global Connectivity Airplanes have transformed our perception of distance. What once took weeks by ship or months by land can now be accomplished in hours. Families separated by oceans reunite, students study abroad, and cultural exchange flourishes. The world feels smaller, and our interconnectedness grows. Iconic Symbols Airplanes evoke a sense of wonder and adventure. The iconic silhouettes of jumbo jets, fighter planes, and vintage biplanes symbolize human achievement and exploration. Airshows, aviation museums, and historical flights celebrate this legacy. Challenges and Future Prospects Sustainability The aviation industry faces the challenge of reducing its environmental impact. Researchers explore alternative fuels, electric propulsion, and lightweight materials. Balancing growth with sustainability remains critical. Airspace Congestion As air travel becomes more accessible, airspace congestion intensifies. Efficient air traffic management, improved routes, and next-generation air traffic control systems are essential to prevent gridlock. Security and Safety Ensuring the safety of passengers, crew, and cargo remains paramount. Rigorous security protocols, maintenance standards, and emergency preparedness are vital. In conclusion, airplanes are more than mere vessels of transportation. They shape economies, connect cultures, and inspire innovation. As we soar into the future, let's navigate the skies responsibly, appreciating both the marvels and challenges of flight. The Effects of Space Travel on the Human Body Space travel, with its awe-inspiring vistas and boundless possibilities, has captivated humanity for decades. However, venturing beyond our home planet comes with a price - a price paid not only in technological challenges but also in the toll it takes on the human body. Let us explore the effects of space travel, from radiation exposure to altered gravity, and how astronauts adapt to these extreme conditions. Space Radiation: A Silent Threat Radiation Exposure On Earth, our protective magnetic field and atmosphere shield us from the majority of space radiation. However, in space, astronauts face direct exposure to cosmic rays and solar particles. These high-energy particles can penetrate the body, damaging cells and DNA. Increased risk of cancer and degenerative diseases, such as heart disease and cataracts, have been observed in human populations exposed to radiation on Earth. In space, health risks from radiation are mainly driven by long-term impacts. Altered Gravity: A Weighty Matter Microgravity and Muscle Atrophy Astronauts aboard the International Space Station (ISS) experience microgravity, where their bodies float freely. While this weightlessness allows for breathtaking experiments and observations, it wreaks havoc on muscles and bones. Without the constant pull of gravity, muscles weaken, and bones lose density. Astronauts must engage in rigorous exercise routines to counteract muscle atrophy and maintain bone health. Fluid Redistribution and Swollen Faces In microgravity, bodily fluids shift upward, causing facial puffiness and fluid retention. Astronauts often joke about their 'moon faces.' This fluid redistribution can also affect vision, leading to a condition known as spaceflight-associated neuro-ocular syndrome (SANS). Isolation and Confinement: The Mental Strain Psychological Challenges Space missions involve prolonged isolation and confinement. Astronauts live in tight quarters, cut off from the natural world. The absence of familiar sights, sounds, and smells can lead to feelings of loneliness and anxiety. Coping mechanisms, communication with loved ones, and psychological support are crucial to maintaining mental well-being. Distance from Earth: A Cosmic Solitude Emotional Impact The vastness of space can evoke existential thoughts. Astronauts gaze back at Earth - a tiny blue dot suspended in the cosmic void - and grapple with their insignificance. The emotional weight of being far from home, family, and friends can be profound. Hostile and Closed Environments: Surviving in the Void Spacecraft Living Conditions Spacecraft are marvels of engineering, but they are also confined capsules. Astronauts adapt to tight spaces, recycled air, and limited privacy. The constant hum of machinery and the absence of natural light can wear on their senses. Risk of Infection In closed environments, microbes thrive. Astronauts must maintain strict hygiene to prevent infections. The immune system faces unique challenges, especially during extended missions. The Resilience of Astronauts Adaptation and Innovation Astronauts are remarkable in their ability to adapt. They learn to navigate microgravity, perform complex tasks, and troubleshoot technical glitches. Their resilience drives innovation, leading to better spacecraft design and life support systems. The Twin Study: Scott and Mark Kelly Scott Kelly and his identical twin brother, Mark Kelly, participated in the unique Twins Study. Scott spent nearly a year aboard the ISS, while Mark remained on Earth. By comparing their physiological and psychological changes, researchers gained valuable insights into the effects of space travel. Looking Ahead: Mars and Beyond Challenges for Deep Space Missions As we plan for Mars missions and beyond, we face the RIDGE of space travel: Space Radiation: Shielding astronauts from cosmic rays. Isolation and Confinement: Maintaining mental health during long journeys. Distance from Earth: Coping with cosmic solitude. Gravity Fields: Addressing muscle and bone health. Hostile/Closed Environments: Ensuring safety and hygiene. In conclusion, space travel is a delicate balance between exploration and preservation. As we venture farther into the cosmos, we must safeguard both our scientific curiosity and the well-being of those who dare to explore the final frontier. The Environmental Impact of Airplanes and Spaceships Airplanes and spaceships have transformed the way we explore our planet and beyond. However, their operations come with significant environmental consequences. Let's delve into the effects of these flying machines on our delicate ecosystem. Climate Change Air travel is a major contributor to climate change due to greenhouse gas emissions. Jet engines burn fossil fuels (mostly aviation gasoline or jet fuel), releasing carbon dioxide (CO2), nitrogen oxides (NOx), and water vapor into the atmosphere. These emissions trap heat, leading to global warming. Although aviation accounts for about 3.5 percent of human-induced climate change, its impact is disproportionately high due to emissions at high altitudes. Air Quality Airplanes emit pollutants such as sulfur dioxide (SO2), particulate matter (PM), and volatile organic compounds (VOCs). These pollutants degrade air quality near airports and along flight paths. Ground-level ozone formation, which harms human health and ecosystems, is also influenced by aviation emissions. Noise Pollution The roar of jet engines disrupts communities around airports. Noise pollution affects sleep patterns, stress levels, and overall well-being. Efforts to reduce noise include quieter engine designs and flight path adjustments. Spaceships: Earth's Atmospheric Guardians Rocket Launches and Pollution Rocket launches, essential for space exploration, release pollutants into the atmosphere. The fuel used - such as unsymmetrical dimethylhydrazine (UDMH) - can be highly carcinogenic and ecologically damaging. For instance, the Baikonur Cosmodrome in Kazakhstan, the world's oldest spaceport, has left a large zone of pollution due to toxic rocket fuel seeping into the soil. Carbon Particles and Geo-Engineering Recent research highlights the impact of rocket emissions on the atmosphere. Black carbon (soot) particles from rockets can absorb heat, acting as a form of geo-engineering. As commercial space launches increase, so does the concern about their environmental effects. Balancing Exploration and Preservation Space Tourism The rise of space tourism introduces new challenges. As more people venture beyond Earth, we must consider the cumulative impact of rocket emissions. Balancing our curiosity with environmental stewardship is crucial. Sustainable Practices Efforts are underway to develop cleaner propulsion technologies, use alternative fuels, and minimize space debris. Innovations like reusable rockets and electric propulsion aim to reduce the environmental footprint of space travel. Looking Ahead: A Cosmic Responsibility Mars and Beyond As we dream of Mars colonies and interstellar travel, we must tread carefully. The RIDGE of space exploration - Radiation, Isolation, Distance, Gravity, and Environment - requires sustainable solutions. Let's explore the cosmos while safeguarding our home planet. In conclusion, airplanes and spaceships propel us toward the stars, but their effects ripple through our atmosphere and ecosystems. As stewards of both Earth and space, we must navigate the skies responsibly, seeking harmony between exploration and preservation. From the ground to the sky, dining experiences have transcended traditional restaurant settings. Imagine savoring gourmet meals while suspended high above the earth, with breathtaking views stretching as far as the eye can see. Welcome to the world of aerial dining, where culinary delights meet gravity-defying elegance. Dinner in the Sky: Elevating Gastronomy The Original Concept Dinner in the Sky, born in 2006, is the epitome of dining with a twist. Picture a massive table - more like a platform - hoisted almost 200 feet into the air by a sturdy crane. Guests, chefs, and waitstaff don their white hats as they ascend to the skies. The setting? A floating dinner table, surrounded by nothing but open air and panoramic vistas. The Experience As you settle into your seat, the anticipation builds. The restaurant staff orchestrates a three-course fine dining experience, all while suspended in midair. The menu features carefully crafted dishes, often prepared beforehand and finished in a convection oven right there in the sky. Each bite is accompanied by awe-inspiring views - city skylines, rolling landscapes, or even the vastness of the ocean. Safety First Before you ascend, a safety briefing ensures that you're securely strapped in. The thrill of being airborne mingles with the elegance of haute cuisine. Whether it's a romantic date night or a corporate event, Dinner in the Sky promises an unforgettable meal. Sky-High Restaurants Around the World Dubai Marina: A Feast Above the Waters Situated in Dubai Marina, this dining concept boasts some of the best views of the city skyline, surrounding waters, and the iconic Palm Jumeirah. Imagine floating above the ground while you dine - a one-of-a-kind experience you simply cannot miss. After the safety briefing near Skydive Dubai, you're hoisted 50 meters into the air, suspended over the bustling marina. The fusion of flavors meets the fusion of horizons. Las Vegas: Unparalleled Views of the Strip In the entertainment capital of the world, Dinner in the Sky Las Vegas takes fine dining to new heights - literally. As the sun sets, you ascend, and the glittering lights of the Las Vegas Strip come alive. The most unforgettable dinner you'll ever have awaits, with the cityscape stretching out beneath you. It's a feast for the senses, where culinary artistry meets architectural marvels. The Future of Aerial Gastronomy Sustainability and Innovation As we look ahead, the challenge lies in balancing indulgence with environmental responsibility. How can we minimize the carbon footprint of these lofty dining experiences? Innovations like electric-powered cranes, locally sourced ingredients, and waste reduction strategies are steps toward a more sustainable future. Beyond Earth: Space Tourism and Cosmic Cuisine With the rise of space tourism, could we soon dine among the stars? Imagine a celestial restaurant aboard a spacecraft, overlooking Earth from orbit. Cosmic cuisine - crafted by zero-gravity chefs - might become the ultimate bucket-list experience. As we explore the cosmos, let's ensure that our gastronomic adventures leave no trace behind. In conclusion, dining in the air transcends mere sustenance. It's a celebration of human ingenuity, a fusion of flavors and vistas, and a reminder that our appetite for exploration knows no bounds. So, raise your glass (carefully!) to the skies and savor the magic of dining aloft. Dining in the Sky is a unique and exhilarating culinary experience that elevates traditional dining to new heights - literally. Here are the key aspects of this extraordinary concept: The Setting: Up, Up, and Away! Imagine being seated at a massive table suspended high above the ground, often hundreds of feet in the air. The dining platform is typically hoisted by a sturdy crane or other mechanical means. Guests, chefs, and waitstaff ascend together, creating an unforgettable communal experience. The Experience: A Feast with a View As you settle into your seat, anticipation builds. The thrill of being airborne mingles with the elegance of haute cuisine. The menu features carefully crafted dishes, often prepared beforehand and finished on-site. Whether it's breakfast, lunch, or dinner, each course is served against a backdrop of breathtaking views - city skylines, rolling landscapes, or even the vastness of the ocean. The floating table becomes a stage for culinary artistry, where flavors dance amidst the clouds. Safety First: Buckle Up! Before ascending, guests receive a safety briefing. Straps secure them to their seats, ensuring a worry-free dining experience. The focus shifts from gravity to gastronomy as the platform rises, leaving the ground far below. Locations Around the World: Where the Sky Meets the Plate Dubai Marina: Suspended above the bustling marina, diners enjoy views of the city skyline and the iconic Palm Jumeirah. Las Vegas: As the sun sets, guests ascend over the glittering lights of the Las Vegas Strip, creating an unparalleled dining spectacle. The Future: Sustainability and Cosmic Cuisine Balancing indulgence with environmental responsibility is crucial. Innovations like electric-powered cranes and locally sourced ingredients aim to reduce the carbon footprint. Could cosmic cuisine be next? With the rise of space tourism, imagine dining aboard a spacecraft, overlooking Earth from orbit. Zero-gravity chefs crafting celestial dishes - it's a tantalizing prospect. Introduction The sky, our celestial canvas, is a dynamic theater where cosmic phenomena unfold. From twinkling stars to majestic planets, the sky offers a mesmerizing display that captivates astronomers and dreamers alike. In this essay, we'll explore the various elements of celestial weather, from meteor showers to planetary alignments. Stars and Constellations Stellar Climates Stars, like earthly weather patterns, exhibit their own 'climates.' Some stars burn fiercely, radiating intense heat, while others are cooler and more temperate. The constellations, those celestial neighborhoods, form intricate patterns across the night sky. Imagine them as cosmic weather maps, guiding our eyes to distant realms. Meteor Showers: Celestial Rainfall Meteor showers are cosmic storms, where Earth passes through debris left behind by comets. As these tiny particles collide with our atmosphere, they ignite, creating streaks of light - the meteors. The Perseids in August and the Geminids in December are celestial fireworks, painting the sky with ephemeral beauty. Planets and Their Dance Planetary Weather Systems Our solar system hosts a diverse range of planets, each with its own atmospheric conditions. Venus, shrouded in thick clouds of sulfuric acid, experiences hurricane-force winds. Mars, with its rusty surface, battles dust storms that engulf the entire planet. Jupiter's Great Red Spot - a colossal storm - has raged for centuries. Conjunctions and Oppositions Planets engage in a cosmic ballet. Conjunctions occur when two planets appear close together in the sky, as if sharing a celestial embrace. Oppositions, on the other hand, position a planet directly opposite the Sun, making it visible all night. Witnessing Mars during opposition feels like meeting an old friend. Lunar Weather Phases of the Moon The Moon, Earth's faithful companion, cycles through its phases. New Moon, First Quarter, Full Moon - the lunar weather changes predictably. During a lunar eclipse, our planet casts a shadow on the Moon, turning it coppery red. It's a cosmic reminder of our place in the grand celestial drama. Tides: The Ocean's Cosmic Response The Moon's gravitational pull orchestrates tides on Earth. High tides and low tides ebb and flow, responding to lunar cues. The celestial dance between Earth, Moon, and Sun shapes our oceans, affecting coastlines and marine life. Celestial Events Comets: Cosmic Visitors Comets, celestial vagabonds, journey through our solar system. Their icy cores release gas and dust, forming magnificent tails. Halley's Comet, a recurring visitor, graces our skies once every 76 years. Its return is a cosmic homecoming. Supernovae: Stellar Explosions When massive stars reach the end of their lives, they explode in brilliant supernovae. These cosmic fireworks outshine entire galaxies. Witnessing a supernova - a rare event - is like glimpsing the universe's raw power. Conclusion As we gaze upward, let's remember that the sky is not merely a backdrop but a living, breathing entity. Its weather - both familiar and otherworldly - shapes our cosmic experience. So, next time you look up, consider the celestial forecast: a blend of stardust, wonder, and infinite possibilities. In the words of Carl Sagan, 'The cosmos is within us. We are made of star-stuff.' Cosmic Mysteries Dark Matter and Dark Energy The sky harbors secrets beyond our comprehension. Among them are dark matter and dark energy. Dark matter, invisible and elusive, exerts gravitational influence on galaxies, holding them together. Imagine it as the cosmic glue binding the universe. Dark energy, on the other hand, accelerates the universe's expansion, pushing galaxies apart. These cosmic enigmas remain shrouded in mystery, awaiting discovery. Auroras: Celestial Light Shows When charged particles from the Sun collide with Earth's magnetic field, they create auroras - the ethereal dance of light near the poles. The Northern Lights (Aurora Borealis) and Southern Lights (Aurora Australis) paint the night sky with hues of green, pink, and purple. These celestial ballets remind us of our interconnectedness with the solar system. Celestial Timekeeping Stellar Clocks The sky serves as humanity's oldest timekeeper. Ancient civilizations relied on celestial events for calendars. The sidereal day, based on Earth's rotation relative to distant stars, is approximately 23 hours, 56 minutes, and 4 seconds. Constellations rise and set, marking the passage of time - a cosmic heartbeat. Eclipses: Celestial Alignments Solar and lunar eclipses are cosmic alignments. During a solar eclipse, the Moon obscures the Sun, casting a shadow on Earth. The eerie twilight and the diamond ring effect evoke awe. Lunar eclipses, when Earth's shadow engulfs the Moon, transform it into a reddish orb - an astronomical spectacle witnessed by civilizations across millennia. Cosmic Harmony Music of the Spheres Ancient philosophers believed in the 'music of the spheres.' They imagined celestial bodies - planets, stars, and moons - emitting harmonious vibrations. Each celestial note contributed to a cosmic symphony. While we no longer hear this celestial music, its metaphorical resonance persists - a reminder that the universe hums with hidden melodies. Galactic Weather Patterns Galaxies, like weather systems, evolve. Spiral galaxies, with their graceful arms, resemble cosmic hurricanes. Elliptical galaxies, shaped like celestial footballs, harbor dormant black holes at their cores. Colliding galaxies create celestial tempests, birthing new stars. The cosmic weather forecast predicts galactic collisions, stellar births, and cosmic winds. Conclusion: Our Cosmic Home As we conclude our cosmic odyssey, remember that the sky is not an abstract canvas - it's our celestial home. Whether you're stargazing from a mountaintop or contemplating the Moon's craters, you participate in the grand cosmic narrative. The sky whispers tales of creation, destruction, and eternity. So, dear reader, look up. Embrace the celestial weather - the storms and serenades. For in the vastness of space, we find wonder, humility, and a shared cosmic kinship. As Carl Sagan eloquently put it, 'We are a way for the cosmos to know itself.' Introduction The universe is a symphony, and planets are its celestial notes. These enigmatic orbs dance around stars, weaving tales of creation, destruction, and cosmic balance. In this essay, we embark on a cosmic journey to explore the eight planets of our solar system and their profound significance. Mercury: The Swift Messenger Mercury, the swiftest planet, orbits closest to the Sun. Its surface is a rugged landscape of craters and cliffs, baked by scorching temperatures during the day and chilled at night. Named after the Roman messenger god, Mercury shuttles between extremes, delivering cosmic messages across the solar system. Venus: Earth's Fiery Twin Venus, Earth's twin sister, hides behind thick clouds of sulfuric acid. Its surface resembles a volcanic inferno, with temperatures hot enough to melt lead. Yet, its beauty lies in its radiant glow - the Morning and Evening Star - illuminating our dawn and dusk. Earth: Our Blue Gem Earth, our precious home, teems with life. Its oceans, forests, and deserts form a delicate biosphere. From the icy poles to the equatorial rainforests, Earth's diverse climates sustain a symphony of ecosystems. We are its guardians, entrusted with its care. Mars: The Red Planet's Mysteries Mars, the Red Planet, beckons explorers. Its rusty surface bears ancient river valleys and polar ice caps. Could Mars harbor hidden reservoirs of life? Robotic rovers traverse its deserts, seeking answers beneath its crimson skies. Jupiter: King of the Gas Giants Jupiter, the colossal gas giant, boasts a mesmerizing tapestry of bands and storms. Its Great Red Spot - a tempest larger than Earth - has raged for centuries. Jupiter's gravitational pull shapes the solar system, protecting inner planets from cosmic debris. Saturn: Jewel of the Rings Saturn, adorned with majestic rings, is a cosmic jewel. These icy hoops, composed of countless particles, create a celestial ballet. Saturn's moons - Titan, Enceladus, and others - beckon us to explore their icy landscapes. Uranus: The Original Ice Giant Uranus, tipped on its side, spins like a cosmic top. Its icy blue hue conceals turbulent storms. Uranus remains a mystery, awaiting further study by future missions. Neptune: The Farthest Wanderer Neptune, shrouded in azure clouds, is the outermost planet. Its winds whip at supersonic speeds, and its icy heart harbors storms that rival Jupiter's. Voyager 2, our interstellar traveler, captured Neptune's beauty as it sailed past. Conclusion: Cosmic Harmony Planets are cosmic harmonizers. Their gravitational dances sculpt orbits, stir tides, and guide comets. They remind us of our place in the grand cosmic orchestra. As we gaze at the night sky, let us cherish these celestial companions - the guardians of harmony. In the words of Carl Sagan, 'We are made of star-stuff.' Our existence echoes the cosmic rhythm, and planets are our celestial partners in this cosmic waltz. Pluto, once considered our ninth planet, now holds the title of a dwarf planet. The International Astronomical Union (IAU) made this reclassification in 2006. Pluto didn't meet one of the three criteria the IAU uses to define a full-sized planet: it has not cleared its neighboring region of other objects. Despite its demotion, Pluto remains a fascinating member of the Kuiper belt, a ring of bodies beyond Neptune's orbit. It is the ninth-largest and tenth-most-massive known object to directly orbit the Sun. Although smaller than Earth's moon, Pluto's icy and rocky composition continues to intrigue astronomers and stargazers alike. NASA's New Horizons mission is a remarkable endeavor that has expanded our understanding of the outer reaches of our solar system. Let's delve into the details of this pioneering spacecraft: Objective: New Horizons was designed to study the dwarf planet Pluto, its moons, and other objects in the Kuiper Belt. Launch Date: On January 19, 2006, New Horizons embarked on its epic journey. Spacecraft Mass: Weighing 1,054 pounds (478 kilograms), it carried a suite of scientific instruments. Mission Design and Management: The mission was led by NASA in collaboration with the Johns Hopkins University Applied Physics Laboratory (APL). Historic Flyby: On July 14, 2015, New Horizons made history by becoming the first spacecraft to explore Pluto up close. It captured stunning images of Pluto's diverse geological features, including its icy plains, rugged mountains, and frozen canyons. Moons of Pluto: During the flyby, New Horizons also studied Pluto's five moons, including the intriguing Charon. Arrokoth Flyby: In early 2019, New Horizons achieved another milestone by flying past Arrokoth (2014 MU69). Arrokoth is a Kuiper Belt Object, making it the most distant object ever explored up close. Kuiper Belt: This region extends from about 30 AU (near Neptune's orbit) to about 50 AU from the Sun. New Horizons ventured into this uncharted territory. New Horizons carried an impressive array of instruments, including: Ralph: A visible and infrared imager/spectrometer. Alice: An ultraviolet imaging spectrometer. Radio-Science Experiment (REX): Studied radio signals. Long-Range Reconnaissance Imager (LORRI): Captured high-resolution images. Solar Wind and Plasma Spectrometer (SWAP): Analyzed solar wind. Pluto Energetic Particle Spectrometer Science Investigation (PEPSSI): Studied particles around Pluto. Student Dust Counter (SDC): Measured dust impacts. New Horizons provided insights into Pluto's atmosphere, surface, and geology. It revealed icy mountains, glaciers, and mysterious dark regions. The spacecraft also observed Jupiter's moons (Io, Europa, and Ganymede) during its long journey. As of 2023, New Horizons continues to explore the outer solar system, contributing to our understanding of distant bodies. In summary, New Horizons has been a trailblazer, revealing the secrets of Pluto and venturing into the cosmic frontier. Its legacy inspires future missions and fuels our curiosity about the cosmos. ",
+    "8192": "Once upon a time, in a quaint little village nestled amidst rolling hills, there existed an old teapot. But this was no ordinary teapot; it was a magical one. Its handle curved just so, and its spout seemed to whisper secrets to the wind. The villagers called it 'Elara,' and they believed it held the power to grant wishes. Elara sat on the windowsill of Mrs. Abernathy's cozy cottage. Mrs. Abernathy was a kind-hearted woman with twinkling eyes and a penchant for herbal teas. She'd inherited the teapot from her grandmother, who, in turn, had received it from a mysterious traveler. One chilly evening, as the sun dipped below the horizon, Mrs. Abernathy brewed her favorite chamomile tea. She poured the fragrant liquid into Elara, and to her astonishment, the teapot began to glow. The room filled with a soft, golden light, and Mrs. Abernathy felt a tingle in her fingertips. 'Make a wish,' whispered Elara, its spout quivering. Mrs. Abernathy hesitated. She'd heard tales of wishes gone awry - of greedy desires leading to unintended consequences. But her heart yearned for something simple: a garden filled with blooming roses. So, she closed her eyes and wished for just that. The next morning, Mrs. Abernathy stepped outside, and her breath caught. The air smelled of roses - sweet and heady. But when she looked around, she gasped. Her modest garden had transformed into a riot of colors. Roses of every hue - crimson, ivory, apricot - bloomed in profusion. They climbed the walls, twined around the picket fence, and even spilled onto the cobblestone path. Word spread throughout the village, and soon everyone wanted a turn with Elara. The baker wished for the perfect sourdough loaf, and it appeared in his oven. The blacksmith wished for strength, and his arms bulged with newfound muscle. The schoolteacher wished for wisdom, and her lectures became captivating tales. But as wishes multiplied, so did the consequences. The baker's sourdough grew sentient and demanded to be called 'Doughbert.' The blacksmith's strength made him accidentally crush his anvil. And the schoolteacher's wisdom led her to question the very fabric of reality. Mrs. Abernathy watched with a mix of amusement and concern. Elara seemed to thrive on granting wishes, but its porcelain surface bore faint cracks. Was it growing weaker? One day, a young girl named Lily approached Elara. Her eyes sparkled with innocence, and she clutched a dandelion in her hand. 'Teapot,' she said, 'I wish for a friend.' Elara hesitated. It sensed the purity of Lily's heart, but it also knew the weight of loneliness. With a shudder, it granted the wish. And so, Lily's dandelion transformed into a giggling sprite named Petal. They danced through meadows, shared secrets, and became inseparable. Elara's cracks deepened, but it didn't mind. As seasons passed, Mrs. Abernathy sat by the window, watching Elara fade. Yet, she felt no regret. For in granting wishes, the teapot had found purpose. And perhaps, just perhaps, it had one final wish left - to be remembered. And so, when Mrs. Abernathy's time came, she whispered to Elara, 'Thank you.' The teapot glowed one last time, and Mrs. Abernathy drifted away, leaving behind a garden of roses and a village full of stories. And that, my dear reader, is how the enchanted teapot became a legend - a vessel of magic, love, and wishes granted with a fragile heart. As the seasons changed, so did the village. The once-sleepy hamlet now buzzed with visitors from distant lands. They came seeking Elara, the legendary teapot that granted wishes. Some sought riches, others fame, but most yearned for something deeper - a connection to the mystical. Among the newcomers was a weary traveler named Ezra. His cloak was tattered, and his boots bore the marks of countless miles. He'd heard whispers of Elara's magic and hoped it could mend his broken heart. For Ezra had lost his beloved, and grief weighed upon him like an anchor. Mrs. Abernathy, now an old woman with silver hair, welcomed Ezra into her cottage. Elara sat on the windowsill, its porcelain surface etched with memories. Mrs. Abernathy poured chamomile tea into the teapot, and it glowed faintly, as if recognizing an old friend. 'Make a wish,' Mrs. Abernathy said, her voice soft. Ezra hesitated. His wish was simple yet profound: to see his love once more, if only in a dream. He closed his eyes and whispered, 'I wish for a single night with her.' Elara trembled, its spout quivering. It understood the ache of lost love - the longing that transcended time. And so, it granted Ezra's wish. That night, as the moon hung low in the sky, Ezra lay on Mrs. Abernathy's creaky bed. Elara sat beside him, its glow illuminating the room. He drifted into slumber, and there, in the realm between wakefulness and dreams, he found himself in a moonlit garden. His love, Isolde, stood before him. Her eyes were the color of forget-me-nots, and her laughter echoed like wind chimes. They danced beneath a silver canopy, twirling through memories - their first kiss, stolen moments by the river, promises whispered under ancient oaks. But dreams are fragile, and dawn approached. Isolde's form wavered, and Ezra clung to her. 'Stay,' he pleaded. 'Just a little longer.' Isolde smiled, her touch like a butterfly's kiss. 'Time bends here,' she said. 'But you must wake, my love.' As the sun peeked over the horizon, Ezra opened his eyes. Elara sat on the windowsill, its glow fading. Mrs. Abernathy watched him, her gaze knowing. 'Did you see her?' she asked. Ezra nodded, tears glistening. 'She was real, Mrs. Abernathy. I held her again.' The village marveled at Ezra's tale - the man who danced with a ghost. They flocked to Elara, each with their wishes. The blacksmith wished for forgiveness, the baker for inspiration, and the schoolteacher for courage. Elara obliged, its cracks deepening, but it never complained. One day, as winter painted the landscape white, Mrs. Abernathy grew frail. She called Ezra to her bedside. 'Elara's magic wanes,' she whispered. 'But it has one final wish.' Ezra knelt beside her. 'What is it?' 'Take Elara beyond the hills,' Mrs. Abernathy said. 'To the ancient oak where Isolde and I carved our initials. There, bury the teapot. It will become part of the earth, and its magic will seep into the roots.' And so, on a frost-kissed morning, Ezra carried Elara to the oak. He dug a small hole, placed the teapot inside, and covered it with soil. As he patted the ground, he felt a tremor - a farewell. The next spring, the oak bloomed with roses - crimson, ivory, apricot. And in its shade, a dandelion sprouted. Its petals glowed like moonlight, and when the wind whispered, it carried echoes of laughter. Ezra knew then that Elara's wish had come true. It had become part of the land, woven into the fabric of stories. And perhaps, just perhaps, it still listened, granting silent wishes to those who believed. And so, the legend of Elara lived on - a teapot turned earth, a vessel of love, and a bridge between worlds. In the heart of the Whispering Forest, where ancient trees leaned close and their leaves murmured secrets, lived a young girl named Evelyn. She had eyes the color of moss and hair that tangled like wild vines. Evelyn was no ordinary child; she could hear the forest's whispers - the soft rustle of leaves, the creaking of branches, and the laughter of unseen creatures. The villagers feared the Whispering Forest. They said it was cursed - a place where time flowed differently, where shadows danced with mischief, and where lost souls wandered forever. But Evelyn felt drawn to its heart. She believed the forest held answers - about her missing parents, about the world beyond the village. One moonlit night, when the forest beckoned with silver fingers, Evelyn slipped away from her tiny cottage. She wore a cloak spun from spider silk and carried a lantern that glowed like a captured star. The trees leaned in, their bark etched with ancient runes. They whispered her name - Evelyn, Evelyn - as if they knew her purpose. Deeper she ventured, past gnarled roots and dew-kissed ferns. The air smelled of moss and memories. The lantern's light flickered, casting eerie shadows on the forest floor. And then, she heard it - the melody of the Whispering Forest. It was a haunting tune, sung by unseen lips, and it tugged at her heart. 'Who are you?' Evelyn whispered. The forest answered - a chorus of voices, overlapping and harmonizing. 'We are the echoes of forgotten dreams, the guardians of lost paths. Seek what you desire, but beware the price.' Evelyn pressed on. She reached a clearing where moonflowers bloomed - a sea of pale petals that glowed like fallen stars. In their midst stood a stone pedestal, and atop it rested a silver key. It was unlike any key she'd seen - twisted and delicate, with a single emerald set in its bow. The whispers intensified. 'Take the key,' they urged. 'Unlock the door to your destiny.' Evelyn hesitated. What door? What destiny? She thought of her parents - their laughter, their scent of pine and adventure. They'd vanished when she was a baby, leaving only a crumpled map with cryptic symbols. With trembling fingers, she picked up the key. It felt warm, alive. And then, she saw it - a door, half-hidden behind an ancient oak. Its wood was etched with constellations, and its handle bore the same emerald as the key. Evelyn inserted the key into the lock. The door groaned open, revealing a tunnel - a ribbon of darkness that wound deeper into the forest. The whispers grew urgent. 'Step through, Evelyn. Find your truth.' She stepped into the tunnel, and the world shifted. Time blurred, and she glimpsed her parents - laughing, dancing, fading like smoke. The tunnel led to a chamber - a celestial cavern where stars swirled in liquid patterns. And there, on a stone pedestal, lay a crystal vial. The whispers crescendoed. 'Drink,' they urged. 'Remember.' Evelyn uncorked the vial. Memories flooded her - the scent of pine, her parents' laughter, the taste of adventure. Tears blurred her vision. She drank, and the forest embraced her - a cocoon of whispers, of love, of belonging. When Evelyn emerged, the Whispering Forest had changed. It no longer whispered of curses but sang of hope. She carried her parents' memories - their legacy - and vowed to protect the forest's secrets. And so, Evelyn became the new guardian. She tended the moonflowers, listened to the trees, and sang the haunting melody. The villagers no longer feared the forest; they sought its solace, its magic. And every night, as the moon rose, Evelyn stood by the ancient oak. She whispered her parents' names, and the forest whispered back - a lullaby woven from stardust and love. Beyond the Whispering Forest, where the moonflowers bloomed and the stars whispered secrets, lay a forgotten path. It was a narrow trail, overgrown with moss and guarded by ancient stones. Few dared to tread there, for it led to the Compass Grove. Lysander, a young cartographer with ink-stained fingers and a heart full of wanderlust, stumbled upon this path one misty morning. His boots sank into damp earth, and the air smelled of pine and possibility. He carried a tattered map - a relic passed down through generations. Its edges bore cryptic symbols, and its center held a blank space - an uncharted territory. The Compass Grove was said to house a mystical compass - the Wayfinder's Compass - forged by the first explorers. It was no ordinary instrument; it pointed not to north, but to one's true desire. Legends whispered that whoever held the compass could navigate not only the physical world but also the labyrinth of their own heart. Lysander's pulse quickened. He yearned for adventure - to map uncharted lands, to unravel mysteries. His parents had vanished during an expedition, leaving behind a single clue: the blank space on the map. Perhaps the Compass Grove held answers. As he pushed through brambles and ferns, the forest seemed to guide him. Sunlight filtered through leaves, casting dappled patterns on the ground. And then, he saw it - a circle of ancient stones, their surfaces etched with symbols. At the center stood a pedestal, and atop it rested the Wayfinder's Compass. Lysander's breath caught. The compass was unlike any he'd seen. Its needle shimmered like a captured star, and its dial bore not cardinal directions but enigmatic words: Dreams, Regret, Destiny, and Hope. He touched the compass, and it hummed - a vibration that resonated in his bones. The whispers began - the voices of long-lost explorers, of forgotten dreams. 'Choose,' they urged. 'Choose your path.' Lysander hesitated. Dreams? Regret? Destiny? Hope? Each word held a promise, a peril. He thought of his parents - their laughter, their courage. He thought of the blank space on the map - the uncharted territory that beckoned. And so, he turned the dial to Dreams. The needle quivered, then settled - a path leading deeper into the forest. Lysander followed, lantern in hand, heart pounding. The compass guided him past silver streams and ancient oaks. It led him to a hidden waterfall - a curtain of moonlight that shimmered like stardust. There, he glimpsed a figure - a woman with eyes like forgotten constellations. She wore a cloak spun from spider silk, and her hair flowed like a river. 'Lysander,' she said, her voice a melody. 'You seek dreams.' He nodded. 'I seek answers. About my parents.' The woman touched his forehead, and memories flooded him - the scent of pine, his parents' laughter, the taste of adventure. 'Dreams are maps,' she said. 'They guide us beyond what we see.' Lysander understood. Dreams were compasses of the soul. His parents had followed theirs, and now he would follow his. He stepped through the waterfall, and the world shifted. He found himself on a cliff overlooking a vast sea - a sea of blank parchment. Islands floated in the distance, waiting to be charted. Lysander unrolled his map - the one with the blank space - and dipped his quill. He drew coastlines, marked mountains, and named each land. And as he mapped, the compass glowed - a beacon of dreams fulfilled. Lysander knew then that he was not merely a cartographer; he was a dreamweaver. His parents' legacy flowed through him - their courage, their laughter, their love. And so, Lysander sailed the uncharted seas, guided by the Wayfinder's Compass. He discovered islands of forgotten myths, forests of whispered tales, and cities where stars danced in the streets. He wrote his own story - a cartography of dreams. And in the Compass Grove, the ancient stones whispered his name - Lysander, Lysander - as if they knew he'd found his true north. In the heart of the city, where cobblestone streets wound like forgotten memories, stood an abandoned mansion. Its windows were boarded up, and ivy clung to its crumbling walls. But within those decaying walls lay a secret - a clockwork garden. Evelyn, a curious girl with eyes like rain-kissed petals, discovered the mansion one rainy afternoon. She wore mismatched socks and carried a notebook filled with sketches - a testament to her love for hidden wonders. The mansion's gate creaked open, and Evelyn stepped into a world frozen in time. The clockwork garden was unlike any other. Its flowers were made of gears and springs, their petals unfolding with precise clicks. The roses ticked, the daffodils whirred, and the tulips chimed. And at the center stood a colossal mechanical tree - its branches reaching toward the sky, its leaves spinning like miniature windmills. Evelyn gasped. She'd read about clockwork wonders - the automatons that danced at royal balls, the pocket watches that whispered secrets. But this garden was alive - a symphony of metal and magic. As she explored, she noticed a silver key embedded in the tree's trunk. It gleamed, beckoning her. Evelyn hesitated. What did the key unlock? And why had the clockwork garden been abandoned? The flowers seemed to whisper. 'Unlock the tree,' they urged. 'Discover its heart.' Evelyn turned the key. The tree shuddered, and its branches parted, revealing a hidden chamber. Inside, a mechanical heart pulsed - a delicate contraption of brass and crystal. It hummed, resonating with the rhythm of forgotten time. And then, she heard it - the voice of the tree. 'I am Chronos,' it said. 'Guardian of moments.' Evelyn's heart raced. 'Moments?' 'Every petal, every leaf,' Chronos explained. 'They hold memories - the laughter of lovers, the tears of parting, the whispers of dreams. But time has fractured. The clockwork garden is frozen, and I am fading.' Evelyn understood. The mansion's former owner - a clockmaker named Lysander - had built this garden to capture fleeting moments. But Lysander had vanished, leaving Chronos incomplete. 'I can mend you,' Evelyn said. 'But why was the garden abandoned?' Chronos sighed - a sound like winding gears. 'Lysander sought eternity. He believed that by freezing time, he could preserve love, prevent loss. But he forgot that life thrives in impermanence.' Evelyn touched the mechanical heart. 'Can we fix it?' Chronos nodded. 'You must find Lysander's final creation - the Celestial Gear. It lies beyond the city, where the river meets the stars.' And so, Evelyn embarked on her quest. She followed the river, past moonlit bridges and forgotten docks. The Celestial Gear awaited - a constellation of interlocking wheels, its center a pulsing light. As she placed the gear in Chronos's heart, the clockwork garden stirred. Flowers bloomed, petals unfurling with joy. The mechanical tree's leaves spun faster, and time flowed once more. But Chronos grew weaker. 'I am bound to this place,' it said. 'My purpose fulfilled.' Evelyn wept. 'Can't you come with me?' Chronos smiled - a clockwork smile. 'I am part of the garden now. But you, dear Evelyn, carry its memory.' And so, she returned to the mansion, where the clockwork garden thrived. She sketched its wonders, capturing gears and petals on paper. And when she closed her eyes, she heard the whispers - the laughter of lovers, the tears of parting, the echoes of dreams. Evelyn became the new guardian. She tended the flowers, wound the tree, and listened to Chronos's fading heartbeat. And every night, as the stars wheeled overhead, she whispered her thanks. For in the heart of the clockwork garden, time danced - a fragile waltz of moments, preserved and cherished. In the heart of the Astronomer's Quarter, where cobblestone streets wound like celestial paths, stood an ancient observatory. Its domed roof bore the scars of countless meteor showers, and its telescopes whispered secrets to the night sky. But within those hallowed walls lay a mystery - a forgotten constellation. Aria, a young stargazer with eyes like distant galaxies, discovered the observatory one moonless night. She wore a cloak spun from stardust and carried a pocket-sized atlas - a testament to her love for the heavens. The observatory's door creaked open, and Aria stepped into a world woven with cosmic threads. The forgotten constellation was unlike any other. Its stars were elusive, their positions shifting with each passing century. Astronomers had once mapped it - a celestial tapestry of myth and memory - but over time, its name faded, its stories lost. As Aria explored, she noticed a silver quill resting on an ancient star chart. Its nib gleamed, beckoning her. Aria hesitated. What secrets did the quill hold? And why had the forgotten constellation slipped from memory? The stars seemed to whisper. 'Write,' they urged. 'Illuminate the night.' Aria dipped the quill in ink. The constellations above shifted - a celestial dance awaiting completion. She traced the forgotten lines - the Hunter's Bow, the Weaver's Loom, the Lost Lyre. And then, she saw it - a gap in the sky, a void where a constellation once blazed. The quill hummed - a vibration that resonated in her bones. The whispers intensified. 'Remember,' they urged. 'Remember the story.' And so, Aria wrote - a tale woven from stardust and longing. She penned the forgotten constellation's name: Lyra's Veil. Its stars had once guided lovers across oceans, inspired poets to verses, and cradled dreams in their luminous arms. But Lyra's Veil had vanished - a casualty of time's relentless march. Its stories faded, its purpose lost. Aria vowed to restore it - to stitch the celestial fabric, thread by thread. She climbed to the observatory's rooftop, where telescopes pointed toward infinity. Aria gazed at the sky, her breath mingling with the Milky Way. And there, in the gap, she saw it - the faint glimmer of Lyra's Veil. The quill guided her. She drew the missing lines - the Weaver's Loom reconnected, the Lost Lyre's melody restored. And as she wrote, the stars responded. Lyra's Veil emerged - a constellation reborn. But Aria felt a pull - a cosmic yearning. She touched the quill to her heart, and memories flooded her - the scent of stardust, her grandmother's bedtime stories, the taste of wonder. 'Guard it,' whispered the stars. 'Guard Lyra's Veil.' And so, Aria became the new guardian. She tended the observatory, charted the skies, and whispered the forgotten stories. The astronomers marveled - the gap was gone, and Lyra's Veil blazed once more. But Aria knew her duty. She would write new tales - of love, of courage, of dreams stitched together. And every night, as the constellations wheeled overhead, she whispered her thanks. For in the heart of the forgotten constellation, time danced - a fragile waltz of memory, preserved and cherished. In the heart of the bustling city, where skyscrapers touched the clouds and neon signs flickered like distant stars, lived a forgotten runner named Evelyn. She wasn't famous like the sprinters on billboards or the marathon champions with their gleaming medals. No, Evelyn was an ordinary woman who ran for the sheer joy of it. Every morning, before the sun peeked over the horizon, Evelyn laced up her worn-out sneakers. She followed the same route - a loop around the park, past the fountain where pigeons bathed, and along the riverbank where willow trees whispered secrets. Her pace was steady, her breaths rhythmic. She ran not to win races but to escape the noise of life - to find solace in the rhythm of her footsteps. But the city had forgotten Evelyn. The sports channels didn't broadcast her runs, and the local newspapers didn't write about her achievements. She was a lone figure - a silhouette against the dawn, chasing dreams that no one else cared about. One chilly morning, as Evelyn jogged along the river, she noticed a poster taped to a lamppost. It announced the city's annual marathon - the grand event that drew elite athletes from around the world. Evelyn's heart skipped a beat. She'd never run a marathon, but the idea tugged at her like a distant constellation. She tore off the poster and studied it. The race would wind through the city's streets, past cheering crowds and historic landmarks. The finish line was the grand stadium - the same stadium where she'd watched her heroes cross the tape, their names echoing through the loudspeakers. Evelyn hesitated. She wasn't a professional runner. She didn't have a coach or a team. But something stirred within her - a longing to be part of the marathon, to leave her mark on the city she loved. And so, she trained. She woke earlier, ran farther, and pushed her limits. She practiced pacing, fueled by oatmeal and determination. The other runners didn't notice her - a middle-aged woman with graying hair - but Evelyn didn't mind. She was a comet streaking through the pre-dawn darkness, fueled by her own quiet fire. On marathon day, the city buzzed with excitement. The streets were lined with spectators - families with homemade signs, old couples in folding chairs, children waving tiny flags. The elite runners surged ahead, their strides effortless. But Evelyn was in the middle of the pack - a forgotten runner among thousands. As she crossed each mile marker, Evelyn felt a surge of pride. She wasn't breaking records, but she was breaking barriers - the ones she'd built around herself. The cheers of the crowd fueled her - their encouragement like solar winds pushing her forward. And then, at mile 20, exhaustion hit. Evelyn's legs wobbled, her breaths came in ragged gasps. She glanced at the grand stadium - the finish line shimmering like a distant galaxy. But her body rebelled. She wanted to collapse, to fade into anonymity. And that's when she saw him - a young boy with a crumpled sign. It read, 'Go, Evelyn! You're not forgotten.' Tears blurred her vision. She pushed through the pain, her heartbeat a metronome of determination. As Evelyn crossed the finish line, the crowd erupted. The loudspeakers blared her name - Evelyn, Evelyn - and the forgotten runner became a star. She collapsed into the arms of a volunteer, her legs trembling. But she'd done it. She'd run the marathon - the one that mattered to her. The newspapers wrote about her - the woman who defied odds, who ran not for glory but for love. And the city remembered Evelyn - the forgotten runner who'd become a constellation, lighting the way for others. Lysander stood at the finish line of the marathon, his chest heaving, sweat-soaked shirt clinging to his skin. The stadium roared - a symphony of applause and encouragement. But amidst the cheers, he felt a void - an ache that no medal could fill. He'd run the race - the one that mattered to him. Yet, as he caught his breath, Lysander wondered about the blank space on his map. The uncharted territory - the reason his parents had vanished - still haunted him. A shadow fell across the track. It was Evelyn, the forgotten runner. Her eyes sparkled with determination, and her worn-out sneakers bore the marks of countless miles. She'd finished the marathon too, her name echoing through the loudspeakers. 'Evelyn,' Lysander said, his voice hoarse. 'Why do we run?' She leaned against the railing, gazing at the city beyond. 'For the same reason we map,' she replied. 'To find what's lost.' Lysander nodded. 'The Compass Grove,' he said. 'The Wayfinder's Compass.' Evelyn's eyes widened. 'You know of it?' He traced the blank space on his map - the gap where the forgotten constellation should be. 'My parents sought it,' Lysander confessed. 'They believed it held answers - about time, about destiny.' Evelyn's fingers brushed the silver quill in her pocket. 'And did they find it?' He shook his head. 'They vanished. But I won't stop searching.' Together, they left the stadium - the forgotten runner and the cartographer. They followed the same path - the one that led beyond the city, into the Whispering Forest. The compass guided them - the needle pointing not to north, but to dreams. As they reached the ancient stones of the Compass Grove, Evelyn gasped. 'Look,' she said, her voice hushed. There, etched into the stones, were symbols - the Weaver's Loom, the Lost Lyre, and the Hunter's Bow. And at the center stood the pedestal - the Wayfinder's Compass. Lysander touched it - the needle quivering. 'What do we seek?' he asked. Evelyn's eyes held galaxies. 'Not just answers,' she said. 'But connection - to the forgotten, to each other.' And so, they turned the dial - to Hope. The compass hummed, and the forest whispered. A path opened - a ribbon of moonlight leading deeper. They stepped through, and the world shifted. Stars swirled - a celestial dance. And there, in the gap, they saw it - the forgotten constellation. Lyra's Veil blazed - a tapestry of memories, stitched by stardust. Its stars guided lovers, inspired poets, and cradled dreams. Lysander and Evelyn held hands - the cartographer and the runner. They traced the lines - the Weaver's Loom reconnected, the Lost Lyre's melody restored. And as they gazed at Lyra's Veil, they felt it - a cosmic yearning. Not for fame or medals, but for eternity - the kind woven into forgotten constellations. Together, they whispered their thanks - to the stars, to the forest, to each other. In the small town of Maplewood, basketball was more than a game - it was a way of life. The local high school gym, with its creaky wooden floors and flickering lights, held memories etched into the hearts of generations. Tommy Reynolds, a lanky teenager with dreams as big as the full moon, had grown up shooting hoops in that gym. His father, a former basketball star, had taught him the art of the game - the perfect arc of a jump shot, the rhythm of dribbling, and the magic of teamwork. But Tommy wasn't like his father. He lacked the height and the natural talent. Still, he practiced tirelessly, his sneakers squeaking on the polished floor. He'd stare at the faded championship banners hanging from the rafters - the ones his father had helped win - and imagine his own name there someday. Senior year arrived, and Tommy made the varsity team. He wasn't a star player, but he hustled, diving for loose balls and setting screens. The crowd cheered louder for the flashy slam dunks, but Tommy's heart beat for the fundamentals - the bounce pass, the defensive stance, the pick-and-roll. The state championship game loomed - a David-and-Goliath matchup against the undefeated Oakwood Tigers. They had a towering center, a lightning-fast point guard, and a reputation for crushing opponents. Maplewood was the underdog, the team with heart but not much else. As the final seconds ticked away, the score was tied. Tommy stood at center court, sweat dripping down his face. The gym seemed to hold its breath. He glanced at the banners - the ghosts of champions past urging him on. The ball found its way to Tommy. He dribbled, eyes scanning the court. His father's voice echoed in his mind: 'Trust your instincts, son.' He drove toward the basket, the Tigers' defense closing in. But instead of taking the shot, Tommy passed - the perfect bounce pass to his teammate, Danny. Danny leaped, releasing the ball just as the buzzer sounded. The gym erupted. The ball swirled through the net - a miracle shot that defied physics. Maplewood had won - the underdogs had toppled the giants. Tommy's teammates lifted him on their shoulders. The crowd chanted his name. But as he glanced at the banners, he knew the truth. It wasn't just his shot - it was the culmination of every bounce pass, every defensive stance, every pick-and-roll. His father hugged him - a rare display of emotion. 'You did it, Tommy,' he whispered. 'You made your mark.' And there, in the glow of victory, Tommy realized that sometimes the greatest miracles happen at center court - not in the spotlight, but in the quiet moments of practice, persistence, and heart."
+}
diff --git a/onnxruntime/python/tools/transformers/models/llama/quant_kv_dataloader.py b/onnxruntime/python/tools/transformers/models/llama/quant_kv_dataloader.py
index e8b563261001b..33084aec214c2 100644
--- a/onnxruntime/python/tools/transformers/models/llama/quant_kv_dataloader.py
+++ b/onnxruntime/python/tools/transformers/models/llama/quant_kv_dataloader.py
@@ -1,3 +1,8 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.  All rights reserved.
+# Licensed under the MIT License.  See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
 import argparse
 
 import numpy as np
diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt
index acd9c23aa42d0..307afbc122901 100644
--- a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt
+++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt
@@ -2,3 +2,4 @@
 # Please manually install torch>=2.2.0 with CUDA enabled for the CUDA version installed in your system.
 # Instructions can be found here: https://pytorch.org/get-started/locally/
 onnxruntime-gpu>=1.16.2
+py3nvml
\ No newline at end of file
diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt
index 8b57279295e35..e991c2f27a1a3 100644
--- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt
+++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt
@@ -1,6 +1,7 @@
 optimum>=1.14.1
-transformers>=4.33.2
+transformers>=4.33.2,<= 4.37.2
 torch>=2.2.0
 onnx>=1.14.0
 datasets>=2.8.0
 protobuf==3.20.2
+psutil
\ No newline at end of file
diff --git a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py
index 51a967cf22608..ab92a12343732 100644
--- a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py
+++ b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py
@@ -335,7 +335,7 @@ def test_ort(args, device) -> List[Dict[str, Any]]:
 
     onnx_model_path = find_onnx_model(model_name) if not args.onnx else args.onnx
 
-    optimized = onnx_model_path.endswith("_fp16.onnx") or onnx_model_path.endswith("_fp32.onnx")
+    optimized = onnx_model_path.endswith("_fp16.onnx") or onnx_model_path.endswith("_fp32.onnx")  # noqa: PIE810
     precision = "fp32" if not onnx_model_path.endswith("_fp16.onnx") else "fp16"
 
     model = load_torch_model(model_name, device)
@@ -590,7 +590,7 @@ def run_tests(
     logger.info(f"ORT_LONGFORMER_COMPACT_MEMORY={compact_memory}")
 
     os.environ["ORT_LONGFORMER_USE_HALF4"] = "1" if use_half4 else "0"
-    logger.info("ORT_LONGFORMER_USE_HALF4={}".format("1" if use_half4 else "0"))
+    logger.info("ORT_LONGFORMER_USE_HALF4={}".format("1" if use_half4 else "0"))  # noqa: G001
 
     results = []
     test_times = 1000
diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py
index b7881d064067d..8083778423241 100644
--- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py
+++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py
@@ -13,6 +13,7 @@
 import torch
 from benchmark_helper import Precision
 from fusion_options import AttentionOpType
+from onnx_model import OnnxModel
 from transformers import AutoConfig, AutoModelForCausalLM
 
 from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
@@ -138,9 +139,6 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str):
             # We keep last three layers of Attention as float32 or bfloat16 to avoid overflow.
             node_block_list = (
                 [
-                    "GroupQueryAttention_29",
-                    "GroupQueryAttention_30",
-                    "GroupQueryAttention_31",
                     "Attention_29",
                     "Attention_30",
                     "Attention_31",
@@ -171,6 +169,58 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str):
             quant.process()
             quant.model.save_model_to_file(onnx_path_opt, use_external_data_format=True)
 
+    # This function currently only works for phi2 model
+    def convert_to_use_cuda_graph(self, in_onnx_path: str, out_onnx_path: str):
+        onnx_model = OnnxModel(onnx.load(in_onnx_path, load_external_data=True))
+
+        from onnx import TensorProto, helper
+
+        graph = onnx_model.graph()
+        new_inputs = []
+        for vi in graph.input:
+            if "attention_mask" in vi.name:
+                vi_seqlen_k = helper.make_tensor_value_info(
+                    "seqlens_k",
+                    elem_type=TensorProto.INT32,
+                    shape=["batch_size"],
+                )
+                vi_total_seq_len = helper.make_tensor_value_info(
+                    "total_sequence_length",
+                    elem_type=TensorProto.INT32,
+                    shape=[1],
+                )
+                new_inputs.extend([vi_seqlen_k, vi_total_seq_len])
+            else:
+                new_inputs.append(vi)
+
+        graph.ClearField("input")
+        graph.input.extend(new_inputs)
+
+        gqas = onnx_model.get_nodes_by_op_type("GroupQueryAttention")
+        gqa = gqas[0]
+        seqlens_path = onnx_model.match_parent_path(
+            gqa,
+            ["Cast", "Sub", "ReduceSum", "Cast"],
+            [5, 0, 0, 0],
+        )
+        if seqlens_path is None:
+            raise RuntimeError("Failed to find seqlens path for GroupQueryAttention node.")
+        total_seq_len_path = onnx_model.match_parent_path(
+            gqa,
+            ["Cast", "Gather", "Shape"],
+            [6, 0, 0],
+        )
+        if total_seq_len_path is None:
+            raise RuntimeError("Failed to find total_seq_len path for GroupQueryAttention node.")
+        onnx_model.remove_nodes(seqlens_path)
+        onnx_model.remove_nodes(total_seq_len_path)
+
+        for gqa in gqas:
+            gqa.input[5] = "seqlens_k"
+            gqa.input[6] = "total_sequence_length"
+
+        onnx_model.save(onnx_model.model, out_onnx_path, save_as_external_data=True)
+
 
 def parse_arguments():
     parser = argparse.ArgumentParser()
@@ -238,6 +288,13 @@ def parse_arguments():
         help="Generate int4 ONNX model for ORT VLLM",
     )
 
+    parser.add_argument(
+        "--use_cuda_graph",
+        required=False,
+        action="store_true",
+        help="Use CUDA Graph in decoding process",
+    )
+
     parser.add_argument(
         "--overwrite",
         required=False,
@@ -268,6 +325,13 @@ def parse_arguments():
         help="Run ORT inference example",
     )
 
+    parser.add_argument(
+        "--run_benchmark",
+        required=False,
+        action="store_true",
+        help="Run ORT benchmark",
+    )
+
     parser.add_argument(
         "--skip_export",
         required=False,
@@ -378,6 +442,9 @@ def run_optimize_phi2_onnx(
         ):
             converter.init_attn_type_and_precision(attention_type, precision)
             converter.optimize_phi2_onnx(original_onnx_path, optimized_onnx_path)
+            if args.use_cuda_graph:
+                assert args.fp16_gpu_sm8x or args.int4_gpu_sm8x
+                converter.convert_to_use_cuda_graph(optimized_onnx_path, optimized_onnx_path)
 
         processes = []
         if args.fp32_cpu:
@@ -450,7 +517,7 @@ def run_optimize_phi2_onnx(
         [p.start() for p in processes]
         [p.join() for p in processes]
 
-    if args.run_example:
+    if args.run_example or args.run_benchmark:
         from inference_example import run_phi2
 
         if args.fp16_gpu_sm8x:
@@ -460,6 +527,8 @@ def run_optimize_phi2_onnx(
                 use_buffer_share=True,
                 device_id=args.device_id,
                 use_step=True,
+                use_cuda_graph=args.use_cuda_graph,
+                run_benchmark=args.run_benchmark,
             )
         if args.int4_gpu_sm8x:
             logging.info("Running int4_gpu_sm8x example...")
@@ -468,6 +537,8 @@ def run_optimize_phi2_onnx(
                 use_buffer_share=True,
                 device_id=args.device_id,
                 use_step=True,
+                use_cuda_graph=args.use_cuda_graph,
+                run_benchmark=args.run_benchmark,
             )
         if args.fp32_gpu:
             logging.info("Running fp32_gpu example...")
@@ -477,6 +548,7 @@ def run_optimize_phi2_onnx(
                 device_id=args.device_id,
                 packed_kv=True,
                 use_fp16=False,
+                run_benchmark=args.run_benchmark,
             )
         if args.fp16_gpu:
             logging.info("Running fp16_gpu example...")
@@ -485,6 +557,7 @@ def run_optimize_phi2_onnx(
                 use_buffer_share=False,
                 device_id=args.device_id,
                 packed_kv=True,
+                run_benchmark=args.run_benchmark,
             )
         if args.int4_gpu:
             logging.info("Running int4_gpu example...")
@@ -493,6 +566,7 @@ def run_optimize_phi2_onnx(
                 use_buffer_share=False,
                 device_id=args.device_id,
                 packed_kv=True,
+                run_benchmark=args.run_benchmark,
             )
         if args.fp32_cpu or args.int4_cpu or args.fp16_vllm or args.int4_vllm:
             raise NotImplementedError("CPU/vllm inference example is not implemented yet.")
diff --git a/onnxruntime/python/tools/transformers/models/phi2/inference_example.py b/onnxruntime/python/tools/transformers/models/phi2/inference_example.py
index 28828ffb853cb..eb66533f00834 100644
--- a/onnxruntime/python/tools/transformers/models/phi2/inference_example.py
+++ b/onnxruntime/python/tools/transformers/models/phi2/inference_example.py
@@ -17,6 +17,17 @@
 }
 
 
+def cuda_memcpy(dst, src):
+    from cuda import cudart
+
+    cudart.cudaMemcpy(
+        dst.data_ptr(),
+        src.data_ptr(),
+        src.element_size() * src.nelement(),
+        cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice,
+    )
+
+
 class ORTGenerator:
     def __init__(self, decoder_path):
         self.onnx_decoder_path = decoder_path
@@ -24,13 +35,68 @@ def __init__(self, decoder_path):
         self.head_size = 80
         self.num_layers = 32
         self.max_sequence_length = 2048
+        self.device_id = 0
+        self.use_cuda_graph = False
+        self.use_traced_inputs = False
+        self.static_inputs_map = {}
+
+    def append_static_inputs(self, batch_size):
+        # Only use this function with GQA and with use_cuda_graph=True
+        if batch_size in self.static_inputs_map:
+            return
+
+        cpu_device = torch.device("cpu")
+        cuda_device = torch.device("cuda", self.device_id)
+
+        static_io = {}
+        static_io["input_ids"] = torch.zeros((batch_size, 1), dtype=torch.int32, device=cuda_device)
+        static_io["step"] = torch.tensor([0], dtype=torch.int64, device=cuda_device)
+        static_io["seqlens_k"] = torch.tensor(batch_size * [0], dtype=torch.int32, device=cuda_device)
+        static_io["total_sequence_length"] = torch.tensor([0], dtype=torch.int32, device=cpu_device)
+
+        cache_shape = (batch_size, self.num_heads, self.max_sequence_length, self.head_size)
+        for i in range(self.num_layers):
+            cache = torch.zeros(cache_shape, device=cuda_device, dtype=torch.float16)
+            static_io.update({f"past_key_{i}": cache.contiguous(), f"past_value_{i}": cache.clone().contiguous()})
+
+        static_io["logits"] = torch.zeros((batch_size, 1, 51200), dtype=torch.float16, device=cuda_device)
+
+        self.static_inputs_map[batch_size] = static_io
 
     def get_initial_inputs_and_outputs(self, encodings_dict):
         self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32
 
         input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32)
         attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32)
-        step = torch.tensor([0], device=self.device, dtype=torch.int64)
+
+        batch_size, sequence_length = input_ids.shape
+
+        self.use_traced_inputs = (
+            self.use_cuda_graph
+            and (batch_size in self.static_inputs_map)
+            and self.use_buffer_share
+            and not self.packed_kv
+        )
+
+        step = (
+            torch.tensor([0], device=self.device, dtype=torch.int64)
+            if not self.use_traced_inputs
+            else self.static_inputs_map[batch_size]["step"]
+        )
+
+        seqlens_k = (
+            torch.tensor(batch_size * [0], device=self.device, dtype=torch.int32)
+            if not self.use_traced_inputs
+            else self.static_inputs_map[batch_size]["seqlens_k"]
+        )
+        cuda_memcpy(seqlens_k, attention_mask.sum(1).sub(1).to(torch.int32))
+
+        total_seq_length = (
+            torch.tensor([0], device=torch.device("cpu"), dtype=torch.int32)
+            if not self.use_traced_inputs
+            else self.static_inputs_map[batch_size]["total_sequence_length"]
+        )
+        total_seq_length[0] = sequence_length
 
         inputs = {
             "input_ids": input_ids.contiguous(),
@@ -40,7 +106,10 @@ def get_initial_inputs_and_outputs(self, encodings_dict):
         if self.use_step:
             inputs["step"] = step.contiguous()
 
-        batch_size, sequence_length = input_ids.shape
+        if self.use_cuda_graph:
+            inputs["seqlens_k"] = seqlens_k.contiguous()
+            inputs["total_sequence_length"] = total_seq_length.contiguous()
+            del inputs["attention_mask"]
 
         past_seq_length = self.max_sequence_length if self.use_buffer_share else 0
         past_shape = (
@@ -48,11 +117,23 @@ def get_initial_inputs_and_outputs(self, encodings_dict):
             if self.packed_kv
             else (batch_size, self.num_heads, past_seq_length, self.head_size)
         )
-        for i in range(self.num_layers):
-            past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype)
-            inputs.update(
-                {f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()}
-            ) if not self.packed_kv else inputs.update({f"past_{i}": past.contiguous()})
+
+        if not self.use_traced_inputs:
+            for i in range(self.num_layers):
+                past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype)
+                (
+                    inputs.update({f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()})
+                    if not self.packed_kv
+                    else inputs.update({f"past_{i}": past.contiguous()})
+                )
+        else:
+            for i in range(self.num_layers):
+                inputs.update(
+                    {
+                        f"past_key_{i}": self.static_inputs_map[batch_size][f"past_key_{i}"].contiguous(),
+                        f"past_value_{i}": self.static_inputs_map[batch_size][f"past_value_{i}"].contiguous(),
+                    }
+                )
 
         logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype)
         outputs = {"logits": logits.contiguous()}
@@ -65,9 +146,13 @@ def get_initial_inputs_and_outputs(self, encodings_dict):
             )
             for i in range(self.num_layers):
                 present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
-                outputs.update(
-                    {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()}
-                ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()})
+                (
+                    outputs.update(
+                        {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()}
+                    )
+                    if not self.packed_kv
+                    else outputs.update({f"present_{i}": present.contiguous()})
+                )
 
         return inputs, outputs
 
@@ -111,12 +196,23 @@ def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: d
 
         return io_binding
 
-    def create_session(self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False):
+    def create_session(
+        self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False, use_cuda_graph=False
+    ):
+        self.device_id = device_id
         sess_options = ort.SessionOptions()
-        ep = ("CUDAExecutionProvider", {"device_id": device_id}) if device_id >= 0 else "CPUExecutionProvider"
+        sess_options.log_verbosity_level = 4
+        sess_options.log_severity_level = 4
+        self.use_cuda_graph = use_cuda_graph
+        ep = (
+            ("CUDAExecutionProvider", {"device_id": self.device_id, "enable_cuda_graph": self.use_cuda_graph})
+            if self.device_id >= 0
+            else "CPUExecutionProvider"
+        )
         self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep])
+        self.ro = ort.RunOptions()
 
-        self.device = torch.device("cuda", device_id) if torch.cuda.is_available() else torch.device("cpu")
+        self.device = torch.device("cuda", self.device_id) if torch.cuda.is_available() else torch.device("cpu")
         self.use_fp16 = use_fp16
         self.use_buffer_share = use_buffer_share
         self.packed_kv = packed_kv
@@ -125,9 +221,7 @@ def create_session(self, device_id, use_fp16=True, use_buffer_share=True, packed
         self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
         self.tokenizer.pad_token = "[PAD]"
 
-    def generate(self, prompt, max_length):
-        encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True)
-
+    def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation, benchmark=False):
         inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict)
 
         all_token_ids = inputs["input_ids"].clone()
@@ -136,13 +230,38 @@ def generate(self, prompt, max_length):
         current_length = sequence_length
         has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool)
 
+        if benchmark:
+            import time
+
+            latency = []
+
+        prompt_run = True
         while current_length < max_length:
             io_binding = self.apply_io_binding(self.sess, inputs, outputs)
 
+            if benchmark:
+                start = time.time()
+
             io_binding.synchronize_inputs()
-            self.sess.run_with_iobinding(io_binding)
+            if prompt_run:
+                if self.use_cuda_graph:
+                    # Disable CUDA graph for the prompt run
+                    self.ro.add_run_config_entry("gpu_graph_id", "-1")
+                self.sess.run_with_iobinding(io_binding, self.ro)
+                if self.use_cuda_graph:
+                    # Enable CUDA graph for the decoding run
+                    self.ro.add_run_config_entry(
+                        "gpu_graph_id", str(cuda_graph_annotation) if self.use_traced_inputs else "-1"
+                    )
+                prompt_run = False
+            else:
+                self.sess.run_with_iobinding(io_binding, self.ro)
             io_binding.synchronize_outputs()
 
+            if benchmark:
+                end = time.time()
+                latency.append(end - start)
+
             # Sample with argmax (greedy search)
             next_token_logits = outputs["logits"][:, -1, :]
             next_tokens = torch.argmax(next_token_logits, dim=-1)
@@ -161,16 +280,37 @@ def generate(self, prompt, max_length):
 
             # Update inputs for next inference run
             current_length += 1
+
             inputs["input_ids"] = tokens_to_add.to(torch.int32)
+            if self.use_traced_inputs:
+                cuda_memcpy(self.static_inputs_map[batch_size]["input_ids"], inputs["input_ids"])
+                inputs["input_ids"] = self.static_inputs_map[batch_size]["input_ids"]
+
             if self.use_step:
                 inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64)
-            inputs["attention_mask"] = torch.cat([inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1).to(
-                torch.int32
-            )
+                if self.use_traced_inputs:
+                    cuda_memcpy(self.static_inputs_map[batch_size]["step"], inputs["step"])
+                    inputs["step"] = self.static_inputs_map[batch_size]["step"]
+
+            if self.use_cuda_graph:
+                previous_seqlens_k = inputs["seqlens_k"]
+                inputs["seqlens_k"] = (previous_seqlens_k + (~has_eos).reshape(batch_size, 1)).to(torch.int32)
+                inputs["total_sequence_length"][0] = current_length
+                if self.use_traced_inputs:
+                    cuda_memcpy(self.static_inputs_map[batch_size]["seqlens_k"], inputs["seqlens_k"])
+                    inputs["seqlens_k"] = self.static_inputs_map[batch_size]["seqlens_k"]
+                    self.static_inputs_map[batch_size]["total_sequence_length"][0] = inputs["total_sequence_length"][0]
+                    inputs["total_sequence_length"] = self.static_inputs_map[batch_size]["total_sequence_length"]
+            else:
+                inputs["attention_mask"] = torch.cat(
+                    [inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1
+                ).to(torch.int32)
 
             # Set logits to zeros for next inference run and re-use memory buffer
             if outputs["logits"].shape[1] != 1:
                 outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
+                if self.use_traced_inputs:
+                    outputs["logits"] = self.static_inputs_map[batch_size]["logits"]
             outputs["logits"].zero_()
 
             if not self.use_buffer_share:
@@ -189,15 +329,70 @@ def generate(self, prompt, max_length):
                 )
                 for i in range(self.num_layers):
                     present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
-                    outputs.update(
-                        {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.clone().contiguous()}
-                    ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()})
+                    (
+                        outputs.update(
+                            {
+                                f"present_key_{i}": present.contiguous(),
+                                f"present_value_{i}": present.clone().contiguous(),
+                            }
+                        )
+                        if not self.packed_kv
+                        else outputs.update({f"present_{i}": present.contiguous()})
+                    )
+
+        if benchmark:
+            print(
+                f"Batch size: {batch_size}, Sequence length: {sequence_length}, Token num: {max_length - sequence_length}"
+            )
+            print(f"Prompt letency: {1000 * latency[0]}ms, Token latency: {1000 * np.mean(latency[1:])}ms")
+            return
 
         texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True)
         return texts
 
+    def generate(self, prompt, max_length, cuda_graph_annotation):
+        encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True)
+
+        return self.generate_impl(encodings_dict, max_length, cuda_graph_annotation)
+
+    def generate_benchmark(self, prompt_shape, token_num, cuda_graph_annotation):
+        batch_size, sequence_length = prompt_shape
+        max_length = sequence_length + token_num
+
+        encodings_dict = {}
+        encodings_dict["input_ids"] = torch.randint(0, 50264, (batch_size, sequence_length), dtype=torch.int32).tolist()
+        encodings_dict["attention_mask"] = torch.ones((batch_size, sequence_length), dtype=torch.int32).tolist()
+
+        # Warm up run
+        self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=False)
+
+        # Benchmark run
+        self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=True)
+
+
+def run_phi2(
+    onnx_model_path,
+    use_buffer_share,
+    device_id,
+    packed_kv=False,
+    use_fp16=True,
+    use_step=False,
+    use_cuda_graph=False,
+    run_benchmark=False,
+):
+    generator = ORTGenerator(onnx_model_path)
+    generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step, use_cuda_graph)
+
+    def simple_run(prompt):
+        example_batch_size = len(prompt)
+        if use_cuda_graph:
+            generator.append_static_inputs(batch_size=example_batch_size)
+        texts = generator.generate(prompt, max_length=210, cuda_graph_annotation=example_batch_size)
+
+        for i in range(len(texts)):
+            print("Prompt: ", prompt[i])
+            print("Texts: ", texts[i])
 
-def run_phi2(onnx_model_path, use_buffer_share, device_id, packed_kv=False, use_fp16=True, use_step=False):
     prompt = [
         '''```python
     def print_prime(n):
@@ -206,10 +401,14 @@ def print_prime(n):
     """'''
     ]
 
-    generator = ORTGenerator(onnx_model_path)
-    generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step)
-    texts = generator.generate(prompt, max_length=200)
-
-    for i in range(len(texts)):
-        print("Prompt: ", prompt[i])
-        print("Texts: ", texts[i])
+    if not run_benchmark:
+        simple_run(prompt)
+
+    # Run simple benchmark. Time the decoder only.
+    if run_benchmark:
+        token_num = 32
+        for batch_size in [1, 2, 4, 8]:
+            generator.append_static_inputs(batch_size)
+            for sequence_length in [16, 512]:
+                prompt_shape = (batch_size, sequence_length)
+                generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size)
diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py
index 6c337af78e0a9..3879e25386d53 100755
--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py
+++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py
@@ -315,13 +315,13 @@ def get_optimum_ort_pipeline(
                 directory,
                 provider=provider,
                 session_options=None,
-                use_io_binding=False,
+                use_io_binding=False,  # Not supported by Optimum version 1.17.1 at the time of verification.
             )
         else:
             pipeline = ORTStableDiffusionPipeline.from_pretrained(
                 directory,
                 provider=provider,
-                use_io_binding=False,
+                use_io_binding=False,  # Not supported by Optimum version 1.17.1 at the time of verification.
             )
     elif "xl" in model_name:
         pipeline = ORTStableDiffusionXLPipeline.from_pretrained(
@@ -329,7 +329,7 @@ def get_optimum_ort_pipeline(
             export=True,
             provider=provider,
             session_options=None,
-            use_io_binding=False,
+            use_io_binding=False,  # Not supported by Optimum version 1.17.1 at the time of verification.
         )
         pipeline.save_pretrained(directory)
     else:
@@ -337,7 +337,7 @@ def get_optimum_ort_pipeline(
             model_name,
             export=True,
             provider=provider,
-            use_io_binding=False,
+            use_io_binding=False,  # Not supported by Optimum version 1.17.1 at the time of verification.
         )
         pipeline.save_pretrained(directory)
 
diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py
index 2cd64e8784c6b..a3caba138f44a 100644
--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py
+++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py
@@ -32,13 +32,8 @@
     repeat_prompt,
 )
 
-if __name__ == "__main__":
-    coloredlogs.install(fmt="%(funcName)20s: %(message)s")
-
-    parser = arg_parser("Options for Stable Diffusion Demo")
-    add_controlnet_arguments(parser)
-    args = parse_arguments(is_xl=False, parser=parser)
 
+def main(args):
     controlnet_images, controlnet_scale = process_controlnet_arguments(args)
 
     pipeline, refiner = load_pipelines(args)
@@ -88,3 +83,20 @@ def run_inference(warmup=False):
     pipeline.save_images(images, prompt, negative_prompt, metadata)
 
     pipeline.teardown()
+
+
+if __name__ == "__main__":
+    coloredlogs.install(fmt="%(funcName)20s: %(message)s")
+
+    parser = arg_parser("Options for Stable Diffusion Demo")
+    add_controlnet_arguments(parser)
+    args = parse_arguments(is_xl=False, parser=parser)
+
+    if args.user_compute_stream:
+        import torch
+
+        s = torch.cuda.Stream()
+        with torch.cuda.stream(s):
+            main(args)
+    else:
+        main(args)
diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py
index 19bbb45d77c93..24fa6a2c51343 100644
--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py
+++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py
@@ -132,9 +132,11 @@ def run_demo(args):
 
 
 def run_dynamic_shape_demo(args):
-    """Run demo of generating images with different settings with ORT CUDA provider."""
+    """
+    Run demo of generating images with different settings with ORT CUDA provider.
+    Try "python demo_txt2img_xl.py --max-cuda-graphs 3 --user-compute-stream" to see the effect of multiple CUDA graphs.
+    """
     args.engine = "ORT_CUDA"
-    args.disable_cuda_graph = True
     base, refiner = load_pipelines(args, 1)
 
     prompts = [
@@ -216,7 +218,6 @@ def run_dynamic_shape_demo(args):
 def run_turbo_demo(args):
     """Run demo of generating images with test prompts with ORT CUDA provider."""
     args.engine = "ORT_CUDA"
-    args.disable_cuda_graph = True
     base, refiner = load_pipelines(args, 1)
 
     from datasets import load_dataset
@@ -239,13 +240,7 @@ def run_turbo_demo(args):
         refiner.teardown()
 
 
-if __name__ == "__main__":
-    coloredlogs.install(fmt="%(funcName)20s: %(message)s")
-
-    parser = arg_parser("Options for Stable Diffusion XL Demo")
-    add_controlnet_arguments(parser)
-    args = parse_arguments(is_xl=True, parser=parser)
-
+def main(args):
     no_prompt = isinstance(args.prompt, list) and len(args.prompt) == 1 and not args.prompt[0]
     if no_prompt:
         if args.version == "xl-turbo":
@@ -254,3 +249,20 @@ def run_turbo_demo(args):
             run_dynamic_shape_demo(args)
     else:
         run_demo(args)
+
+
+if __name__ == "__main__":
+    coloredlogs.install(fmt="%(funcName)20s: %(message)s")
+
+    parser = arg_parser("Options for Stable Diffusion XL Demo")
+    add_controlnet_arguments(parser)
+    args = parse_arguments(is_xl=True, parser=parser)
+
+    if args.user_compute_stream:
+        import torch
+
+        s = torch.cuda.Stream()
+        with torch.cuda.stream(s):
+            main(args)
+    else:
+        main(args)
diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py
index 369f31511faca..a50940933eb82 100644
--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py
+++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py
@@ -23,7 +23,7 @@
 import os
 import sys
 from importlib.metadata import PackageNotFoundError, version
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Optional
 
 import controlnet_aux
 import cv2
@@ -246,6 +246,8 @@ def parse_arguments(is_xl: bool, parser):
 
     group = parser.add_argument_group("Options for ORT_CUDA engine only")
     group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.")
+    group.add_argument("--max-cuda-graphs", type=int, default=1, help="Max number of cuda graphs to use. Default 1.")
+    group.add_argument("--user-compute-stream", action="store_true", help="Use user compute stream.")
 
     # TensorRT only options
     group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only")
@@ -400,15 +402,16 @@ def initialize_pipeline(
     max_image_size: int = 1024,
     max_batch_size: int = 16,
     opt_batch_size: int = 1,
-    build_all_tactics=False,
-    do_classifier_free_guidance=False,
-    lcm=False,
+    build_all_tactics: bool = False,
+    do_classifier_free_guidance: bool = False,
+    lcm: bool = False,
     controlnet=None,
     lora_weights=None,
-    lora_scale=1.0,
-    use_fp16_vae=True,
-    use_vae=True,
-    framework_model_dir=None,
+    lora_scale: float = 1.0,
+    use_fp16_vae: bool = True,
+    use_vae: bool = True,
+    framework_model_dir: Optional[str] = None,
+    max_cuda_graphs: int = 1,
 ):
     pipeline_info = PipelineInfo(
         version,
@@ -465,6 +468,7 @@ def initialize_pipeline(
             tmp_dir=os.path.join(work_dir or ".", engine_type.name, pipeline_info.short_name(), "tmp"),
             device_id=torch.cuda.current_device(),
             import_engine_dir=import_engine_dir,
+            max_cuda_graphs=max_cuda_graphs,
         )
     elif engine_type == EngineType.ORT_TRT:
         pipeline.backend.build_engines(
@@ -562,6 +566,7 @@ def load_pipelines(args, batch_size=None):
         "use_fp16_vae": "xl" in args.version,
         "use_vae": True,
         "framework_model_dir": args.framework_model_dir,
+        "max_cuda_graphs": args.max_cuda_graphs,
     }
 
     if "xl" in args.version:
diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py
index 10af22e44d3a5..c2cfc165e32cf 100644
--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py
+++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py
@@ -414,7 +414,6 @@ def get_profile_id(self, batch_size, image_height, image_width, static_batch, st
 
     def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
         """For TensorRT"""
-        pass
 
     def get_shape_dict(self, batch_size, image_height, image_width):
         pass
diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py
index 6ab4858f11f23..56012e223b18c 100644
--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py
+++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py
@@ -6,7 +6,7 @@
 import gc
 import logging
 import os
-from typing import List, Optional
+from typing import Dict, List, Optional
 
 import onnx
 import torch
@@ -15,25 +15,25 @@
 from packaging import version
 
 import onnxruntime as ort
-from onnxruntime.transformers.io_binding_helper import CudaSession
+from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager
 from onnxruntime.transformers.onnx_model import OnnxModel
 
 logger = logging.getLogger(__name__)
 
 
-class OrtCudaEngine(CudaSession):
+class OrtCudaEngine:
     def __init__(
         self,
         onnx_path,
         device_id: int = 0,
         enable_cuda_graph: bool = False,
         disable_optimization: bool = False,
+        max_cuda_graphs: int = 1,
     ):
         self.onnx_path = onnx_path
         self.provider = "CUDAExecutionProvider"
-        self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph)
-        # self.provider_options["enable_skip_layer_norm_strict_mode"] = True
-
+        self.stream = torch.cuda.current_stream().cuda_stream
+        self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph, self.stream)
         session_options = ort.SessionOptions()
 
         # When the model has been optimized by onnxruntime, we can disable optimization to save session creation time.
@@ -52,10 +52,33 @@ def __init__(
         logger.info("created CUDA EP session for %s", onnx_path)
 
         device = torch.device("cuda", device_id)
-        super().__init__(ort_session, device, enable_cuda_graph)
+        self.enable_cuda_graph = enable_cuda_graph
+
+        # Support multiple CUDA graphs for different input shapes.
+        # For clip2 model that disabled cuda graph, max_cuda_graphs is updated to 0 here.
+        self.gpu_binding_manager = GpuBindingManager(
+            ort_session=ort_session,
+            device=device,
+            stream=self.stream,
+            max_cuda_graphs=max_cuda_graphs if enable_cuda_graph else 0,
+        )
+
+        self.current_gpu_binding = None
+
+    def metadata(self, name: str):
+        data = {}
+        if self.current_gpu_binding is not None:
+            if self.current_gpu_binding.last_run_gpu_graph_id >= 0:
+                data[f"{name}.gpu_graph_id"] = self.current_gpu_binding.last_run_gpu_graph_id
+        return data
+
+    def infer(self, feed_dict: Dict[str, torch.Tensor]):
+        return self.current_gpu_binding.infer(feed_dict=feed_dict, disable_cuda_graph_in_run=not self.enable_cuda_graph)
 
     def allocate_buffers(self, shape_dict, device):
-        super().allocate_buffers(shape_dict)
+        self.current_gpu_binding = self.gpu_binding_manager.get_binding(
+            shape_dict=shape_dict, use_cuda_graph=self.enable_cuda_graph
+        )
 
 
 class _ModelConfig:
@@ -220,6 +243,7 @@ def build_engines(
         device_id: int = 0,
         save_fp32_intermediate_model: bool = False,
         import_engine_dir: Optional[str] = None,
+        max_cuda_graphs: int = 1,
     ):
         self.torch_device = torch.device("cuda", device_id)
         self.load_models(framework_model_dir)
@@ -352,6 +376,7 @@ def build_engines(
                 device_id=device_id,
                 enable_cuda_graph=use_cuda_graph,
                 disable_optimization=False,
+                max_cuda_graphs=max_cuda_graphs,
             )
 
             logger.info("%s options for %s: %s", engine.provider, model_name, engine.provider_options)
diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py
index 0ad8b13b6091c..1629537dc294f 100644
--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py
+++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py
@@ -547,7 +547,7 @@ def pt_to_numpy(images: torch.FloatTensor):
         return ((images + 1) / 2).clamp(0, 1).detach().permute(0, 2, 3, 1).float().cpu().numpy()
 
     def metadata(self) -> Dict[str, Any]:
-        return {
+        data = {
             "actual_steps": self.actual_steps,
             "seed": self.get_current_seed(),
             "name": self.pipeline_info.name(),
@@ -555,6 +555,12 @@ def metadata(self) -> Dict[str, Any]:
             "custom_unet": self.pipeline_info.custom_unet(),
         }
 
+        if self.engine_type == EngineType.ORT_CUDA:
+            for engine_name, engine in self.backend.engines.items():
+                data.update(engine.metadata(engine_name))
+
+        return data
+
     def save_images(self, images: List, prompt: List[str], negative_prompt: List[str], metadata: Dict[str, Any]):
         session_id = str(random.randint(1000, 9999))
         for i, image in enumerate(images):
diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md
index 02100266200f8..b44124340a2cd 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/README.md
+++ b/onnxruntime/python/tools/transformers/models/whisper/README.md
@@ -1,5 +1,23 @@
 # Whisper
 
+## Prerequisites
+
+Please note the package versions needed for using Whisper in the `requirements.txt` file that fits your scenario.
+- `requirements-cpu.txt`
+  - For running Whisper on CPU
+- `requirements-cuda.txt`
+  - For running Whisper on CUDA
+  - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file.
+- `requirements.txt`
+  - Package versions needed in each of the above files
+- ffmpeg-python is also required, but please install it by source code with allowed codecs to avoid any patent risks.
+
+In addition to the above packages, you will need to install `ffmpeg` on your machine. Visit the [FFmpeg website](https://ffmpeg.org/) for details. You can also install it natively using package managers.
+
+- Linux: `sudo apt-get install ffmpeg`
+- MacOS: `sudo brew install ffmpeg`
+- Windows: Download from website
+
 ## Exporting Whisper with Beam Search
 
 There are several ways to export Whisper with beam search (using Whisper tiny as an example).
@@ -10,10 +28,10 @@ There are several ways to export Whisper with beam search (using Whisper tiny as
 # From source
 $ git clone https://github.com/microsoft/onnxruntime
 $ cd onnxruntime/onnxruntime/python/tools/transformers/
-$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format
+$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format
 
 # From wheel
-$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format
+$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format
 ```
 
 ### Option 2: end-to-end model from [Olive](https://github.com/microsoft/Olive/tree/main/examples/whisper)
@@ -39,40 +57,49 @@ model.save_pretrained(model_name.split("/")[-1] + "-onnx")
 
 Here are some additional examples for exporting Whisper with beam search.
 
+To see all available options
+```
+# From source:
+$ python3 -m models.whisper.convert_to_onnx --help
+
+# From wheel:
+$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx --help
+```
+
 Export with Forced Decoder Input Ids
 ```
 # From source:
-$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids
+$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --use_forced_decoder_ids
 
 # From wheel:
-$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids
+$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --use_forced_decoder_ids
 ```
 
 Export + Optimize for FP32
 ```
 # From source:
-$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32
+$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp32
 
 # From wheel:
-$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32
+$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp32
 ```
 
 Export + Optimize for FP16 and GPU
 ```
 # From source:
-$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision
+$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision
 
 # From wheel:
-$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision
+$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision
 ```
 
 Export + Quantize for INT8
 ```
 # From source:
-$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer
+$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --precision int8 --quantize_embedding_layer
 
 # From wheel:
-$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer
+$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --precision int8 --quantize_embedding_layer
 ```
 
 ## Benchmark Whisper
diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py
index 759ae6d14f184..3f7a292a02748 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py
@@ -1,3 +1,9 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.  All rights reserved.
+# Licensed under the MIT License.  See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
 import argparse
 import ast
 import datetime
@@ -54,6 +60,8 @@ def load_via_numpy():
             inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32)
         if args.has_logits_processor:
             inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32)
+        if args.has_temperature:
+            inputs["temperature"] = np.array([args.temperature], dtype=np.float32)
 
     # Measure time taken to load audio file
     logger.info(f"Load audio: {args.audio_path}")
@@ -137,10 +145,10 @@ def get_model(args: argparse.Namespace):
         start_time = time.time()
         model = ORTModelForSpeechSeq2Seq.from_pretrained(
             args.hf_ort_dir_path,
-            use_io_binding=(args.device != "cpu"),
             provider=provider,
             provider_options=provider_options,
             session_options=sess_options,
+            use_io_binding=True,  # Avoid memory copy overhead
         )
         end_time = time.time()
 
@@ -163,6 +171,7 @@ def get_model(args: argparse.Namespace):
 def time_fn(args, fn, inputs):
     warmup_inputs = inputs[0] if type(inputs) is tuple else inputs
     benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs
+    torch_device = torch.device(args.target_device)
 
     # Warm up
     warmup_range = (
@@ -180,7 +189,7 @@ def time_fn(args, fn, inputs):
 
     # Benchmark
     if args.device != "cpu":
-        torch.cuda.synchronize()
+        torch.cuda.synchronize(torch_device)
     start_time = time.time()
 
     bench_range = (
@@ -192,7 +201,7 @@ def time_fn(args, fn, inputs):
         fn(benchmark_inputs)
 
     if args.device != "cpu":
-        torch.cuda.synchronize()
+        torch.cuda.synchronize(torch_device)
     end_time = time.time()
 
     # Newline print after trange in order to print metrics on new lines without progress bar on same line
@@ -401,7 +410,8 @@ def handle_output(output):
         actual_output = handle_output(ort_outputs[0][0])
         logger.info(f"Generated token length: {len(actual_output)} tokens")
         transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0]
-        logger.info(f"Transcription: {transcription}")
+        # print to stdout as the output for comparison
+        print(f"{transcription}")
 
     measure_fn(args, generate_fn, ort_inputs)
 
@@ -500,7 +510,13 @@ def parse_args():
         "--logits-processor",
         type=int,
         default=1,
-        help="Type of logits processor to use. See `BeamSearch` in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/contrib_ops/contrib_defs.cc for details.",
+        help="Whether to use timestamps logits processor or not (0 for false, 1 for true).",
+    )
+    parser.add_argument(
+        "--temperature",
+        type=float,
+        default=1.0,
+        help="Temperature value for generation.",
     )
 
     # Args for accessing detailed info
@@ -581,6 +597,7 @@ def main():
         args.has_audio_stream = "audio_stream" in ort_model_inputs
         setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs)  # noqa: B010
         setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs)  # noqa: B010
+        setattr(args, "has_temperature", "temperature" in ort_model_inputs)  # noqa: B010
 
         if args.decoder_input_ids == []:
             args.decoder_input_ids = [config.decoder_start_token_id]
diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py
index d205a2d340721..814b0dd1ef6ac 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py
@@ -1,3 +1,9 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.  All rights reserved.
+# Licensed under the MIT License.  See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
 import argparse
 import datetime
 import json
diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py
index bb697fe1e1506..5921e4ed42936 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py
@@ -28,17 +28,25 @@
 def parse_arguments(argv=None):
     parser = argparse.ArgumentParser()
 
-    pretrained_models = PRETRAINED_WHISPER_MODELS
-    parser.add_argument(
+    conversion_args = parser.add_argument_group("Conversion Process Args")
+    optional_inputs = parser.add_argument_group("Optional Inputs (for WhisperBeamSearch op)")
+    optional_outputs = parser.add_argument_group("Optional Outputs (for WhisperBeamSearch op)")
+    quant_args = parser.add_argument_group("INT8 Quantization Args")
+
+    #################################
+    # Conversion options for Whisper
+    #################################
+
+    conversion_args.add_argument(
         "-m",
         "--model_name_or_path",
         required=False,
         default=PRETRAINED_WHISPER_MODELS[0],
         type=str,
-        help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models),
+        help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_WHISPER_MODELS),
     )
 
-    parser.add_argument(
+    conversion_args.add_argument(
         "--model_impl",
         required=False,
         default="hf",
@@ -47,7 +55,7 @@ def parse_arguments(argv=None):
         help="Select implementation for export of encoder and decoder subgraphs",
     )
 
-    parser.add_argument(
+    conversion_args.add_argument(
         "--cache_dir",
         required=False,
         type=str,
@@ -55,7 +63,7 @@ def parse_arguments(argv=None):
         help="Directory to cache pre-trained models",
     )
 
-    parser.add_argument(
+    conversion_args.add_argument(
         "--output",
         required=False,
         type=str,
@@ -63,19 +71,24 @@ def parse_arguments(argv=None):
         help="Output directory",
     )
 
-    parser.add_argument(
+    conversion_args.add_argument(
         "-o",
         "--optimize_onnx",
         required=False,
         action="store_true",
         help="Use optimizer.py to optimize onnx model",
     )
-    parser.set_defaults(optimize_onnx=False)
+    conversion_args.set_defaults(optimize_onnx=False)
 
-    parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
-    parser.set_defaults(use_gpu=False)
+    conversion_args.add_argument(
+        "--use_gpu",
+        required=False,
+        action="store_true",
+        help="Use GPU for model inference",
+    )
+    conversion_args.set_defaults(use_gpu=False)
 
-    parser.add_argument(
+    conversion_args.add_argument(
         "-p",
         "--precision",
         required=False,
@@ -85,221 +98,226 @@ def parse_arguments(argv=None):
         help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization",
     )
 
-    parser.add_argument("--verbose", required=False, action="store_true")
-    parser.set_defaults(verbose=False)
-
-    parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
-    parser.set_defaults(use_external_data_format=False)
-
-    parser.add_argument(
-        "-s",
-        "--use_decoder_start_token",
+    conversion_args.add_argument(
+        "--use_int64_inputs",
         required=False,
         action="store_true",
-        help="Use config.decoder_start_token_id. Otherwise, add an extra graph input to \
-              the encoder-decoder-init subgraph for decoder_input_ids.",
+        help="Use int64 instead of int32 for input_ids and attention_mask.",
     )
-    parser.set_defaults(use_decoder_start_token=False)
+    conversion_args.set_defaults(use_int64_inputs=False)
 
-    parser.add_argument(
-        "-f",
-        "--use_forced_decoder_ids",
+    conversion_args.add_argument(
+        "--disable_auto_mixed_precision",
         required=False,
         action="store_true",
-        help="Use decoder_input_ids as an extra graph input to the beam search op",
+        help="Use pure fp16 instead of mixed precision",
     )
-    parser.set_defaults(use_forced_decoder_ids=False)
+    conversion_args.set_defaults(disable_auto_mixed_precision=False)
 
-    parser.add_argument(
-        "-l",
-        "--use_logits_processor",
+    conversion_args.add_argument(
+        "-r",
+        "--provider",
         required=False,
-        action="store_true",
-        help="Use logits_processor as an extra graph input to enable specific logits processing",
+        type=str,
+        default="cpu",
+        choices=list(PROVIDERS.keys()),
+        help="Provider to benchmark. Default is CPUExecutionProvider.",
     )
-    parser.set_defaults(use_specific_logits_processor=False)
 
-    parser.add_argument(
-        "-v",
-        "--use_vocab_mask",
+    conversion_args.add_argument(
+        "--verbose",
         required=False,
         action="store_true",
-        help="Use vocab_mask as an extra graph input to enable specific logits processing",
+        help="Enable verbose logging",
     )
-    parser.set_defaults(use_vocab_mask=False)
+    conversion_args.set_defaults(verbose=False)
 
-    parser.add_argument(
-        "-u",
-        "--use_prefix_vocab_mask",
+    conversion_args.add_argument(
+        "-e",
+        "--use_external_data_format",
         required=False,
         action="store_true",
-        help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing",
+        help="Save weights in external file. Necessary for 'small', 'medium', and 'large' models. Optional for 'tiny' and 'base' models.",
     )
-    parser.set_defaults(use_prefix_vocab_mask=False)
+    conversion_args.set_defaults(use_external_data_format=False)
 
-    parser.add_argument(
+    conversion_args.add_argument(
         "-w",
         "--overwrite",
         required=False,
         action="store_true",
-        help="overwrite existing ONNX model",
+        help="Overwrite existing ONNX model",
     )
-    parser.set_defaults(overwrite=False)
+    conversion_args.set_defaults(overwrite=False)
 
-    parser.add_argument(
-        "--disable_auto_mixed_precision",
+    conversion_args.add_argument(
+        "--separate_encoder_and_decoder_init",
         required=False,
         action="store_true",
-        help="use pure fp16 instead of mixed precision",
+        help="Do not merge encoder and decoder init to initialize past KV caches. Output 3 instead of 2 ONNX models.",
     )
-    parser.set_defaults(disable_auto_mixed_precision=False)
+    conversion_args.set_defaults(separate_encoder_and_decoder_init=False)
 
-    parser.add_argument(
-        "--separate_encoder_and_decoder_init",
+    conversion_args.add_argument(
+        "--no_beam_search_op",
         required=False,
         action="store_true",
-        help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.",
+        help="Do not produce model with WhisperBeamSearch op, which chains encdecinit and decoder models into one op.",
     )
-    parser.set_defaults(separate_encoder_and_decoder_init=False)
+    conversion_args.set_defaults(no_beam_search_op=False)
 
-    parser.add_argument(
-        "--use_int64_inputs",
+    conversion_args.add_argument(
+        "--state_dict_path",
+        type=str,
+        default="",
+        help="Filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
+    )
+
+    #############################################################
+    # Optional inputs for Whisper
+    # (listed below in the order that WhisperBeamSearch expects)
+    #############################################################
+
+    optional_inputs.add_argument(
+        "-v",
+        "--use_vocab_mask",
         required=False,
         action="store_true",
-        help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.",
+        help="Use vocab_mask as an extra graph input to enable specific logits processing",
     )
-    parser.set_defaults(use_int64_inputs=False)
+    optional_inputs.set_defaults(use_vocab_mask=False)
 
-    parser.add_argument(
-        "--chain_model",
+    optional_inputs.add_argument(
+        "-u",
+        "--use_prefix_vocab_mask",
         required=False,
         action="store_true",
-        help="Produce beam search model with chained encdecinit and decoder.",
+        help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing",
     )
-    parser.set_defaults(chain_model=True)
+    optional_inputs.set_defaults(use_prefix_vocab_mask=False)
 
-    parser.add_argument(
-        "--use_whisper_beamsearch",
+    optional_inputs.add_argument(
+        "-f",
+        "--use_forced_decoder_ids",
         required=False,
         action="store_true",
-        help="When chain_model, using WhisperBeamSearch operator rather than BeamSearch operator. \
-              It will be set to true when collect_cross_qk, extra_decoding_ids or output_no_speech_probs is set.",
+        help="Use decoder_input_ids as an extra graph input to the beam search op",
     )
-    parser.set_defaults(use_whisper_beamsearch=False)
+    optional_inputs.set_defaults(use_forced_decoder_ids=False)
 
-    parser.add_argument(
-        "--extra_decoding_ids",
+    optional_inputs.add_argument(
+        "-l",
+        "--use_logits_processor",
         required=False,
         action="store_true",
-        help="Need extra starting decoding ids for some feature like cross qk. Default if false.",
+        help="Use logits_processor as an extra graph input to enable specific logits processing",
     )
-    parser.set_defaults(extra_decoding_ids=False)
+    optional_inputs.set_defaults(use_specific_logits_processor=False)
 
-    parser.add_argument(
+    optional_inputs.add_argument(
         "--collect_cross_qk",
         required=False,
         action="store_true",
         help="Beam search model collect stacked cross QK.",
     )
-    parser.set_defaults(collect_cross_qk=False)
+    optional_inputs.set_defaults(collect_cross_qk=False)
 
-    parser.add_argument(
-        "--output_cross_qk",
+    optional_inputs.add_argument(
+        "--extra_decoding_ids",
         required=False,
         action="store_true",
-        help="Beam search model output collected qk as output. Also hint collect_cross_qk",
+        help="Need extra starting decoding ids for some feature like cross qk. Default if false.",
     )
-    parser.set_defaults(output_cross_qk=False)
+    optional_inputs.set_defaults(extra_decoding_ids=False)
 
-    parser.add_argument(
-        "--no_speech_token_id",
-        default=50362,
+    optional_inputs.add_argument(
+        "-t",
+        "--use_temperature",
+        required=False,
+        action="store_true",
+        help="Use temperature as an extra graph input for the WhisperBeamSearch op",
+    )
+    optional_inputs.set_defaults(use_temperature=False)
+
+    optional_inputs.add_argument(
+        "--no_repeat_ngram_size",
         type=int,
-        help="specify no_speech_token_id. Default is 50362. if >= 0, will be add into beam search attr. \
-              Note that default value maybe different between the multilingual and English-only models.",
+        default=0,
+        help="default to 0",
     )
 
-    parser.add_argument(
-        "--output_no_speech_probs",
+    #############################################################
+    # Optional outputs for Whisper
+    # (listed below in the order that WhisperBeamSearch expects)
+    #############################################################
+
+    optional_outputs.add_argument(
+        "--output_sequence_scores",
         required=False,
         action="store_true",
-        help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.",
+        help="Beam search model output scores for each generated sequence.",
     )
-    parser.set_defaults(output_no_speech_probs=False)
+    optional_outputs.set_defaults(output_sequence_scores=False)
 
-    parser.add_argument(
+    optional_outputs.add_argument(
         "--output_scores",
         required=False,
         action="store_true",
         help="Beam search model output scores over vocab per generated token.",
     )
-    parser.set_defaults(output_scores=False)
+    optional_outputs.set_defaults(output_scores=False)
 
-    parser.add_argument(
-        "--output_sequence_scores",
+    optional_outputs.add_argument(
+        "--output_cross_qk",
         required=False,
         action="store_true",
-        help="Beam search model output scores for each generated sequence.",
+        help="Beam search model output collected qk as output. Also hint collect_cross_qk",
     )
-    parser.set_defaults(output_sequence_scores=False)
+    optional_outputs.set_defaults(output_cross_qk=False)
 
-    parser.add_argument(
+    optional_outputs.add_argument(
         "--cross_qk_onnx_model",
         required=False,
         type=str,
         default=None,
-        help="the model which consume cross_qk.",
+        help="The model which consumes cross_qk outputs.",
     )
 
-    parser.add_argument(
-        "--beam_output_model",
-        type=str,
-        default="whisper_beamsearch.onnx",
-        help="default name is whisper_beamsearch.onnx.",
+    optional_outputs.add_argument(
+        "--output_no_speech_probs",
+        required=False,
+        action="store_true",
+        help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.",
     )
+    optional_outputs.set_defaults(output_no_speech_probs=False)
 
-    parser.add_argument(
+    ###################################
+    # Quantization options for Whisper
+    ###################################
+
+    quant_args.add_argument(
         "--quantize_embedding_layer",
         required=False,
         action="store_true",
         help="Quantize MatMul, GEMM, and Gather.",
     )
-    parser.set_defaults(quantize_embedding_layer=False)
+    quant_args.set_defaults(quantize_embedding_layer=False)
 
-    parser.add_argument(
+    quant_args.add_argument(
         "--quantize_per_channel",
         required=False,
         action="store_true",
         help="Quantize weights per each channel.",
     )
-    parser.set_defaults(quantize_per_channel=False)
+    quant_args.set_defaults(quantize_per_channel=False)
 
-    parser.add_argument(
+    quant_args.add_argument(
         "--quantize_reduce_range",
         required=False,
         action="store_true",
         help="Quantize weights with 7 bits.",
     )
-    parser.set_defaults(quantize_reduce_range=False)
-
-    parser.add_argument("--no_repeat_ngram_size", type=int, default=0, help="default to 0")
-
-    parser.add_argument(
-        "--state_dict_path",
-        type=str,
-        default="",
-        help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
-    )
-
-    parser.add_argument(
-        "-r",
-        "--provider",
-        required=False,
-        type=str,
-        default="cpu",
-        choices=list(PROVIDERS.keys()),
-        help="Provider to benchmark. Default is CPUExecutionProvider.",
-    )
+    quant_args.set_defaults(quantize_reduce_range=False)
 
     args = parser.parse_args(argv)
     args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk
@@ -317,7 +335,7 @@ def export_onnx_models(
     optimize_onnx,
     precision,
     verbose,
-    use_decoder_start_token: bool = False,
+    use_forced_decoder_ids: bool = False,
     merge_encoder_and_decoder_init: bool = True,
     overwrite: bool = False,
     disable_auto_mixed_precision: bool = False,
@@ -362,7 +380,6 @@ def export_onnx_models(
                 onnx_path,
                 verbose,
                 use_external_data_format,
-                use_decoder_input_ids=not use_decoder_start_token,
                 use_int32_inputs=use_int32_inputs,
             )
         else:
@@ -397,16 +414,16 @@ def export_onnx_models(
                     quantization.quantize_dynamic(
                         onnx_path,
                         output_path,
-                        op_types_to_quantize=["MatMul", "Gemm", "Gather"]
-                        if quantize_embedding_layer
-                        else ["MatMul", "Gemm"],
+                        op_types_to_quantize=(
+                            ["MatMul", "Gemm", "Gather"] if quantize_embedding_layer else ["MatMul", "Gemm"]
+                        ),
                         use_external_data_format=use_external_data_format,
                         per_channel=quantize_per_channel,
                         reduce_range=quantize_reduce_range,
                         extra_options={"MatMulConstBOnly": True},
                     )
             else:
-                logger.info(f"Skip optimizing: existed ONNX model {onnx_path}")
+                logger.info(f"Skip optimizing: existing ONNX model {onnx_path}")
         else:
             output_path = onnx_path
 
@@ -449,7 +466,7 @@ def main(argv=None):
         args.optimize_onnx,
         args.precision,
         args.verbose,
-        args.use_decoder_start_token,
+        args.use_forced_decoder_ids,
         not args.separate_encoder_and_decoder_init,
         args.overwrite,
         args.disable_auto_mixed_precision,
@@ -462,7 +479,7 @@ def main(argv=None):
     )
 
     max_diff = 0
-    if args.chain_model:
+    if not args.no_beam_search_op:
         logger.info("Chaining model ... :")
         args.beam_model_output_dir = WhisperHelper.get_onnx_path(
             output_dir,
diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt
new file mode 100644
index 0000000000000..db2cd95324328
--- /dev/null
+++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt
@@ -0,0 +1,2 @@
+-r requirements.txt
+onnxruntime>=1.17.1
\ No newline at end of file
diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt
new file mode 100644
index 0000000000000..9bd215de9bc09
--- /dev/null
+++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt
@@ -0,0 +1,4 @@
+-r requirements.txt
+# Please manually install torch>=1.13.0 with CUDA enabled for the CUDA version installed in your system.
+# Instructions can be found here: https://pytorch.org/get-started/locally/
+onnxruntime-gpu>=1.17.1
diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt
new file mode 100644
index 0000000000000..4cb808501713c
--- /dev/null
+++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt
@@ -0,0 +1,13 @@
+torch>=1.13.0
+transformers>=4.24.0
+openai-whisper
+datasets
+soundfile
+librosa
+optimum
+onnxruntime-extensions>=0.9.0
+onnx>=1.15.0
+protobuf==3.20.2
+numpy==1.23.3
+psutil
+py3nvml
diff --git a/onnxruntime/python/tools/transformers/models/whisper/test/1272-141231-0002.mp3 b/onnxruntime/python/tools/transformers/models/whisper/test/1272-141231-0002.mp3
new file mode 100644
index 0000000000000..6d220f5ede6a7
Binary files /dev/null and b/onnxruntime/python/tools/transformers/models/whisper/test/1272-141231-0002.mp3 differ
diff --git a/onnxruntime/python/tools/transformers/models/whisper/test/whisper_ort_output.txt b/onnxruntime/python/tools/transformers/models/whisper/test/whisper_ort_output.txt
new file mode 100644
index 0000000000000..e3dbef248d0b2
--- /dev/null
+++ b/onnxruntime/python/tools/transformers/models/whisper/test/whisper_ort_output.txt
@@ -0,0 +1 @@
+ the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about
diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py
index a74666b7af297..0b128f122e0f4 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py
@@ -1,3 +1,9 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.  All rights reserved.
+# Licensed under the MIT License.  See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
 import logging
 import os
 
@@ -9,7 +15,7 @@
     update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha,
 )
 from onnx import TensorProto, helper
-from transformers import WhisperConfig
+from transformers import WhisperConfig, WhisperTokenizer
 
 logger = logging.getLogger(__name__)
 
@@ -23,11 +29,22 @@ def verify_inputs(beam_inputs, graph_inputs):
         assert graph_input.name in beam_input
 
 
+def clean_list(arr, remove_all_strings=True):
+    if remove_all_strings:
+        # Remove all empty strings in list
+        return list(filter(lambda elm: elm != "", arr))
+
+    # Remove empty strings at end of list
+    while len(arr) > 0:
+        if arr[-1] == "":
+            arr.pop()
+        else:
+            break
+    return arr
+
+
 def chain_model(args):
-    # Load encoder/decoder and insert necessary (but unused) graph inputs expected by BeamSearch op or WhisperBeamSearch op
-    args.use_whisper_beamsearch = (
-        args.use_whisper_beamsearch or args.collect_cross_qk or args.output_no_speech_probs or args.extra_decoding_ids
-    )
+    # Load encoder/decoder and insert necessary (but unused) graph inputs expected by WhisperBeamSearch op
     encoder_model = onnx.load_model(args.encoder_path, load_external_data=True)
     encoder_model.graph.name = "encoderdecoderinit subgraph"
 
@@ -35,7 +52,10 @@ def chain_model(args):
     decoder_model.graph.name = "decoder subgraph"
 
     config = WhisperConfig.from_pretrained(args.model_name_or_path)
+    tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path)
 
+    # Create inputs/outputs for WhisperBeamSearch op
+    temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature"
     beam_inputs = [
         "input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features",
         "max_length",
@@ -44,38 +64,27 @@ def chain_model(args):
         "num_return_sequences",
         "length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty",
         "repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty",
-        "vocab_mask" if args.use_prefix_vocab_mask else "",
+        "vocab_mask" if args.use_vocab_mask else "",
         "prefix_vocab_mask" if args.use_prefix_vocab_mask else "",
         "",  # attention mask
         "decoder_input_ids" if args.use_forced_decoder_ids else "",
         "logits_processor" if args.use_logits_processor else "",
+        "cross_qk_layer_head" if args.collect_cross_qk else "",
+        "extra_decoding_ids" if args.extra_decoding_ids else "",
+        temperature_name if args.use_temperature else "",
     ]
 
-    beam_outputs = ["sequences"]
-    if args.output_sequence_scores:
-        beam_outputs.append("sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores")
-    if args.output_scores:
-        beam_outputs.append("scores_fp16" if args.precision == Precision.FLOAT16 else "scores")
-
-    if args.use_whisper_beamsearch:
-        assert len(beam_inputs) == 12
-        beam_inputs.extend(
-            [
-                "cross_qk_layer_head" if args.collect_cross_qk else "",
-                "extra_decoding_ids" if args.extra_decoding_ids else "",
-            ]
-        )
-        if args.collect_cross_qk:
-            while len(beam_outputs) < 3:
-                beam_outputs.extend([""])
-            beam_outputs.extend(["cross_qk"])
-        if args.output_no_speech_probs:
-            while len(beam_outputs) < 4:
-                beam_outputs.extend([""])
-            beam_outputs.extend(["no_speech_probs_beam"])
-
-    input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None
-    output_scores_cast_node = output_sequence_scores_cast_node = None
+    sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores"
+    scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores"
+    beam_outputs = [
+        "sequences",
+        sequence_scores_name if args.output_sequence_scores else "",
+        scores_name if args.output_scores else "",
+        "cross_qk" if args.collect_cross_qk else "",
+        "no_speech_probs_beam" if args.output_no_speech_probs else "",
+    ]
+
+    graph_nodes = []
     if args.precision == Precision.FLOAT16:
         input_features_cast_node = helper.make_node(
             "Cast",
@@ -98,6 +107,18 @@ def chain_model(args):
             name="CastRepetitionPenaltyToFp16",
             to=TensorProto.FLOAT16,
         )
+        graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node])
+
+        if args.use_temperature:
+            temp_cast_node = helper.make_node(
+                "Cast",
+                inputs=["temperature"],
+                outputs=["temperature_fp16"],
+                name="temperature_to_fp16",
+                to=TensorProto.FLOAT16,
+            )
+            graph_nodes.append(temp_cast_node)
+
         if args.output_sequence_scores:
             output_sequence_scores_cast_node = helper.make_node(
                 "Cast",
@@ -106,6 +127,8 @@ def chain_model(args):
                 name="CastOutputSequenceScoresToFp32",
                 to=TensorProto.FLOAT,
             )
+            graph_nodes.append(output_sequence_scores_cast_node)
+
         if args.output_scores:
             output_scores_cast_node = helper.make_node(
                 "Cast",
@@ -114,26 +137,40 @@ def chain_model(args):
                 name="CastScoresToFp32",
                 to=TensorProto.FLOAT,
             )
-
-    operator_type = "WhisperBeamSearch" if args.use_whisper_beamsearch else "BeamSearch"
-    node = helper.make_node(operator_type, inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_zcode")
-    node.domain = "com.microsoft"
-    node.attribute.extend(
-        [
-            helper.make_attribute("eos_token_id", config.eos_token_id),
-            helper.make_attribute("pad_token_id", config.pad_token_id),
-            helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id),
-            helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
-            helper.make_attribute("early_stopping", True),
-            helper.make_attribute("model_type", 2),
-        ]
+            graph_nodes.append(output_scores_cast_node)
+
+    # Create WhisperBeamSearch op
+    beam_search_attrs = [
+        helper.make_attribute("eos_token_id", config.eos_token_id),
+        helper.make_attribute("pad_token_id", config.pad_token_id),
+        helper.make_attribute(
+            "decoder_start_token_id", config.decoder_start_token_id
+        ),  # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0]
+        helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]),
+        helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0]),
+        helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0]),
+        (
+            helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(["<|nospeech|>"])[0])
+            if args.output_no_speech_probs
+            else ""
+        ),
+        helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0]),
+        helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0]),
+        helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
+        helper.make_attribute("early_stopping", True),
+        helper.make_attribute("model_type", 2),
+        helper.make_attribute("decoder_output_cross_qk", 1) if args.collect_cross_qk else "",
+    ]
+    node = helper.make_node(
+        "WhisperBeamSearch",
+        inputs=clean_list(beam_inputs, remove_all_strings=False),
+        outputs=clean_list(beam_outputs, remove_all_strings=False),
+        name="BeamSearch",
+        domain="com.microsoft",
     )
-    if args.use_whisper_beamsearch:
-        if args.collect_cross_qk:
-            node.attribute.extend([helper.make_attribute("decoder_output_cross_qk", 1)])
-        if args.no_speech_token_id >= 0:
-            node.attribute.extend([helper.make_attribute("no_speech_token", args.no_speech_token_id)])
+    node.attribute.extend(clean_list(beam_search_attrs, remove_all_strings=True))
 
+    # Graph inputs
     input_features = helper.make_tensor_value_info(
         "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"]
     )
@@ -143,73 +180,63 @@ def chain_model(args):
     num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1])
     length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
     repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
+    vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size])
+    prefix_vocab_mask = helper.make_tensor_value_info(
+        "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size]
+    )
+    decoder_input_ids = helper.make_tensor_value_info(
+        "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"]
+    )
+    logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1])
+    cross_qk_layer_head = helper.make_tensor_value_info("cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2])
+    extra_decoding_ids = helper.make_tensor_value_info(
+        "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"]
+    )
+    temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1])
 
-    graph_inputs = [
-        input_features,
-        max_length,
-        min_length,
-        num_beams,
-        num_return_sequences,
-        length_penalty,
-        repetition_penalty,
-    ]
-    if args.use_vocab_mask:
-        vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size])
-        graph_inputs.append(vocab_mask)
-
-    if args.use_prefix_vocab_mask:
-        prefix_vocab_mask = helper.make_tensor_value_info(
-            "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size]
-        )
-        graph_inputs.append(prefix_vocab_mask)
-
-    if args.use_forced_decoder_ids:
-        decoder_input_ids = helper.make_tensor_value_info(
-            "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"]
-        )
-        graph_inputs.append(decoder_input_ids)
-
-    if args.use_logits_processor:
-        logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1])
-        graph_inputs.append(logits_processor)
-
-    if args.collect_cross_qk:
-        cross_qk_layer_head = helper.make_tensor_value_info(
-            "cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2]
-        )
-        graph_inputs.append(cross_qk_layer_head)
-
-    if args.extra_decoding_ids:
-        extra_decoding_ids = helper.make_tensor_value_info(
-            "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"]
-        )
-        graph_inputs.append(extra_decoding_ids)
+    graph_inputs = clean_list(
+        [
+            input_features,
+            max_length,
+            min_length,
+            num_beams,
+            num_return_sequences,
+            length_penalty,
+            repetition_penalty,
+            vocab_mask if args.use_vocab_mask else "",
+            prefix_vocab_mask if args.use_prefix_vocab_mask else "",
+            decoder_input_ids if args.use_forced_decoder_ids else "",
+            logits_processor if args.use_logits_processor else "",
+            cross_qk_layer_head if args.collect_cross_qk else "",
+            extra_decoding_ids if args.extra_decoding_ids else "",
+            temperature if args.use_temperature else "",
+        ]
+    )
 
-    # graph outputs
+    # Graph outputs
     sequences = helper.make_tensor_value_info(
         "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"]
     )
-    graph_outputs = [sequences]
-    if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk):
-        cross_qk = helper.make_tensor_value_info(
-            "cross_qk",
-            TensorProto.FLOAT,
-            ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"],
-        )
-        graph_outputs.extend([cross_qk])
-
-    if args.output_no_speech_probs:
-        no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"])
-        graph_outputs.extend([no_speech_probs])
-
-    if args.output_sequence_scores:
-        sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"])
-        graph_outputs.extend([sequence_scores])
+    sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"])
+    scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"])
+    cross_qk = helper.make_tensor_value_info(
+        "cross_qk",
+        TensorProto.FLOAT,
+        ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"],
+    )
+    no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"])
 
-    if args.output_scores:
-        scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"])
-        graph_outputs.extend([scores])
+    graph_outputs = clean_list(
+        [
+            sequences,
+            sequence_scores if args.output_sequence_scores else "",
+            scores if args.output_scores else "",
+            cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "",
+            no_speech_probs if args.output_no_speech_probs else "",
+        ]
+    )
 
+    # Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference
     if hasattr(args, "use_gpu") and args.use_gpu:
         if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph):
             logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!")
@@ -230,19 +257,7 @@ def chain_model(args):
 
     opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)]
 
-    graph_nodes = (
-        [
-            input_features_cast_node,
-            len_pen_cast_node,
-            rep_pen_cast_node,
-            node,
-            output_sequence_scores_cast_node,
-            output_scores_cast_node,
-        ]
-        if args.precision == Precision.FLOAT16
-        else [node]
-    )
-    graph_nodes = [node for node in graph_nodes if node is not None]
+    graph_nodes.append(node)
     if args.output_no_speech_probs:
         prob_cast_node = helper.make_node(
             "Cast",
@@ -251,9 +266,16 @@ def chain_model(args):
             name="no_speech_probs_cast_to_fp32",
             to=TensorProto.FLOAT,
         )
-        graph_nodes.extend([prob_cast_node])
-
-    beam_graph = helper.make_graph(graph_nodes, "beam-search-test", graph_inputs, graph_outputs, initializers)
+        graph_nodes.append(prob_cast_node)
+
+    # Make graph with WhisperBeamSearch op
+    beam_graph = helper.make_graph(
+        graph_nodes,
+        name="WhisperBeamSearch Graph",
+        inputs=graph_inputs,
+        outputs=graph_outputs,
+        initializer=initializers,
+    )
     beam_graph_input_names = [gi.name for gi in graph_inputs]
     beam_graph_output_names = [go.name for go in graph_outputs]
 
@@ -287,10 +309,12 @@ def chain_model(args):
         ir_version=decoder_model.ir_version,
     )
 
+    # Save WhisperBeamSearch graph and external data
     if os.path.isfile(args.beam_model_output_dir):
         logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}")
         os.remove(args.beam_model_output_dir)
         os.remove(args.beam_model_output_dir + ".data")
+
     onnx.save(
         beam_model,
         args.beam_model_output_dir,
diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py
index 0d69960a095ac..93fd64c9eb7d3 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py
@@ -170,7 +170,7 @@ def create_dummy(
             cross_attention_past_shape = [
                 batch_size,
                 num_attention_heads,
-                past_decode_sequence_length,
+                encode_sequence_length,
                 head_size,
             ]
 
diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py
index 351173f525727..832f692e9980d 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py
@@ -75,7 +75,7 @@ def create_dummy(
         config: WhisperConfig,
         batch_size: int,
         encode_sequence_length: int,
-        use_decoder_input_ids: int,
+        use_decoder_input_ids: bool,
         device: torch.device,
         use_int32_inputs: bool = False,
     ):  # -> WhisperEncoderDecoderInitInputs:
@@ -125,7 +125,7 @@ def export_onnx(
             model.config,
             batch_size=2,
             encode_sequence_length=3000,
-            use_decoder_input_ids=use_decoder_input_ids,
+            use_decoder_input_ids=True,
             device=device,
             use_int32_inputs=use_int32_inputs,
         )
@@ -159,7 +159,7 @@ def export_onnx(
         hidden_size = str(model.config.d_model)
         head_size = str(model.config.d_model // model.config.encoder_attention_heads)
         dynamic_axes = {
-            "encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"},
+            "encoder_input_ids": {0: "batch_size", 1: "feature_size"},
             "encoder_hidden_states": {
                 0: "batch_size",
                 1: "encode_sequence_length",
diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py
index e2dc79ca247ce..adf7f69470ae7 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py
@@ -6,12 +6,14 @@
 
 import logging
 import os
-import sys
 from pathlib import Path
 from typing import Dict, Tuple, Union
 
 import numpy as np
 import torch
+from float16 import float_to_float16_max_diff
+from onnx_model import OnnxModel
+from optimizer import optimize_model
 from packaging import version
 from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
 from transformers import __version__ as transformers_version
@@ -21,24 +23,20 @@
 
 from onnxruntime import InferenceSession
 
-sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
-from float16 import float_to_float16_max_diff
-from onnx_model import OnnxModel
-from optimizer import optimize_model
-
 logger = logging.getLogger(__name__)
 
 PRETRAINED_WHISPER_MODELS = [
     "whisper-tiny",
     "whisper-tiny.en",
+    "whisper-base",
+    "whisper-base.en",
     "whisper-small",
     "whisper-small.en",
     "whisper-medium",
     "whisper-medium.en",
-    "whisper-base",
-    "whisper-base.en",
     "whisper-large",
     "whisper-large-v2",
+    "whisper-large-v3",
 ]
 
 
@@ -336,7 +334,7 @@ def verify_onnx(
         try:
             from datasets import load_dataset
         except Exception as e:
-            logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True)
+            logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True)  # noqa: G201
             install_cmd = "pip install datasets"
             logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.")
             os.system(install_cmd)
@@ -346,7 +344,12 @@ def verify_onnx(
         ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
         input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
 
-        batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1
+        start_id = [config.decoder_start_token_id]  # ex: [50258]
+        prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
+        prompt_ids = list(map(lambda token: token[1], prompt_ids))  # ex: [50259, 50358, 50363]
+        forced_decoder_ids = start_id + prompt_ids  # ex: [50258, 50259, 50358, 50363]
+
+        batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 30, 0, 1, 1
         length_penalty, repetition_penalty = 1.0, 1.0
         inputs = {
             "input_features": input_features.to(device),
@@ -383,43 +386,51 @@ def verify_onnx(
             elif name == "prefix_vocab_mask":
                 inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype])
             elif name == "decoder_input_ids":
-                raw_input_ids = (
-                    [[config.decoder_start_token_id]]
-                    if use_extra_decoding_ids
-                    else [[config.decoder_start_token_id, 50259, 50359, 50363]]
-                )
+                raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids]
                 inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype])
             elif name == "logits_processor":
                 inputs[name] = np.array([1], dtype=ort_to_np[dtype])
             elif name == "cross_qk_layer_head":
                 inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype])
             elif name == "extra_decoding_ids":
-                inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0)
+                inputs[name] = np.repeat(np.array([prompt_ids], dtype=ort_to_np[dtype]), batch_size, 0)
+            elif name == "temperature":
+                inputs[name] = np.array([1.0], dtype=ort_to_np[dtype])
             else:
                 inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype])
         ort_outputs = ort_session.run(None, inputs)[0][0]
 
-        if pt_outputs.shape != ort_outputs.shape:
-            logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape")
+        expected_transcription_no_comma = (
+            " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
+        )
+        expected_transcription_with_comma = (
+            " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
+        )
+        expected_transcription_with_quote_and_comma = (
+            ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
+        )
+        expected_transcription_options = {
+            expected_transcription_no_comma,
+            expected_transcription_with_comma,
+            expected_transcription_with_quote_and_comma,
+        }
+        pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]
+        ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0]
 
-        diff = pt_outputs - ort_outputs
-        max_diff = max(diff.min(), diff.max(), key=abs)
+        parity = (
+            pt_transcription in expected_transcription_options and ort_transcription in expected_transcription_options
+        )
+        max_diff = 0
 
-        if max_diff > 0:
-            # For ONNX Runtime INT8 model
-            pt_expected_transcription = (
-                " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
-            )
-            pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)
-            ort_expected_transcription = (
-                " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
-            )
-            ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)
+        if not parity:
+            if pt_outputs.shape != ort_outputs.shape:
+                diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])]
+            else:
+                diff = pt_outputs - ort_outputs
+            max_diff = max(diff.min(), diff.max(), key=abs)
 
-            parity = (
-                pt_expected_transcription == pt_transcription[0] and ort_expected_transcription == ort_transcription[0]
-            )
-            if parity:
-                max_diff = 0
+        if max_diff != 0:
+            logger.warning(f"PyTorch outputs: {pt_transcription}")
+            logger.warning(f"ONNX Runtime outputs: {ort_transcription}")
 
         return max_diff
diff --git a/onnxruntime/python/tools/transformers/onnx_model_phi.py b/onnxruntime/python/tools/transformers/onnx_model_phi.py
index e68c3120e3f09..05a27ba487f4d 100644
--- a/onnxruntime/python/tools/transformers/onnx_model_phi.py
+++ b/onnxruntime/python/tools/transformers/onnx_model_phi.py
@@ -80,14 +80,17 @@ def set_attention_op_type(self, attn_op_type: AttentionOpType):
     def get_uname(self, layer_id, name):
         return name + "_" + str(layer_id)
 
-    def get_io_by_name(self, node, name):
-        for input in node.input:
-            if input == name or input.endswith(name) or input.startswith(name):
-                return input
-        for output in node.output:
-            if output == name or output.endswith(name) or output.startswith(name):
-                return output
-        raise Exception(f"input {name} not found in node {node.name}")
+    def get_edge_by_name(self, edges, name):
+        for edge in edges:
+            if edge == name or edge.endswith(name) or edge.startswith(name):
+                return edge
+        raise ValueError(f"Edge {name} not found")
+
+    def get_input_by_name(self, node, name):
+        return self.get_edge_by_name(node.input, name)
+
+    def get_output_by_name(self, node, name):
+        return self.get_edge_by_name(node.output, name)
 
     def process_initializer(self, initializer_name, functor, custom_name=None):
         i = self.model.get_initializer(initializer_name)
@@ -287,7 +290,6 @@ def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
         self.num_attention_heads = num_heads
         self.hidden_size = hidden_size
 
-        self.phi2_edge_dict = self.get_phi2_edge_dict()
         self.func_name = "modeling_phi_PhiModel_model_1"
 
     def get_phi2_edge_dict(self) -> dict:
@@ -296,11 +298,20 @@ def get_phi2_edge_dict(self) -> dict:
         edge_dict["l_input_ids_"] = "input_ids"
         edge_dict["key_states"] = "past_key_0"
         edge_dict["value_states"] = "past_value_0"
-        for i in range(self.num_hidden_layers):
+        for i in range(1, self.num_hidden_layers, 1):
             edge_dict[f"key_states_{i}"] = f"past_key_{i}"
             edge_dict[f"value_states_{i}"] = f"past_value_{i}"
             edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}"
             edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}"
+
+        outputs = [o.name for o in self.model.graph.output]
+        if "model_layers_0_1_1" in outputs and "model_layers_0_1_2" in outputs:
+            edge_dict["model_layers_0_1_1"] = "present_key_0"
+            edge_dict["model_layers_0_1_2"] = "present_value_0"
+        else:
+            assert "model_layers_0_1" in outputs and "model_layers_0_1_1" in outputs
+            edge_dict["model_layers_0_1"] = "present_key_0"
+            edge_dict["model_layers_0_1_1"] = "present_value_0"
         return edge_dict
 
     def simplify_phi2_op_type(self):
@@ -342,8 +353,10 @@ def process_graph_io(self, attn_op_type: AttentionOpType):
                     elem_type=TensorProto.INT64,
                     shape=[1],
                 )
-                new_inputs.extend([vi_iid, vi_step, vi_mask]) if not self.use_vllm else new_inputs.extend(
-                    [vi_iid, vi_pid, vi_meta]
+                (
+                    new_inputs.extend([vi_iid, vi_step, vi_mask])
+                    if not self.use_vllm
+                    else new_inputs.extend([vi_iid, vi_pid, vi_meta])
                 )
             if self.use_attn:
                 if "past_key" in vi.name:
@@ -441,7 +454,7 @@ def preprocess_onnx(self, attn_op_type: AttentionOpType):
                 break
         assert function_name is not None
         self.unroll_function(function_name)
-        self.update_edges(self.phi2_edge_dict)
+        self.update_edges(self.get_phi2_edge_dict())
         self.simplify_phi2_op_type()
         self.remove_dropout_layer()
         if attn_op_type == AttentionOpType.PagedAttention:
@@ -465,7 +478,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
         input = node.input[0]
         output = node.output[0]
 
-        embedding = self.get_io_by_name(node, "embed_tokens.weight")
+        embedding = self.get_input_by_name(node, "embed_tokens.weight")
 
         layer_known_edges_names = [input, output, embedding]
 
@@ -499,8 +512,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
         input = node.input[0]
         output = node.output[0]
 
-        ln_weight = self.get_io_by_name(node, "final_layernorm.weight")
-        ln_bias = self.get_io_by_name(node, "final_layernorm.bias")
+        ln_weight = self.get_input_by_name(node, "final_layernorm.weight")
+        ln_bias = self.get_input_by_name(node, "final_layernorm.bias")
 
         layer_known_edges_names = [input, output, ln_weight, ln_bias]
 
@@ -532,8 +545,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
         input = node.input[2]
         output = node.output[0]
 
-        fc_weight = self.process_initializer(self.get_io_by_name(node, "lm_head.weight"), ProcessGemmWFunc())
-        fc_bias = self.get_io_by_name(node, "lm_head.bias")
+        fc_weight = self.process_initializer(self.get_input_by_name(node, "lm_head.weight"), ProcessGemmWFunc())
+        fc_bias = self.get_input_by_name(node, "lm_head.bias")
 
         layer_known_edges_names = [input, output, fc_weight, fc_bias]
 
@@ -670,15 +683,15 @@ def fuse(
         layer_id = self.get_layer_id(node)
 
         i_hidden_states = node.input[0]
-        i_key_cache = self.get_io_by_name(node, "past_key")
-        i_value_cache = self.get_io_by_name(node, "past_value")
+        i_key_cache = self.get_input_by_name(node, "past_key")
+        i_value_cache = self.get_input_by_name(node, "past_value")
 
-        o_hidden_states = node.output[3]
-        o_key_cache = self.get_io_by_name(node, "present_key")
-        o_value_cache = self.get_io_by_name(node, "present_value")
+        o_hidden_states = node.output[-1]
+        o_key_cache = self.get_output_by_name(node, "present_key")
+        o_value_cache = self.get_output_by_name(node, "present_value")
 
-        ln_weight = self.get_io_by_name(node, "input_layernorm.weight")
-        ln_bias = self.get_io_by_name(node, "input_layernorm.bias")
+        ln_weight = self.get_input_by_name(node, "input_layernorm.weight")
+        ln_bias = self.get_input_by_name(node, "input_layernorm.bias")
 
         attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = (
             None,
@@ -693,45 +706,45 @@ def fuse(
 
         if self.attn_op_type != AttentionOpType.Attention:
             attn_q_weight = self.process_initializer(
-                self.get_io_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc()
+                self.get_input_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc()
             )
             attn_k_weight = self.process_initializer(
-                self.get_io_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc()
+                self.get_input_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc()
             )
             attn_v_weight = self.process_initializer(
-                self.get_io_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc()
+                self.get_input_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc()
             )
-            attn_q_bias = self.get_io_by_name(node, "self_attn.q_proj.bias")
-            attn_k_bias = self.get_io_by_name(node, "self_attn.k_proj.bias")
-            attn_v_bias = self.get_io_by_name(node, "self_attn.v_proj.bias")
+            attn_q_bias = self.get_input_by_name(node, "self_attn.q_proj.bias")
+            attn_k_bias = self.get_input_by_name(node, "self_attn.k_proj.bias")
+            attn_v_bias = self.get_input_by_name(node, "self_attn.v_proj.bias")
 
             cos_cache = self.process_initializer(
-                self.get_io_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc()
+                self.get_input_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc()
             )
             sin_cache = self.process_initializer(
-                self.get_io_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc()
+                self.get_input_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc()
             )
         else:
             attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm(
-                self.get_io_by_name(node, "self_attn.q_proj.weight"),
-                self.get_io_by_name(node, "self_attn.k_proj.weight"),
-                self.get_io_by_name(node, "self_attn.v_proj.weight"),
-                self.get_io_by_name(node, "self_attn.q_proj.bias"),
-                self.get_io_by_name(node, "self_attn.k_proj.bias"),
-                self.get_io_by_name(node, "self_attn.v_proj.bias"),
+                self.get_input_by_name(node, "self_attn.q_proj.weight"),
+                self.get_input_by_name(node, "self_attn.k_proj.weight"),
+                self.get_input_by_name(node, "self_attn.v_proj.weight"),
+                self.get_input_by_name(node, "self_attn.q_proj.bias"),
+                self.get_input_by_name(node, "self_attn.k_proj.bias"),
+                self.get_input_by_name(node, "self_attn.v_proj.bias"),
                 self.get_uname(layer_id, "attn_qkv_weight"),
                 self.get_uname(layer_id, "attn_qkv_bias"),
             )
 
         attn_out_weight = self.process_initializer(
-            self.get_io_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc()
+            self.get_input_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc()
         )
-        attn_out_bias = self.get_io_by_name(node, "self_attn.dense.bias")
+        attn_out_bias = self.get_input_by_name(node, "self_attn.dense.bias")
 
-        mlp_fc1_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc())
-        mlp_fc2_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc())
-        mlp_fc1_bias = self.get_io_by_name(node, "mlp.fc1.bias")
-        mlp_fc2_bias = self.get_io_by_name(node, "mlp.fc2.bias")
+        mlp_fc1_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc())
+        mlp_fc2_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc())
+        mlp_fc1_bias = self.get_input_by_name(node, "mlp.fc1.bias")
+        mlp_fc2_bias = self.get_input_by_name(node, "mlp.fc2.bias")
 
         layer_known_edges_names = []
         layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache])
@@ -771,6 +784,7 @@ def fuse(
             subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_"))
             subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_"))
             subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_"))
+            # vllm engine requires full position ids as the input
             pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step"
             subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_"))
             subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_"))
diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py
index 01298b3576eb1..77e24986f0fde 100644
--- a/onnxruntime/python/tools/transformers/onnx_model_unet.py
+++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py
@@ -127,7 +127,7 @@ def optimize(self, options: Optional[FusionOptions] = None):
 
             with logging_redirect_tqdm():
                 steps = 18
-                progress_bar = tqdm.tqdm(range(0, steps), initial=0, desc="fusion")
+                progress_bar = tqdm.tqdm(range(steps), initial=0, desc="fusion")
                 self._optimize(options, progress_bar)
         else:
             logger.info("tqdm is not installed. Run optimization without progress bar")
diff --git a/onnxruntime/python/tools/transformers/quantize_helper.py b/onnxruntime/python/tools/transformers/quantize_helper.py
index a449e881ad361..6a25196dbc24c 100644
--- a/onnxruntime/python/tools/transformers/quantize_helper.py
+++ b/onnxruntime/python/tools/transformers/quantize_helper.py
@@ -7,7 +7,7 @@
 import logging
 import os
 
-import onnx  # noqa: F401
+import onnx
 import torch
 from transformers.modeling_utils import Conv1D
 
@@ -69,6 +69,7 @@ def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data
             onnx_model_path,
             quantized_model_path,
             use_external_data_format=use_external_data_format,
+            extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT},
         )
         logger.info(f"quantized model saved to:{quantized_model_path}")
         # TODO: inlcude external data in total model size.
diff --git a/onnxruntime/python/tools/transformers/shape_optimizer.py b/onnxruntime/python/tools/transformers/shape_optimizer.py
index ac62188662990..503930b23229f 100644
--- a/onnxruntime/python/tools/transformers/shape_optimizer.py
+++ b/onnxruntime/python/tools/transformers/shape_optimizer.py
@@ -133,9 +133,7 @@ def use_static_input(self, inputs, batch_size=1, max_seq_len=128):
                     dim_proto.dim_value = max_seq_len
                 elif dim_proto.HasField("dim_value") and dim_proto.dim_value != max_seq_len:
                     raise ValueError(
-                        "Unable to set dimension value to {} for axis {} of {}. Contradicts existing dimension value {}.".format(
-                            max_seq_len, 1, input.name, dim_proto.dim_value
-                        )
+                        f"Unable to set dimension value to {max_seq_len} for axis {1} of {input.name}. Contradicts existing dimension value {dim_proto.dim_value}."
                     )
 
     def create_dummy_inputs(
diff --git a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py
index f3e67930adbff..66f24c47f6cdb 100644
--- a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py
+++ b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py
@@ -4,6 +4,7 @@
 # --------------------------------------------------------------------------
 
 import torch
+from torch._C._onnx import OperatorExportTypes
 
 TrainingMode = torch.onnx.TrainingMode
 from packaging.version import Version  # noqa: E402
@@ -18,7 +19,7 @@ def torch_onnx_export(
     training=TrainingMode.EVAL,
     input_names=None,
     output_names=None,
-    operator_export_type=None,
+    operator_export_type=OperatorExportTypes.ONNX,
     opset_version=None,
     _retain_param_name=None,
     do_constant_folding=True,
diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py
index 8bf7cbf80eb37..9dee6564509d5 100644
--- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py
+++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py
@@ -29,5 +29,5 @@ def load_aten_op_executor_cpp_extension():
     from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor
 
     _C.register_aten_op_executor(
-        str(aten_op_executor.is_cpu_argument_address()), str(aten_op_executor.execute_aten_operator_address())
+        str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address())
     )
diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc
index 903a394a06ef3..4148e63d58619 100644
--- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc
+++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc
@@ -34,25 +34,30 @@ struct ATenOperator {
   std::vector<bool> is_optional_arguments;
   std::vector<c10::optional<c10::IValue>> default_values;
   size_t return_size;
+  std::vector<c10::TypeKind> ret_kinds;
 
   c10::IValue ToIValueArgument(const DLManagedTensor* dlpack, size_t index) const {
     TORCH_INTERNAL_ASSERT(index < argument_size);
     bool is_optional = is_optional_arguments[index];
-    TORCH_INTERNAL_ASSERT(dlpack || is_optional || default_values[index]);
+    TORCH_INTERNAL_ASSERT(dlpack || is_optional || default_values[index] ||
+                          elem_kinds[index] == c10::TypeKind::TensorType);
     if (!dlpack) {
       if (is_optional) {
         // Optional argument always has no default value.
         return c10::IValue(c10::nullopt);
       }
-
-      return *default_values[index];
+      if (default_values[index]) {
+        return *default_values[index];
+      }
+      // Fow bw func, it's possible that input is an undefined tensor from fw outputs, dlpack is nullptr for such case.
+      return c10::IValue(at::Tensor());
     }
 
     bool is_list = is_list_arguments[index];
     c10::IValue i_value;
     // Create the torch tensor from this DLPack no matter we need it or not below,
     // so that the dlpack's deleter will be triggered when torch tensor is out of scope.
-    at::Tensor tensor = at::fromDLPack(dlpack);
+    at::Tensor tensor = at::fromDLPack(const_cast<DLManagedTensor*>(dlpack));
     switch (elem_kinds[index]) {
       case c10::TypeKind::TensorType: {
         i_value = is_optional ? c10::IValue(c10::optional<at::Tensor>(tensor)) : c10::IValue(tensor);
@@ -142,7 +147,10 @@ class ATenOperatorCache {
       }
       aten_op.return_size = schema.returns().size();
       for (const auto& ret : schema.returns()) {
-        TORCH_INTERNAL_ASSERT(ret.type()->kind() == c10::TypeKind::TensorType);
+        c10::TypeKind ret_type = ret.type()->kind();
+        // Support tensor or int only for now.
+        TORCH_INTERNAL_ASSERT(ret_type == c10::TypeKind::TensorType || ret_type == c10::TypeKind::IntType);
+        aten_op.ret_kinds.emplace_back(ret_type);
       }
       ops_.emplace(key, aten_op);
     }
@@ -154,32 +162,15 @@ class ATenOperatorCache {
   std::unordered_map<std::pair<std::string, std::string>, ATenOperator, PairHash> ops_;
 };
 
-const std::unordered_map<std::string, std::unordered_set<size_t>> kCpuTensorInputsMap = {
-    {"_efficient_attention_forward", {4, 5, 11, 12}}, {"_efficient_attention_backward", {6, 7, 12, 13}}};
-
-const std::unordered_map<std::string, std::unordered_set<size_t>> kCpuTensorOutputsMap = {
-    {"_efficient_attention_forward", {2, 3}}};
-
-// Backend uses this function to check if an argument is CPU input or not.
-bool IsCpuArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) {
+// Backend uses this function to check if an argument is tensor type or not.
+bool IsTensorArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) {
+  const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name);
   if (is_input) {
-    // If the argument is non-tensor type, it's CPU argument.
-    const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name);
     TORCH_INTERNAL_ASSERT(index < aten_op.argument_size);
-    if (aten_op.elem_kinds[index] != c10::TypeKind::TensorType) {
-      return true;
-    }
-  }
-
-  std::string full_name = std::string(op_name);
-  std::string overload_name_str = std::string(overload_name);
-  if (overload_name_str != "") {
-    full_name += ("." + overload_name_str);
+    return aten_op.elem_kinds[index] == c10::TypeKind::TensorType;
   }
-
-  const auto& cpu_tensors_map = is_input ? kCpuTensorInputsMap : kCpuTensorOutputsMap;
-  return cpu_tensors_map.find(full_name) != cpu_tensors_map.end() &&
-         cpu_tensors_map.at(full_name).find(index) != cpu_tensors_map.at(full_name).end();
+  TORCH_INTERNAL_ASSERT(index < aten_op.return_size);
+  return aten_op.ret_kinds[index] == c10::TypeKind::TensorType;
 }
 
 void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t input_size,
@@ -216,16 +207,23 @@ void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t
   TORCH_INTERNAL_ASSERT(output_size == aten_op.return_size);
   size_t output_index = 0;
   for (const auto& ret : torch::jit::pop(stack, output_size)) {
-    const auto& tensor = ret.toTensor();
-    dlpack_outputs[output_index++] =
-        tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr;
+    if (ret.isTensor()) {
+      const auto& tensor = ret.toTensor();
+      dlpack_outputs[output_index++] =
+          tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr;
+    } else if (ret.isInt()) {
+      at::Tensor scalar = at::scalar_to_tensor(at::Scalar(ret.toInt()));
+      dlpack_outputs[output_index++] = at::toDLPack(scalar);
+    } else {
+      TORCH_INTERNAL_ASSERT(false);
+    }
   }
 }
 
-size_t is_cpu_argument_address() { return reinterpret_cast<size_t>(&IsCpuArgument); }
+size_t is_tensor_argument_address() { return reinterpret_cast<size_t>(&IsTensorArgument); }
 size_t execute_aten_operator_address() { return reinterpret_cast<size_t>(&ExecuteATenOperator); }
 
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-  m.def("is_cpu_argument_address", &is_cpu_argument_address, "Address of tensor argument check.");
+  m.def("is_tensor_argument_address", &is_tensor_argument_address, "Address of tensor argument check.");
   m.def("execute_aten_operator_address", &execute_aten_operator_address, "Address of Aten operator executor");
 }
diff --git a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py
index 329fba5aa670a..7d5716b85db30 100644
--- a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py
+++ b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py
@@ -5,7 +5,7 @@
 
 from onnxruntime.capi import _pybind_state as _C
 
-from .aten_op_executor import execute_aten_operator_address, is_cpu_argument_address
+from .aten_op_executor import execute_aten_operator_address, is_tensor_argument_address
 
 
 def run_once_aten_op_executor(f):
@@ -30,7 +30,7 @@ def aten_op_executor_wrapper(*args, **kwargs):
 
 @run_once_aten_op_executor
 def load_aten_op_executor_cpp_extension():
-    _C.register_aten_op_executor(str(is_cpu_argument_address()), str(execute_aten_operator_address()))
+    _C.register_aten_op_executor(str(is_tensor_argument_address()), str(execute_aten_operator_address()))
 
 
 def init_aten_op_executor():
diff --git a/onnxruntime/test/common/cuda_op_test_utils.cc b/onnxruntime/test/common/cuda_op_test_utils.cc
new file mode 100644
index 0000000000000..bab4e9a60e2ed
--- /dev/null
+++ b/onnxruntime/test/common/cuda_op_test_utils.cc
@@ -0,0 +1,36 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#ifdef USE_CUDA
+#include "cuda_runtime_api.h"
+#endif
+
+namespace onnxruntime {
+namespace test {
+
+int GetCudaArchitecture() {
+  // This will cache the result so we only call cudaGetDeviceProperties once.
+  // Usually, we test on a single GPU or multiple GPUs of same architecture, so it's fine to cache the result.
+  static int cuda_arch = -1;
+
+#ifdef USE_CUDA
+  if (cuda_arch == -1) {
+    int current_device_id = 0;
+    cudaGetDevice(&current_device_id);
+    // must wait GPU idle, otherwise cudaGetDeviceProperties might fail
+    cudaDeviceSynchronize();
+    cudaDeviceProp prop;
+
+    // When cudaGetDeviceProperties fails, just return -1 and no error is raised.
+    // If cuda device has issue, test will fail anyway so no need to raise error here.
+    if (cudaSuccess == cudaGetDeviceProperties(&prop, current_device_id)) {
+      cuda_arch = prop.major * 100 + prop.minor * 10;
+    }
+  }
+#endif
+
+  return cuda_arch;
+}
+
+}  // namespace test
+}  // namespace onnxruntime
diff --git a/onnxruntime/test/common/cuda_op_test_utils.h b/onnxruntime/test/common/cuda_op_test_utils.h
index 043e3059c38d7..6f3e460628566 100644
--- a/onnxruntime/test/common/cuda_op_test_utils.h
+++ b/onnxruntime/test/common/cuda_op_test_utils.h
@@ -4,37 +4,20 @@
 #pragma once
 
 #include "test/util/include/default_providers.h"
-#ifdef USE_CUDA
-#include "cuda_runtime_api.h"
-#endif
 
 namespace onnxruntime {
 namespace test {
 
+// CUDA architecture of the current device like 100 * major + 10 * minor.
+// Please call this function after CUDA EP is enabled.
+int GetCudaArchitecture();
+
 inline bool HasCudaEnvironment(int min_cuda_architecture) {
   if (DefaultCudaExecutionProvider().get() == nullptr) {
     return false;
   }
 
-  if (min_cuda_architecture == 0) {
-    return true;
-  }
-
-  int cuda_architecture = 0;
-
-#ifdef USE_CUDA
-  int currentCudaDevice = 0;
-  cudaGetDevice(&currentCudaDevice);
-  cudaDeviceSynchronize();
-  cudaDeviceProp prop;
-  if (cudaSuccess != cudaGetDeviceProperties(&prop, currentCudaDevice)) {
-    return false;
-  }
-
-  cuda_architecture = prop.major * 100 + prop.minor * 10;
-#endif
-
-  return cuda_architecture >= min_cuda_architecture;
+  return GetCudaArchitecture() >= min_cuda_architecture;
 }
 
 inline bool NeedSkipIfCudaArchLowerThan(int min_cuda_architecture) {
diff --git a/onnxruntime/test/common/trt_op_test_utils.h b/onnxruntime/test/common/trt_op_test_utils.h
new file mode 100644
index 0000000000000..a0b0b9bb1931f
--- /dev/null
+++ b/onnxruntime/test/common/trt_op_test_utils.h
@@ -0,0 +1,33 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "test/common/cuda_op_test_utils.h"
+
+namespace onnxruntime {
+namespace test {
+
+// TensorRT EP Segmentation fault on A100: https://github.com/microsoft/onnxruntime/issues/19530
+inline const std::unordered_set<std::string> ExcludeTrtOnA100() {
+  // Note: GetCudaArchitecture need USE_CUDA to be defined. Currently, it is defined when TRT EP is enabled.
+  // If we want to make TRT EP independent of CUDA EP, we need to change the implementation of GetCudaArchitecture.
+  if (DefaultTensorrtExecutionProvider() != nullptr && GetCudaArchitecture() == 800) {
+    return {kTensorrtExecutionProvider};
+  }
+
+  return {};
+}
+
+// Add TensorRT EP to an excluded provider list when running on A100
+inline const std::unordered_set<std::string>& ExcludeTrtOnA100(std::unordered_set<std::string>& excluded_providers) {
+  if (DefaultTensorrtExecutionProvider() != nullptr && GetCudaArchitecture() == 800) {
+    excluded_providers.insert(kTensorrtExecutionProvider);
+    return excluded_providers;
+  }
+
+  return excluded_providers;
+}
+
+}  // namespace test
+}  // namespace onnxruntime
diff --git a/onnxruntime/test/contrib_ops/activation_op_test.cc b/onnxruntime/test/contrib_ops/activation_op_test.cc
index b1e54ec605a39..061fffa572be2 100644
--- a/onnxruntime/test/contrib_ops/activation_op_test.cc
+++ b/onnxruntime/test/contrib_ops/activation_op_test.cc
@@ -22,7 +22,8 @@ namespace test {
 TEST_F(ActivationOpTest, ThresholdedRelu_version_1_to_9) {
   float alpha = 0.1f;
   TestActivationOp<float>(
-      "ThresholdedRelu", input_values, [alpha](float x) { return (x >= alpha) ? x : 0; }, {{"alpha", alpha}}, true, 1);
+      "ThresholdedRelu", input_values, [alpha](float x) { return (x >= alpha) ? x : 0; }, {{"alpha", alpha}}, {},
+      true, 1);
 }
 
 TEST_F(ActivationOpTest, ScaledTanh) {
@@ -46,14 +47,18 @@ TEST_F(ActivationOpTest, ParametricSoftplus) {
         else
           return alpha * logf(expf(bx) + 1);
       },
-      {{"alpha", alpha}, {"beta", beta}}, false);  // Disable TensorRT due to result mismatch
+      {{"alpha", alpha}, {"beta", beta}}, {}, false);  // Disable TensorRT due to result mismatch
 }
 
+// [TODO] Temporarily ignore this test for OpenVINO
+// Fails due to accuracy mismatch
+#if !defined(USE_OPENVINO)
 TEST_F(ActivationOpTest, Gelu) {
   TestActivationOp<float>(
       "Gelu", input_values, [](float x) { return x * 0.5f * (1.0f + std::erf(x * static_cast<float>(M_SQRT1_2))); }, {},
-      false, 1, kMSDomain);
+      {}, false, 1, kMSDomain);
 }
+#endif
 
 #if defined(USE_DNNL)
 std::vector<BFloat16> expected_output_bfloat16(const std::vector<float>& input_data) {
@@ -115,7 +120,7 @@ TEST_F(ActivationOpTest, QuickGelu) {
           y = tmp >= 0 ? y : 1 - y;
           return x * y;
         },
-        {{"alpha", alpha}}, false, 1, kMSDomain);
+        {{"alpha", alpha}}, {}, false, 1, kMSDomain);
   }
 
   // Silu = x*sigmoid(x), i.e., alpha = 1.0f.
@@ -129,7 +134,7 @@ TEST_F(ActivationOpTest, QuickGelu) {
           y = tmp >= 0 ? y : 1 - y;
           return x * y;
         },
-        {{"alpha", alpha}}, false, 1, kMSDomain);
+        {{"alpha", alpha}}, {}, false, 1, kMSDomain);
   }
 
   // Negative alpha.
@@ -143,7 +148,7 @@ TEST_F(ActivationOpTest, QuickGelu) {
           y = tmp >= 0 ? y : 1 - y;
           return x * y;
         },
-        {{"alpha", alpha}}, false, 1, kMSDomain);
+        {{"alpha", alpha}}, {}, false, 1, kMSDomain);
   }
 }
 
diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc
index b652e0723f5aa..a8e2fccdd0462 100644
--- a/onnxruntime/test/contrib_ops/attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/attention_op_test.cc
@@ -227,6 +227,12 @@ static void RunAttentionTest(
       tester.AddOptionalInputEdge<int32_t>();
     }
 
+    if (use_float16) {
+      tester.SetOutputTolerance(0.005f);
+    } else {
+      tester.SetOutputTolerance(0.001f, 0.001f);
+    }
+
     if (enable_cuda) {
       std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
       execution_providers.push_back(DefaultCudaExecutionProvider());
@@ -254,6 +260,9 @@ static void RunAttentionTest(
     if (enable_dml) {
       std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
       execution_providers.push_back(DefaultDmlExecutionProvider());
+      if (use_float16) {
+        tester.SetOutputTolerance(0.02f);
+      }
       tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
     }
   }
@@ -2013,13 +2022,6 @@ TEST(AttentionTest, AttentionMaskIndexOutOfRange) {
 #if !defined(__wasm__)
 // TODO: fix in web assembly
 TEST(AttentionTest, AttentionPastState_dynamic) {
-  // ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test.
-  // Do not run this test unless TF32 is disabled explicitly.
-  if (HasCudaEnvironment(800) && ParseEnvironmentVariableWithDefault<int>("NVIDIA_TF32_OVERRIDE", 1) != 0) {
-    GTEST_SKIP() << "Skipping AttentionPastState_dynamic in A100 since TF32 is enabled";
-    return;
-  }
-
   // create rand inputs
   RandomValueGenerator random{};
 
@@ -2101,13 +2103,6 @@ static void RunModelWithRandomInput(
     std::vector<int32_t>& mask_index_data,
     std::string& onnx_model,
     bool is_float16) {
-  // ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test.
-  // Do not run this test unless TF32 is disabled explicitly.
-  if (HasCudaEnvironment(800) && ParseEnvironmentVariableWithDefault<int>("NVIDIA_TF32_OVERRIDE", 1) != 0) {
-    GTEST_SKIP() << "Skipping RunModelWithRandomInput in A100 since TF32 is enabled";
-    return;
-  }
-
   RandomValueGenerator random{234};
 
   constexpr int hidden_size = 768;
diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc
index 156ed3799fc22..6ce9f5de68f11 100644
--- a/onnxruntime/test/contrib_ops/beam_search_test.cc
+++ b/onnxruntime/test/contrib_ops/beam_search_test.cc
@@ -8,6 +8,10 @@
 #include "core/session/onnxruntime_cxx_api.h"
 #include "test/common/cuda_op_test_utils.h"
 
+#ifdef USE_CUDA
+#include "core/providers/cuda/cuda_provider_options.h"
+#endif
+
 extern std::unique_ptr<Ort::Env> ort_env;
 
 namespace onnxruntime {
@@ -70,7 +74,9 @@ TEST(BeamSearchTest, GptBeamSearchFp32) {
 
   Ort::SessionOptions session_options;
 #ifdef USE_CUDA
-  Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
+  OrtCUDAProviderOptionsV2 cuda_options;
+  cuda_options.use_tf32 = false;
+  session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
 #endif
 
 #ifdef USE_ROCM
@@ -161,7 +167,9 @@ TEST(BeamSearchTest, GptBeamSearchFp16) {
   if (enable_cuda || enable_rocm) {
     Ort::SessionOptions session_options;
 #ifdef USE_CUDA
-    Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
+    OrtCUDAProviderOptionsV2 cuda_options;
+    cuda_options.use_tf32 = false;
+    session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
 #endif
 
 #ifdef USE_ROCM
@@ -254,7 +262,9 @@ TEST(BeamSearchTest, GptBeamSearchWithInitDecoderFp16) {
   if (enable_cuda || enable_rocm) {
     Ort::SessionOptions session_options;
 #ifdef USE_CUDA
-    Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
+    OrtCUDAProviderOptionsV2 cuda_options;
+    cuda_options.use_tf32 = false;
+    session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
 #endif
 
 #ifdef USE_ROCM
@@ -346,7 +356,9 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) {
   if (enable_cuda || enable_rocm) {
     Ort::SessionOptions session_options;
 #ifdef USE_CUDA
-    Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
+    OrtCUDAProviderOptionsV2 cuda_options;
+    cuda_options.use_tf32 = false;
+    session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
 #endif
 
 #ifdef USE_ROCM
diff --git a/onnxruntime/test/contrib_ops/decoder_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_attention_op_test.cc
index 88a2bdf6a4849..8a37ef921fd2b 100644
--- a/onnxruntime/test/contrib_ops/decoder_attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/decoder_attention_op_test.cc
@@ -31,10 +31,8 @@ static void RunAttentionTest(
     const std::vector<float>* new_value_cache = nullptr,
     const std::vector<float>* key_cache = nullptr,
     const std::vector<float>* value_cache = nullptr,
-    const std::initializer_list<bool>* key_padding_mask_data = nullptr,
-    bool use_float16 = false) {
-  int min_cuda_architecture = use_float16 ? 530 : 0;
-  bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
+    const std::initializer_list<bool>* key_padding_mask_data = nullptr) {
+  bool enable_cuda = HasCudaEnvironment(0);
   bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get());
   bool enable_cpu = false;
 
@@ -99,6 +97,7 @@ static void RunAttentionTest(
       tester.AddOutput<float>("new_key_cache", output_cache_dims, *new_key_cache);
       tester.AddOutput<float>("new_value_cache", output_cache_dims, *new_value_cache);
     }
+    tester.SetOutputTolerance(0.001f, 0.001f);
 
     std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
     if (enable_cuda) {
diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc
index 6afb61bd1f0a1..17c9e8592f64e 100644
--- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc
@@ -463,12 +463,12 @@ std::vector<MLFloat16> QK_Transpose(MLFloat16* q_matrix, MLFloat16* k_transpose_
 
 // Softmax_QK_Transpose
 template <typename T>
-std::vector<T> Softmax_QK_Transpose(T* qk_transpose_matrix,
-                                    int batch_size, int num_heads, int sequence_length, int total_sequence_length, int head_size);
+std::vector<T> Softmax_QK_Transpose(T* qk_transpose_matrix, int batch_size, int num_heads,
+                                    int sequence_length, int total_sequence_length, int head_size);
 
 template <>
-std::vector<float> Softmax_QK_Transpose(float* qk_transpose_matrix,
-                                        int batch_size, int num_heads, int sequence_length, int total_sequence_length, int head_size) {
+std::vector<float> Softmax_QK_Transpose(float* qk_transpose_matrix, int batch_size, int num_heads,
+                                        int sequence_length, int total_sequence_length, int /*head_size*/) {
   if (sequence_length != 1) {
     throw std::runtime_error("Not supported");
   }
@@ -506,8 +506,8 @@ std::vector<float> Softmax_QK_Transpose(float* qk_transpose_matrix,
 }
 
 template <>
-std::vector<MLFloat16> Softmax_QK_Transpose(MLFloat16* qk_transpose_matrix,
-                                            int batch_size, int num_heads, int sequence_length, int total_sequence_length, int head_size) {
+std::vector<MLFloat16> Softmax_QK_Transpose(MLFloat16* qk_transpose_matrix, int batch_size, int num_heads,
+                                            int sequence_length, int total_sequence_length, int /*head_size*/) {
   if (sequence_length != 1) {
     throw std::runtime_error("Not supported");
   }
@@ -640,248 +640,283 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) {
     return;
   }
 
-  // Vary batch size
-  for (int batch_size = 1; batch_size <= 5; batch_size += 2) {
-    // Vary kv_lengths
-    for (int past_sequence_length = 1; past_sequence_length <= 3000; past_sequence_length += 150) {
-      int sequence_length = 1;
-      int number_of_heads = 12;
-      // Vary head_size / hidden_size
-      int hidden_sizes[3] = {384, 768, 1536};
-      for (int hidden_size : hidden_sizes) {
-        int head_size = (hidden_size / number_of_heads);
-        int total_sequence_length = sequence_length + past_sequence_length;
-        int max_sequence_length = past_sequence_length + 1;  // Always keep >  past_sequence_length
+  // Buckets for test data:
+  // batch_size: 1, >=2
+  // past_sequence_length 0~30, 31~2046, >=2047 (so that total_sequence_length: 1~31, 32~2047, >=2048)
+  // head_size: 32, 64, 128
+  struct MyTestCase {
+    int batch_size;
+    int past_sequence_length;
+    int hidden_size;
+  } test_cases[] = {
+      {1, 0, 768},
+      {1, 1, 384},
+      {2, 30, 768},
+      {3, 31, 1536},
+      {4, 512, 384},
+      {1, 1024, 768},
+      {1, 2046, 1536},
+      {2, 2047, 384},
+      {3, 3000, 768},
+  };
+
+  constexpr int sequence_length = 1;
+  constexpr int number_of_heads = 12;
+
+  for (MyTestCase test_case : test_cases) {
+    int batch_size = test_case.batch_size;
+    int past_sequence_length = test_case.past_sequence_length;
+    int hidden_size = test_case.hidden_size;
+
+    int head_size = (hidden_size / number_of_heads);
+    int total_sequence_length = sequence_length + past_sequence_length;
+    int max_sequence_length = past_sequence_length + 1;  // Always keep >  past_sequence_length
+
+    OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain);
+    tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(number_of_heads));
+    tester.AddAttribute<int64_t>("past_present_share_buffer", static_cast<int64_t>(1));
+
+    std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
+    std::vector<int64_t> weights_dims = {hidden_size, 3 * hidden_size};
+    std::vector<int64_t> bias_dims = {3 * hidden_size};
+    std::vector<int64_t> output_dims = {batch_size, sequence_length, hidden_size};
+
+    auto input = CreateRandom<float>(batch_size * sequence_length * hidden_size);
+    tester.AddInput<float>("input", input_dims, input);
+
+    auto weight = CreateRandom<float>(hidden_size * 3 * hidden_size);
+    tester.AddInput<float>("weight", weights_dims, weight);
+
+    auto bias = CreateRandom<float>(3 * hidden_size);
+    tester.AddInput<float>("bias", bias_dims, bias);
+
+    // Mask
+    tester.AddOptionalInputEdge<int32_t>();
+
+    // Past
+    std::vector<int64_t> past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size};
+    int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size;
+
+    auto kv_cache = CreateRandom<float>(past_present_size);
+
+    auto reordered_kv_cache = ReorderKVCache<float>(kv_cache, batch_size,
+                                                    number_of_heads, past_sequence_length, head_size, max_sequence_length);
+
+    // Validate if reordering went well - by transposing and checking equality
+    int chunk_size = 16 / sizeof(float);
+    int num_chunks = head_size / chunk_size;
+    auto transposed = Transpose<float>(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size);
+    CheckEquality<float>(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks,
+                         max_sequence_length, past_sequence_length, chunk_size);
+
+    tester.AddInput<float>("past", past_dims, reordered_kv_cache);
+
+    // Rel
+    tester.AddOptionalInputEdge<float>();
+
+    // Past sequence length
+    std::vector<int32_t> arr_past_sequence_len(1, past_sequence_length);
+    tester.AddInput<int32_t>("past_sequence_length", {1}, arr_past_sequence_len);
+
+    // QKV MatMul
+    auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size);
+    auto* qkv_matrix = qkv.data();
+
+    auto pair = MergePastKWithPresentKAndTranspose<float>(kv_cache.data(), qkv_matrix + hidden_size, batch_size,
+                                                          number_of_heads, past_sequence_length,
+                                                          max_sequence_length, head_size);
+
+    auto k_merged = pair.first;
+    auto k_transpose = pair.second;
+
+    auto qk_transpose = QK_Transpose<float>(qkv_matrix, k_transpose.data(), batch_size, number_of_heads,
+                                            total_sequence_length, head_size);
+
+    auto softmax_qk_transpose = Softmax_QK_Transpose<float>(qk_transpose.data(), batch_size, number_of_heads,
+                                                            sequence_length, total_sequence_length, head_size);
+
+    auto present = MergeReorderedKVCacheWithK<float>(reordered_kv_cache, qkv_matrix + hidden_size, batch_size,
+                                                     number_of_heads, past_sequence_length, max_sequence_length, head_size);
+
+    // Validate our test logic
+    // We want to validate if our merged "unordered" K is the same as
+    // the merged "ordered" K so that the QKT we do in our test code
+    // is equivalent to the QKT we do in the kernel
+    ValidateReorderedMergedKWithK<float>(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size);
+
+    MergeReorderedKVCacheWithV<float>(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size,
+                                      number_of_heads, past_sequence_length, max_sequence_length, head_size);
+
+    auto output = Softmax_QK_Transpose_V<float>(softmax_qk_transpose.data(), present.data() + (past_present_size / 2),
+                                                batch_size, number_of_heads,
+                                                sequence_length, total_sequence_length,
+                                                max_sequence_length, head_size);
 
-        OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain);
-        tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(number_of_heads));
-        tester.AddAttribute<int64_t>("past_present_share_buffer", static_cast<int64_t>(1));
+    // Output(s)
+    tester.AddOutput<float>("output", input_dims, output);
+    tester.AddOutput<float>("present", past_dims, present);
 
-        std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
-        std::vector<int64_t> weights_dims = {hidden_size, 3 * hidden_size};
-        std::vector<int64_t> bias_dims = {3 * hidden_size};
-        std::vector<int64_t> output_dims = {batch_size, sequence_length, hidden_size};
+    tester.SetOutputTolerance(0.001f, 0.001f);
 
-        auto input = CreateRandom<float>(batch_size * sequence_length * hidden_size);
-        tester.AddInput<float>("input", input_dims, input);
-
-        auto weight = CreateRandom<float>(hidden_size * 3 * hidden_size);
-        tester.AddInput<float>("weight", weights_dims, weight);
+    // Run - Regular kernel execution path
+    {
+      std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+      execution_providers.push_back(DefaultCudaExecutionProvider());
+      tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+    }
 
-        auto bias = CreateRandom<float>(3 * hidden_size);
-        tester.AddInput<float>("bias", bias_dims, bias);
+    // Test alternate kernel path of loading more KV data "in flight"
+    {
+      ScopedEnvironmentVariables scoped_env_vars{
+          EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}};
 
-        // Mask
-        tester.AddOptionalInputEdge<int32_t>();
+      std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+      execution_providers.push_back(DefaultCudaExecutionProvider());
 
-        // Past
-        std::vector<int64_t> past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size};
-        int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size;
+      tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+    }
+  }
+}
 
-        auto kv_cache = CreateRandom<float>(past_present_size);
+TEST(DecoderMaskedSelfAttentionTest, Test_fp16) {
+  // The kernel is only supported on CC 5.3 or higher GPUs
+  if (NeedSkipIfCudaArchLowerThan(530)) {
+    return;
+  }
 
-        auto reordered_kv_cache = ReorderKVCache<float>(kv_cache, batch_size,
+  // Buckets for test data:
+  // batch_size: 1, >=2
+  // past_sequence_length 0, 1~30, 31~2046, >=2047 (so that total_sequence_length: 1, 2-31, 32~2047, >=2048)
+  // head_size: 32, 64, 128
+  struct MyTestCase {
+    int batch_size;
+    int past_sequence_length;
+    int hidden_size;
+  } test_cases[] = {
+      {1, 0, 768},
+      {1, 1, 768},
+      {3, 30, 384},
+      {8, 31, 1536},
+      {4, 256, 384},
+      {3, 1024, 768},
+      {2, 2046, 1536},
+      {1, 2047, 384},
+      {2, 3000, 768},
+  };
+
+  constexpr int sequence_length = 1;
+  constexpr int number_of_heads = 12;
+
+  for (MyTestCase test_case : test_cases) {
+    int batch_size = test_case.batch_size;
+    int past_sequence_length = test_case.past_sequence_length;
+    int hidden_size = test_case.hidden_size;
+
+    int head_size = (hidden_size / number_of_heads);
+    int total_sequence_length = sequence_length + past_sequence_length;
+    int max_sequence_length = past_sequence_length + 1;  // Always keep >  past_sequence_length
+
+    OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain);
+    tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(number_of_heads));
+    tester.AddAttribute<int64_t>("past_present_share_buffer", static_cast<int64_t>(1));
+
+    std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
+    std::vector<int64_t> weights_dims = {hidden_size, 3 * hidden_size};
+    std::vector<int64_t> bias_dims = {3 * hidden_size};
+    std::vector<int64_t> output_dims = {batch_size, sequence_length, hidden_size};
+
+    auto input = CreateRandom<MLFloat16>(batch_size * sequence_length * hidden_size);
+    tester.AddInput<MLFloat16>("input", input_dims, input);
+
+    auto weight = CreateRandom<MLFloat16>(hidden_size * 3 * hidden_size);
+    tester.AddInput<MLFloat16>("weight", weights_dims, weight);
+
+    auto bias = CreateRandom<MLFloat16>(3 * hidden_size);
+    tester.AddInput<MLFloat16>("bias", bias_dims, bias);
+
+    // Mask
+    tester.AddOptionalInputEdge<int32_t>();
+
+    // Past
+    std::vector<int64_t> past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size};
+    int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size;
+
+    auto kv_cache = CreateRandom<MLFloat16>(past_present_size);
+
+    auto reordered_kv_cache = ReorderKVCache<MLFloat16>(kv_cache, batch_size,
                                                         number_of_heads, past_sequence_length, head_size, max_sequence_length);
 
-        // Validate if reordering went well - by transposing and checking equality
-        int chunk_size = 16 / sizeof(float);
-        int num_chunks = head_size / chunk_size;
-        auto transposed = Transpose<float>(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size);
-        CheckEquality<float>(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks,
+    // Validate if reordering went well - by transposing and checking equality
+    int chunk_size = 16 / sizeof(MLFloat16);
+    int num_chunks = head_size / chunk_size;
+    auto transposed = Transpose<MLFloat16>(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size);
+    CheckEquality<MLFloat16>(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks,
                              max_sequence_length, past_sequence_length, chunk_size);
 
-        tester.AddInput<float>("past", past_dims, reordered_kv_cache);
+    tester.AddInput<MLFloat16>("past", past_dims, reordered_kv_cache);
 
-        // Rel
-        tester.AddOptionalInputEdge<float>();
+    // Rel
+    tester.AddOptionalInputEdge<MLFloat16>();
 
-        // Past sequence length
-        std::vector<int32_t> arr_past_sequence_len(1, past_sequence_length);
-        tester.AddInput<int32_t>("past_sequence_length", {1}, arr_past_sequence_len);
+    // Past sequence length
+    std::vector<int32_t> arr_past_sequence_len(1, past_sequence_length);
+    tester.AddInput<int32_t>("past_sequence_length", {1}, arr_past_sequence_len);
 
-        // QKV MatMul
-        auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size);
-        auto* qkv_matrix = qkv.data();
+    // QKV MatMul
+    auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size);
+    auto* qkv_matrix = qkv.data();
 
-        auto pair = MergePastKWithPresentKAndTranspose<float>(kv_cache.data(), qkv_matrix + hidden_size, batch_size,
+    auto pair = MergePastKWithPresentKAndTranspose<MLFloat16>(kv_cache.data(), qkv_matrix + hidden_size, batch_size,
                                                               number_of_heads, past_sequence_length,
                                                               max_sequence_length, head_size);
 
-        auto k_merged = pair.first;
-        auto k_transpose = pair.second;
+    auto k_merged = pair.first;
+    auto k_transpose = pair.second;
 
-        auto qk_transpose = QK_Transpose<float>(qkv_matrix, k_transpose.data(), batch_size, number_of_heads,
+    auto qk_transpose = QK_Transpose<MLFloat16>(qkv_matrix, k_transpose.data(), batch_size, number_of_heads,
                                                 total_sequence_length, head_size);
 
-        auto softmax_qk_transpose = Softmax_QK_Transpose<float>(qk_transpose.data(), batch_size, number_of_heads,
+    auto softmax_qk_transpose = Softmax_QK_Transpose<MLFloat16>(qk_transpose.data(), batch_size, number_of_heads,
                                                                 sequence_length, total_sequence_length, head_size);
 
-        auto present = MergeReorderedKVCacheWithK<float>(reordered_kv_cache, qkv_matrix + hidden_size, batch_size,
+    auto present = MergeReorderedKVCacheWithK<MLFloat16>(reordered_kv_cache, qkv_matrix + hidden_size, batch_size,
                                                          number_of_heads, past_sequence_length, max_sequence_length, head_size);
 
-        // Validate our test logic
-        // We want to validate if our merged "unordered" K is the same as
-        // the merged "ordered" K so that the QKT we do in our test code
-        // is equivalent to the QKT we do in the kernel
-        ValidateReorderedMergedKWithK<float>(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size);
+    // Validate our test logic
+    // We want to validate if our merged "unordered" K is the same as
+    // the merged "ordered" K so that the QKT we do in our test code
+    // is equivalent to the QKT we do in the kernel
+    ValidateReorderedMergedKWithK<MLFloat16>(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size);
 
-        MergeReorderedKVCacheWithV<float>(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size,
+    MergeReorderedKVCacheWithV<MLFloat16>(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size,
                                           number_of_heads, past_sequence_length, max_sequence_length, head_size);
 
-        auto output = Softmax_QK_Transpose_V<float>(softmax_qk_transpose.data(), present.data() + (past_present_size / 2),
-                                                    batch_size, number_of_heads,
-                                                    sequence_length, total_sequence_length,
-                                                    max_sequence_length, head_size);
-
-        // Output(s)
-        tester.AddOutput<float>("output", input_dims, output);
-
-        tester.AddOutput<float>("present", past_dims, present);
-
-        // Run - Regular kernel execution path
-        {
-          std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
-          execution_providers.push_back(DefaultCudaExecutionProvider());
-          tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
-        }
+    auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2),
+                                         batch_size, number_of_heads,
+                                         sequence_length, total_sequence_length,
+                                         max_sequence_length, head_size);
 
-        // Test alternate kernel path of loading more KV data "in flight"
-        {
-          ScopedEnvironmentVariables scoped_env_vars{
-              EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}};
+    // Output(s)
+    tester.AddOutput<MLFloat16>("output", input_dims, output);
+    tester.AddOutput<MLFloat16>("present", past_dims, present);
 
-          std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
-          execution_providers.push_back(DefaultCudaExecutionProvider());
+    tester.SetOutputTolerance(0.005f);
 
-          tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
-        }
-      }
+    // Run - Regular kernel execution path
+    {
+      std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+      execution_providers.push_back(DefaultCudaExecutionProvider());
+      tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
     }
-  }
-}
-
-TEST(DecoderMaskedSelfAttentionTest, Test_fp16) {
-  // The kernel is only supported on CC 5.3 or higher GPUs
-  if (NeedSkipIfCudaArchLowerThan(530)) {
-    return;
-  }
-
-  // Vary batch size
-  for (int batch_size = 1; batch_size <= 5; batch_size += 2) {
-    // Vary kv_lengths
-    for (int past_sequence_length = 1; past_sequence_length <= 3000; past_sequence_length += 150) {
-      int sequence_length = 1;
-      int number_of_heads = 12;
-
-      // Vary head_size / hidden_size
-      int hidden_sizes[3] = {384, 768, 1536};
-      for (int hidden_size : hidden_sizes) {
-        int head_size = (hidden_size / number_of_heads);
-        int total_sequence_length = sequence_length + past_sequence_length;
-        int max_sequence_length = past_sequence_length + 1;  // Always keep >  past_sequence_length
-
-        OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain);
-        tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(number_of_heads));
-        tester.AddAttribute<int64_t>("past_present_share_buffer", static_cast<int64_t>(1));
 
-        std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
-        std::vector<int64_t> weights_dims = {hidden_size, 3 * hidden_size};
-        std::vector<int64_t> bias_dims = {3 * hidden_size};
-        std::vector<int64_t> output_dims = {batch_size, sequence_length, hidden_size};
+    // Test alternate kernel path of loading more KV data "in flight"
+    {
+      ScopedEnvironmentVariables scoped_env_vars{
+          EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}};
 
-        auto input = CreateRandom<MLFloat16>(batch_size * sequence_length * hidden_size);
-        tester.AddInput<MLFloat16>("input", input_dims, input);
-
-        auto weight = CreateRandom<MLFloat16>(hidden_size * 3 * hidden_size);
-        tester.AddInput<MLFloat16>("weight", weights_dims, weight);
-
-        auto bias = CreateRandom<MLFloat16>(3 * hidden_size);
-        tester.AddInput<MLFloat16>("bias", bias_dims, bias);
-
-        // Mask
-        tester.AddOptionalInputEdge<int32_t>();
-
-        // Past
-        std::vector<int64_t> past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size};
-        int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size;
-
-        auto kv_cache = CreateRandom<MLFloat16>(past_present_size);
-
-        auto reordered_kv_cache = ReorderKVCache<MLFloat16>(kv_cache, batch_size,
-                                                            number_of_heads, past_sequence_length, head_size, max_sequence_length);
-
-        // Validate if reordering went well - by transposing and checking equality
-        int chunk_size = 16 / sizeof(MLFloat16);
-        int num_chunks = head_size / chunk_size;
-        auto transposed = Transpose<MLFloat16>(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size);
-        CheckEquality<MLFloat16>(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks,
-                                 max_sequence_length, past_sequence_length, chunk_size);
-
-        tester.AddInput<MLFloat16>("past", past_dims, reordered_kv_cache);
-
-        // Rel
-        tester.AddOptionalInputEdge<MLFloat16>();
-
-        // Past sequence length
-        std::vector<int32_t> arr_past_sequence_len(1, past_sequence_length);
-        tester.AddInput<int32_t>("past_sequence_length", {1}, arr_past_sequence_len);
-
-        // QKV MatMul
-        auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size);
-        auto* qkv_matrix = qkv.data();
-
-        auto pair = MergePastKWithPresentKAndTranspose<MLFloat16>(kv_cache.data(), qkv_matrix + hidden_size, batch_size,
-                                                                  number_of_heads, past_sequence_length,
-                                                                  max_sequence_length, head_size);
-
-        auto k_merged = pair.first;
-        auto k_transpose = pair.second;
-
-        auto qk_transpose = QK_Transpose<MLFloat16>(qkv_matrix, k_transpose.data(), batch_size, number_of_heads,
-                                                    total_sequence_length, head_size);
-
-        auto softmax_qk_transpose = Softmax_QK_Transpose<MLFloat16>(qk_transpose.data(), batch_size, number_of_heads,
-                                                                    sequence_length, total_sequence_length, head_size);
-
-        auto present = MergeReorderedKVCacheWithK<MLFloat16>(reordered_kv_cache, qkv_matrix + hidden_size, batch_size,
-                                                             number_of_heads, past_sequence_length, max_sequence_length, head_size);
-
-        // Validate our test logic
-        // We want to validate if our merged "unordered" K is the same as
-        // the merged "ordered" K so that the QKT we do in our test code
-        // is equivalent to the QKT we do in the kernel
-        ValidateReorderedMergedKWithK<MLFloat16>(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size);
-
-        MergeReorderedKVCacheWithV<MLFloat16>(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size,
-                                              number_of_heads, past_sequence_length, max_sequence_length, head_size);
-
-        auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2),
-                                             batch_size, number_of_heads,
-                                             sequence_length, total_sequence_length,
-                                             max_sequence_length, head_size);
-
-        // Output(s)
-        tester.AddOutput<MLFloat16>("output", input_dims, output);
-
-        tester.AddOutput<MLFloat16>("present", past_dims, present);
-
-        // Run - Regular kernel execution path
-        {
-          std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
-          execution_providers.push_back(DefaultCudaExecutionProvider());
-          tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
-        }
-
-        // Test alternate kernel path of loading more KV data "in flight"
-        {
-          ScopedEnvironmentVariables scoped_env_vars{
-              EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}};
-
-          std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
-          execution_providers.push_back(DefaultCudaExecutionProvider());
-          tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
-        }
-      }
+      std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+      execution_providers.push_back(DefaultCudaExecutionProvider());
+      tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
     }
   }
 }
@@ -889,4 +924,4 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) {
 #endif
 
 }  // namespace test
-}  // namespace onnxruntime
\ No newline at end of file
+}  // namespace onnxruntime
diff --git a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc
index c70f659f1b645..0b64ea3de8ded 100644
--- a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc
+++ b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc
@@ -23,20 +23,85 @@ namespace onnxruntime {
 namespace test {
 
 template <typename T>
-void TestDynamicQuantizeMatMul(const std::vector<int64_t>& A_dims,
-                               std::vector<int64_t> B_dims,
-                               const std::string& reference_model,
-                               bool is_matrix_b_constant,
+static void CalculateDynamicQuantizeMatMul(const int64_t M, const int64_t N, const int64_t K,
+                                           const std::vector<float>& A_data, const std::vector<T>& B_data,
+                                           std::vector<float>& B_scale, std::vector<T>& B_zero_point,
+                                           const std::vector<float>& Bias, std::vector<float>& Y_data,
+                                           bool per_column, bool has_zp, bool has_bias) {
+  // DynamicQuantize Matrix A
+  const uint32_t num_elements = static_cast<uint32_t>(M * K);
+  std::vector<T> QuantA_data(num_elements);
+  std::vector<float> A_scale;
+  std::vector<T> A_zero_point;
+
+  // Get max and min
+  float min = std::numeric_limits<float>::max();
+  float max = std::numeric_limits<float>::lowest();
+  float qmax = static_cast<float>(std::numeric_limits<T>::max());
+  float qmin = static_cast<float>(std::numeric_limits<T>::lowest());
+
+  for (uint32_t i = 0; i < num_elements; ++i) {
+    max = std::max(A_data[i], max);
+    min = std::min(A_data[i], min);
+  }
+
+  // Adjust the maximum and minimum to include zero
+  max = std::max(max, 0.0f);
+  min = std::min(min, 0.0f);
+
+  float scale = static_cast<float>(max - min) / (qmax - qmin);
+  T zeroPoint = std::round(std::clamp(qmin - min / scale, qmin, qmax));
+
+  A_scale.push_back(scale);
+  A_zero_point.push_back(zeroPoint);
+
+  // Matrix Multiplication
+  for (uint32_t i = 0; i < num_elements; ++i) {
+    QuantA_data[i] = static_cast<T>(std::round((A_data[i] / scale) + zeroPoint));
+  }
+  if (!per_column) {
+    B_zero_point.resize(N, B_zero_point[0]);
+    B_scale.resize(N, B_scale[0]);
+  }
+
+  for (int64_t m = 0; m < M; m++) {
+    for (int64_t n = 0; n < N; n++) {
+      float sum = 0.0f;
+      for (int64_t k = 0; k < K; k++) {
+        float A_dequantized = (static_cast<int>(QuantA_data[m * K + k]) - static_cast<int>(A_zero_point[0])) * A_scale[0];
+
+        float B_dequantized = has_zp ? (static_cast<int>(B_data[k * N + n]) - static_cast<int>(B_zero_point[n])) * B_scale[n] : B_data[k * N + n] * B_scale[n];
+
+        sum += A_dequantized * B_dequantized;
+      }
+      if (has_bias) {
+        sum += Bias[n];
+      }
+      Y_data[m * N + n] = sum;
+    }
+  }
+}
+
+template <typename T>
+void TestDynamicQuantizeMatMul(bool is_matrix_b_constant,
                                bool per_column = false,
                                bool has_zp = true,
-                               bool has_bias = false) {
+                               bool has_bias = false,
+                               bool empty_input = false) {
   // create rand inputs
-  RandomValueGenerator random{};
-
+  RandomValueGenerator random{1668426375};
+
+  int64_t M = empty_input ? 1 : 4;
+  int64_t N = 128;
+  int64_t K = 128;
+  std::vector<int64_t> A_dims{empty_input ? 0 : M, K};
+  std::vector<int64_t> B_dims{K, N};
+  std::vector<int64_t> Y_dims{empty_input ? 0 : M, K};
   std::vector<float> A_data = random.Uniform<float>(A_dims, -1.0f, 1.0f);
-
   std::vector<T> B_data;
-  std::vector<int> tmp_B_data = random.Uniform<int32_t>(B_dims, std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
+  std::vector<T> tmp_B_data = random.Uniform<T>(B_dims,
+                                                (std::is_same_v<T, int8_t>) ? std::numeric_limits<int8_t>::lowest() / 2 : std::numeric_limits<uint8_t>::lowest(),
+                                                std::numeric_limits<T>::max() / 2);
   std::transform(tmp_B_data.begin(), tmp_B_data.end(), std::back_inserter(B_data), [](int32_t v) -> T {
     return static_cast<T>(v);
   });
@@ -47,7 +112,9 @@ void TestDynamicQuantizeMatMul(const std::vector<int64_t>& A_dims,
   std::for_each(B_zero_point.begin(),
                 B_zero_point.end(),
                 [&random](T& zp) {
-                  zp = static_cast<T>(random.Uniform<int32_t>(std::array<int64_t, 1>{1}, std::numeric_limits<T>::min(), std::numeric_limits<T>::max())[0]);
+                  zp = static_cast<T>(random.Uniform<T>(std::array<int64_t, 1>{1},
+                                                        std::numeric_limits<T>::min(),
+                                                        std::numeric_limits<T>::max())[0]);
                 });
 
   std::vector<float> Bias = random.Uniform<float>(AsSpan({B_dims.back()}), -0.1f, 0.1f);
@@ -69,77 +136,85 @@ void TestDynamicQuantizeMatMul(const std::vector<int64_t>& A_dims,
     test.AddOptionalInputEdge<float>();
   }
 
-  test.AddReferenceOutputs(reference_model);
+  std::vector<float> Y_data(M * N);
+  CalculateDynamicQuantizeMatMul<T>(M, N, K, A_data, B_data, B_scale, B_zero_point, Bias, Y_data,
+                                    per_column, has_zp, has_bias);
+  test.AddOutput<float>("Y", Y_dims, Y_data);
+  test.SetOutputRelErr("Y", 0.02f);
   test.Run();
 }
 
-template <typename Scalar, bool HasZeroPoint, bool HasBias>
-void RunDynamicQuantizeMatMulTest(const string& model_path) {
-  std::vector<int64_t> A_dims{4, 128};
-  std::vector<int64_t> B_dims{128, 128};
-  std::vector<int64_t> Y_dims{4, 128};
-
-  TestDynamicQuantizeMatMul<Scalar>(A_dims,
-                                    B_dims,
-                                    model_path,
-                                    false,        /*is_matrix_b_constant*/
-                                    false,        /*per_column*/
-                                    HasZeroPoint, /*has_zp*/
-                                    HasBias       /*has_bias*/
+template <typename T, bool HasZeroPoint, bool HasBias>
+void RunDynamicQuantizeMatMulTest() {
+  TestDynamicQuantizeMatMul<T>(false,        /*is_matrix_b_constant*/
+                               false,        /*per_column*/
+                               HasZeroPoint, /*has_zp*/
+                               HasBias       /*has_bias*/
   );
 
-  TestDynamicQuantizeMatMul<Scalar>(A_dims,
-                                    B_dims,
-                                    model_path,
-                                    true,         /*is_matrix_b_constant*/
-                                    false,        /*per_column*/
-                                    HasZeroPoint, /*has_zp*/
-                                    HasBias       /*has_bias*/
+  TestDynamicQuantizeMatMul<T>(true,         /*is_matrix_b_constant*/
+                               false,        /*per_column*/
+                               HasZeroPoint, /*has_zp*/
+                               HasBias       /*has_bias*/
   );
 
-  TestDynamicQuantizeMatMul<Scalar>(A_dims,
-                                    B_dims,
-                                    model_path,
-                                    false,        /*is_matrix_b_constant*/
-                                    true,         /*per_column*/
-                                    HasZeroPoint, /*has_zp*/
-                                    HasBias       /*has_bias*/
+  TestDynamicQuantizeMatMul<T>(false,        /*is_matrix_b_constant*/
+                               true,         /*per_column*/
+                               HasZeroPoint, /*has_zp*/
+                               HasBias       /*has_bias*/
   );
 
-  TestDynamicQuantizeMatMul<Scalar>(A_dims,
-                                    B_dims,
-                                    model_path,
-                                    true,         /*is_matrix_b_constant*/
-                                    true,         /*per_column*/
-                                    HasZeroPoint, /*has_zp*/
-                                    HasBias       /*has_bias*/
+  TestDynamicQuantizeMatMul<T>(true,         /*is_matrix_b_constant*/
+                               true,         /*per_column*/
+                               HasZeroPoint, /*has_zp*/
+                               HasBias       /*has_bias*/
   );
 }
 
-TEST(DynamicQuantizeMatMul, HasZeroPoint_NoBias_test) {
-  RunDynamicQuantizeMatMulTest<int8_t, true, false>("testdata/dynamic_quantize_matmul_int8.onnx");
-  RunDynamicQuantizeMatMulTest<uint8_t, true, false>("testdata/dynamic_quantize_matmul_uint8.onnx");
+TEST(DynamicQuantizeMatMul, HasZeroPoint_NoBias_test_S8) {
+  RunDynamicQuantizeMatMulTest<int8_t, true, false>();
 }
 
-TEST(DynamicQuantizeMatMul, NoZeroPoint_HasBias_test) {
-  RunDynamicQuantizeMatMulTest<int8_t, false, true>("testdata/dynamic_quantize_matmul_int8_bias.onnx");
-  RunDynamicQuantizeMatMulTest<uint8_t, false, true>("testdata/dynamic_quantize_matmul_uint8_bias.onnx");
+TEST(DynamicQuantizeMatMul, HasZeroPoint_NoBias_test_U8) {
+  RunDynamicQuantizeMatMulTest<uint8_t, true, false>();
+}
+
+TEST(DynamicQuantizeMatMul, NoZeroPoint_HasBias_test_S8) {
+  RunDynamicQuantizeMatMulTest<int8_t, false, true>();
+}
+
+TEST(DynamicQuantizeMatMul, NoZeroPoint_HasBias_test_U8) {
+  RunDynamicQuantizeMatMulTest<uint8_t, false, true>();
+}
+
+TEST(DynamicQuantizeMatMul, NoZeroPoint_NoBias_test_S8) {
+  RunDynamicQuantizeMatMulTest<int8_t, false, false>();
+}
+
+TEST(DynamicQuantizeMatMul, NoZeroPoint_NoBias_test_U8) {
+  RunDynamicQuantizeMatMulTest<uint8_t, false, false>();
+}
+
+TEST(DynamicQuantizeMatMul, HasZeroPoint_HasBias_test_S8) {
+  RunDynamicQuantizeMatMulTest<int8_t, true, true>();
+}
+
+TEST(DynamicQuantizeMatMul, HasZeroPoint_HasBias_test_U8) {
+  RunDynamicQuantizeMatMulTest<uint8_t, true, true>();
 }
 
 TEST(DynamicQuantizeMatMul, UInt8_test_with_empty_input) {
-  std::vector<int64_t> A_dims{0, 128};
-  std::vector<int64_t> B_dims{128, 128};
-  std::vector<int64_t> Y_dims{0, 128};
-
-  TestDynamicQuantizeMatMul<uint8_t>(A_dims,
-                                     B_dims,
-                                     "testdata/dynamic_quantize_matmul_uint8.onnx",
-                                     false /*is_matrix_b_constant*/);
-
-  TestDynamicQuantizeMatMul<uint8_t>(A_dims,
-                                     B_dims,
-                                     "testdata/dynamic_quantize_matmul_uint8.onnx",
-                                     true /*is_matrix_b_constant*/);
+  std::vector<int64_t> A_dims{0, 2};
+  std::vector<int64_t> B_dims{2, 2};
+  std::vector<int64_t> Y_dims{0, 2};
+  OpTester test("DynamicQuantizeMatMul", 1, onnxruntime::kMSDomain);
+  test.AddInput<float>("T1", A_dims, {});
+  test.AddInput<uint8_t>("T2", B_dims, {1, 6, 0, 8});
+  test.AddInput<float>("b_scale", {1}, {1.0f});
+  test.AddInput<uint8_t>("b_zero_point", {1}, {0});
+  test.AddOptionalInputEdge<float>();
+  test.AddOutput<float>("Y", {0, 2}, {});
+  test.Run();
 }
 
 TEST(DynamicQuantizeMatMul, B_PerColumn_ND) {
diff --git a/onnxruntime/test/contrib_ops/fft_op_test.cc b/onnxruntime/test/contrib_ops/fft_op_test.cc
index 56a6466c760f6..7a6b6cca6425a 100644
--- a/onnxruntime/test/contrib_ops/fft_op_test.cc
+++ b/onnxruntime/test/contrib_ops/fft_op_test.cc
@@ -25,6 +25,7 @@ TEST(ContribOpTest, Rfft) {
   // Target values conputed using PyTorch torch.fft.rfft(X, dim=-1, norm="backward")
   test.AddInput<float>("X", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f});
   test.AddOutput<float>("Y", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f});
+  test.SetOutputTolerance(0.0001f);
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
 }
 
@@ -45,6 +46,7 @@ TEST(ContribOpTest, Irfft) {
   test.AddAttribute("normalized", static_cast<int64_t>(0));
   test.AddInput<float>("X", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f});
   test.AddOutput<float>("Y", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f});
+  test.SetOutputTolerance(0.0001f);
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
 }
 }  // namespace test
diff --git a/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc
index a24f3b6b441e1..d9d2681dd3b3f 100644
--- a/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc
+++ b/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc
@@ -50,6 +50,8 @@ static void RunGemmFastGeluGpuTest(const std::vector<float>& input_data, const s
     tester.AddOutput<float>("Y", output_dims, output_data);
   }
 
+  tester.SetOutputTolerance(use_float16 ? 0.005f : 0.0025f);
+
   tester.Config(run_with_tunable_op)
       .RunWithConfig();
 }
@@ -154,7 +156,7 @@ TEST(GemmFastGeluTest, GemmFastGeluWithoutBiasFloat16) {
 
   RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data,
                          input_dims, weight_dims, bias_dims, output_dims,
-                         false);
+                         false, true);
 }
 
 TEST(GemmFastGeluTest, GemmFastGeluWithBiasFloat16) {
@@ -189,7 +191,7 @@ TEST(GemmFastGeluTest, GemmFastGeluWithBiasFloat16) {
 
   RunGemmFastGeluGpuTest(input_data, weight_data, bias_data, output_data,
                          input_dims, weight_dims, bias_dims, output_dims,
-                         true);
+                         true, true);
 }
 
 TEST(GemmFastGeluTest, GemmFastGeluWithBias_bfloat16) {
diff --git a/onnxruntime/test/contrib_ops/greedy_search_test.cc b/onnxruntime/test/contrib_ops/greedy_search_test.cc
index 1baf50c1ba616..73da82d4bb039 100644
--- a/onnxruntime/test/contrib_ops/greedy_search_test.cc
+++ b/onnxruntime/test/contrib_ops/greedy_search_test.cc
@@ -8,6 +8,10 @@
 #include "core/session/onnxruntime_cxx_api.h"
 #include "test/common/cuda_op_test_utils.h"
 
+#ifdef USE_CUDA
+#include "core/providers/cuda/cuda_provider_options.h"
+#endif
+
 extern std::unique_ptr<Ort::Env> ort_env;
 
 namespace onnxruntime {
@@ -64,9 +68,13 @@ TEST(GreedySearchTest, GptGreedySearchFp16_VocabPadded) {
 
   if (is_cuda || is_rocm) {
     Ort::SessionOptions session_options;
+#ifdef USE_CUDA
     if (is_cuda) {
-      Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
+      OrtCUDAProviderOptionsV2 cuda_options;
+      cuda_options.use_tf32 = false;
+      session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
     }
+#endif
     if (is_rocm) {
       Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0));
     }
@@ -145,9 +153,13 @@ TEST(GreedySearchTest, GptGreedySearchFp32) {
 
   if (is_cuda || is_rocm) {
     Ort::SessionOptions session_options;
+#ifdef USE_CUDA
     if (is_cuda) {
-      Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
+      OrtCUDAProviderOptionsV2 cuda_options;
+      cuda_options.use_tf32 = false;
+      session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
     }
+#endif
     if (is_rocm) {
       Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0));
     }
diff --git a/onnxruntime/test/contrib_ops/gridsample_test.cc b/onnxruntime/test/contrib_ops/gridsample_test.cc
index 1f31c2bd21f14..d970178e29ab8 100644
--- a/onnxruntime/test/contrib_ops/gridsample_test.cc
+++ b/onnxruntime/test/contrib_ops/gridsample_test.cc
@@ -32,7 +32,7 @@ TEST(GridsampleContribOpTest, gridsample_default) {
                          3.8000f, 7.9000f, 8.7000f, 9.5000f, 10.3000f, 5.3000f,
                          5.4000f, 11.1000f, 11.9000f, 12.7000f, 13.5000f, 6.9000f,
                          3.0000f, 6.1500f, 6.5500f, 6.9500f, 7.3500f, 3.7500f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider});
 }
 
 TEST(GridsampleContribOpTest, gridsample_paddingmode_zeros) {
@@ -45,7 +45,7 @@ TEST(GridsampleContribOpTest, gridsample_paddingmode_zeros) {
                         5.0000f, 5.0000f, 10.0000f, 10.0000f});
   test.AddAttribute("padding_mode", "zeros");
   test.AddOutput<float>("Y", {1, 1, 2, 4}, {0.0000f, 0.0000f, 1.7000f, 0.0000f, 0.0000f, 1.7000f, 0.0000f, 0.0000f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider});
 }
 
 TEST(GridsampleContribOpTest, gridsample_paddingmode_border) {
@@ -58,7 +58,7 @@ TEST(GridsampleContribOpTest, gridsample_paddingmode_border) {
                         5.0000f, 5.0000f, 10.0000f, 10.0000f});
   test.AddAttribute("padding_mode", "border");
   test.AddOutput<float>("Y", {1, 1, 2, 4}, {0.0000f, 0.0000f, 1.7000f, 5.0000f, 5.0000f, 1.7000f, 5.0000f, 5.0000f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider});
 }
 
 TEST(GridsampleContribOpTest, gridsample_paddingmode_reflection) {
@@ -71,7 +71,8 @@ TEST(GridsampleContribOpTest, gridsample_paddingmode_reflection) {
                         5.0000f, 5.0000f, 10.0000f, 10.0000f});
   test.AddAttribute("padding_mode", "reflection");
   test.AddOutput<float>("Y", {1, 1, 2, 4}, {2.5000f, 0.0000f, 1.7000f, 2.5000f, 2.5000f, 1.7000f, 5.0000f, 2.5000f});
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider});  // Accuracy issue for QNN
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kCudaNHWCExecutionProvider, kQnnExecutionProvider});  // Accuracy issue for QNN
 }
 
 TEST(GridsampleContribOpTest, gridsample_aligncorners_true) {
@@ -86,7 +87,7 @@ TEST(GridsampleContribOpTest, gridsample_aligncorners_true) {
   test.AddAttribute("mode", "bilinear");
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", {1, 1, 2, 4}, {0.0000f, 1.2500f, 2.0000f, 2.5000f, 2.5000f, 2.0000f, 3.7500f, 5.0000f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider});
 }
 
 TEST(GridsampleContribOpTest, gridsample_mode_bilinear) {
@@ -99,7 +100,7 @@ TEST(GridsampleContribOpTest, gridsample_mode_bilinear) {
                         0.5000f, 0.5000f, 1.0000f, 1.0000f});
   test.AddAttribute("mode", "bilinear");
   test.AddOutput<float>("Y", {1, 1, 2, 4}, {0.0000f, 0.5000f, 1.7000f, 2.5000f, 2.5000f, 1.7000f, 4.5000f, 1.2500f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider});
 }
 
 TEST(GridsampleContribOpTest, gridsample_mode_nearest) {
@@ -112,7 +113,7 @@ TEST(GridsampleContribOpTest, gridsample_mode_nearest) {
                         0.5000f, 0.5000f, 1.0000f, 1.0000f});
   test.AddAttribute("mode", "nearest");
   test.AddOutput<float>("Y", {1, 1, 2, 4}, {0.f, 0.f, 2.f, 2.f, 2.f, 2.f, 5.f, 0.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider});
 }
 
 TEST(GridsampleContribOpTest, gridsample_mode_bicubic) {
@@ -125,7 +126,8 @@ TEST(GridsampleContribOpTest, gridsample_mode_bicubic) {
                         0.5000f, 0.5000f, 1.0000f, 1.0000f});
   test.AddAttribute("mode", "bicubic");
   test.AddOutput<float>("Y", {1, 1, 2, 4}, {-0.1406f, 0.3828f, 1.7556f, 2.9688f, 2.9688f, 1.7556f, 5.1445f, 1.3906f});
-  test.Run();
+  test.SetOutputTolerance(0.0001f);
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider});
 }
 
 }  // namespace test
diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc
index 84bbee35eed5a..655c4951f262d 100644
--- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc
+++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc
@@ -7,6 +7,7 @@
 #include "core/session/inference_session.h"
 #include "test/common/dnnl_op_test_utils.h"
 #include "test/common/tensor_op_test_utils.h"
+#include "test/common/cuda_op_test_utils.h"
 #include "test/framework/test_utils.h"
 #include "test/util/include/default_providers.h"
 #include "test/providers/provider_test_utils.h"
@@ -75,6 +76,28 @@ TEST(LayerNormTest, LayerNorm) {
   test.Run();
 }
 
+TEST(LayerNormTest, LayerNorm_BFloat16Input) {
+// prevents test from running on non-BF16-supporting hardware
+#ifdef USE_CUDA
+  int min_cuda_architecture = 530;
+  if (!HasCudaEnvironment(min_cuda_architecture)) {
+    LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16";
+    return;
+  }
+#endif
+  OpTester test("LayerNormalization");
+  test.AddAttribute<float>("epsilon", 1e-05f);
+
+  std::vector<int64_t> dims{1, 2, 3};
+  test.AddInput<BFloat16>("x", dims, MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}));
+  test.AddInput<BFloat16>("gamma", {3}, MakeBFloat16({1.0f, 1.0f, 1.0f}));
+  test.AddOutput<BFloat16>("output", dims, MakeBFloat16({-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f}));
+  // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider,
+            kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider});
+}
+
 TEST(LayerNormTest, LayerNorm_Scale) {
   OpTester test("LayerNormalization");
   test.AddAttribute<float>("epsilon", 1e-05f);
@@ -137,6 +160,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias) {
   test.AddInput<float>("gamma", {2}, {-0.6953f, 5.1824f});
   test.AddInput<float>("bias", {2}, {0.6435f, -0.3964f});
   test.AddOutput<float>("output", dims, {-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -149,6 +173,8 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16Input) {
   test.AddInput<float>("gamma", {2}, {-0.6953f, 5.1824f});
   test.AddInput<float>("bias", {2}, {0.6435f, -0.3964f});
   test.AddOutput<float>("output", dims, {-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f});
+  test.SetOutputTolerance(0.0001f);
+
   // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes
   test.Run(OpTester::ExpectResult::kExpectSuccess, "",
            {kTensorrtExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider,
@@ -205,6 +231,9 @@ TEST(LayerNormTest, LayerNorm17_double) {
   test.AddInput<double>("x", dims, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0});
   test.AddInput<double>("gamma", {3}, {1.0, 1.0, 1.0});
   test.AddOutput<double>("output", dims, {-1.2247, 0.0, 1.2247, -1.2247, 0.0, 1.2247});
+
+  test.SetOutputTolerance(0.0001f);
+
   // DNNL does not support double
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDnnlExecutionProvider});
 }
diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc
index 2ad20eafc2ef1..d294fd4e2b0e0 100644
--- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc
+++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc
@@ -2,6 +2,7 @@
 // Licensed under the MIT License.
 
 #ifndef ORT_MINIMAL_BUILD
+#include <gsl/narrow>
 
 #include "core/common/span_utils.h"
 #include "core/framework/tensor.h"
@@ -66,7 +67,9 @@ void QuantizeDequantize(std::vector<float>& raw_vals,
 }
 
 void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level,
-             bool has_zeropoint, bool use_float16, float fp16_abs_error = 0.02f) {
+             bool has_zeropoint, bool use_float16, bool has_g_idx = false,
+             bool zp_is_4bit = true, float fp16_abs_error = 0.02f) {
+  zp_is_4bit = zp_is_4bit | has_g_idx;
   RandomValueGenerator random{1234};
   std::vector<float> input0_vals(random.Gaussian<float>(std::vector<int64_t>({M, K}), 0.0f, 0.25f));
   std::vector<float> input1_f_vals(random.Gaussian<float>(std::vector<int64_t>({K, N}), 0.0f, 0.25f));
@@ -113,12 +116,40 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura
   test.AddAttribute<int64_t>("block_size", block_size);
   test.AddAttribute<int64_t>("bits", QBits);
   test.AddAttribute<int64_t>("accuracy_level", accuracy_level);
+  auto ceildiv = [](int64_t a, int64_t b) { return (a + b - 1) / b; };
+
   if (use_float16) {
     test.AddInput<MLFloat16>("A", {M, K}, ToFloat16(input0_vals), false);
     test.AddInput<uint8_t>("B", {q_cols, q_rows}, input1_vals, true);
     test.AddInput<MLFloat16>("scales", {static_cast<int64_t>(q_scale_size)}, ToFloat16(scales), true);
     if (has_zeropoint) {
-      test.AddInput<uint8_t>("zero_points", {static_cast<int64_t>(q_zp_size_in_bytes)}, zp, true);
+      if (zp_is_4bit) {
+        test.AddInput<uint8_t>("zero_points", {static_cast<int64_t>(q_zp_size_in_bytes)}, zp, true);
+      } else {
+        std::vector<float> zp_f;
+        zp_f.reserve(q_zp_size_in_bytes * 2);
+        for (size_t i = 0; i < zp.size(); i++) {
+          zp_f.push_back(static_cast<float>(zp[i] & 0xf));
+          zp_f.push_back(static_cast<float>((zp[i] >> 4) & 0xf));
+        }
+        size_t ind = zp_f.size() - 1;
+        while (zp_f.size() != q_scale_size) {
+          zp_f.erase(zp_f.begin() + ind);
+          ind -= q_scale_size / N + 1;
+        }
+
+        test.AddInput<MLFloat16>("zero_points", {static_cast<int64_t>(q_scale_size)}, ToFloat16(zp_f), true);
+      }
+    } else {
+      test.AddInput<uint8_t>("", {0}, {});
+    }
+    if (has_g_idx) {
+      int K_pad = gsl::narrow<int32_t>(ceildiv(K, block_size) * block_size);
+      std::vector<int32_t> g_idx(K_pad);
+      for (int64_t i = 0; i < K_pad; i++) {
+        g_idx[i] = gsl::narrow<int32_t>(i / block_size);
+      }
+      test.AddInput<int32_t>("g_idx", {static_cast<int64_t>(K_pad)}, g_idx, true);
     }
 
     test.AddOutput<MLFloat16>("Y", {M, N}, ToFloat16(expected_vals));
@@ -132,9 +163,34 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura
     test.AddInput<uint8_t>("B", {q_cols, q_rows}, input1_vals, true);
     test.AddInput<float>("scales", {static_cast<int64_t>(q_scale_size)}, scales, true);
     if (has_zeropoint) {
-      test.AddInput<uint8_t>("zero_points", {static_cast<int64_t>(q_zp_size_in_bytes)}, zp, true);
-    }
+      if (zp_is_4bit) {
+        test.AddInput<uint8_t>("zero_points", {static_cast<int64_t>(q_zp_size_in_bytes)}, zp, true);
+      } else {
+        std::vector<float> zp_f;
+        zp_f.reserve(q_zp_size_in_bytes * 2);
+        for (size_t i = 0; i < zp.size(); i++) {
+          zp_f.push_back(static_cast<float>(zp[i] & 0xf));
+          zp_f.push_back(static_cast<float>((zp[i] >> 4) & 0xf));
+        }
+        size_t ind = zp_f.size() - 1;
+        while (zp_f.size() != q_scale_size) {
+          zp_f.erase(zp_f.begin() + ind);
+          ind -= q_scale_size / N + 1;
+        }
 
+        test.AddInput<float>("zero_points", {static_cast<int64_t>(q_scale_size)}, zp_f, true);
+      }
+    } else {
+      test.AddInput<uint8_t>("", {0}, {});
+    }
+    if (has_g_idx) {
+      int K_pad = gsl::narrow<int32_t>(ceildiv(K, block_size) * block_size);
+      std::vector<int32_t> g_idx(K_pad);
+      for (int64_t i = 0; i < K_pad; i++) {
+        g_idx[i] = gsl::narrow<int32_t>(i / block_size);
+      }
+      test.AddInput<int32_t>("g_idx", {static_cast<int64_t>(K_pad)}, g_idx, true);
+    }
     test.AddOutput<float>("Y", {M, N}, expected_vals);
     if (accuracy_level == 4) {
       test.SetOutputAbsErr("Y", 0.1f);
@@ -158,6 +214,8 @@ TEST(MatMulNBits, Float32) {
           for (auto accuracy_level : {0}) {
             RunTest(M, N, K, block_size, accuracy_level, false, false);
             RunTest(M, N, K, block_size, accuracy_level, true, false);
+            RunTest(M, N, K, block_size, accuracy_level, false, false, true);
+            RunTest(M, N, K, block_size, accuracy_level, true, false, false, false);
           }
 #endif
         }
@@ -172,8 +230,10 @@ TEST(MatMulNBits, Float16) {
     for (auto N : {1, 2, 32, 288}) {
       for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) {
         for (auto block_size : {16, 32, 64, 128}) {
-          RunTest(M, N, K, block_size, 0, false, true);
-          RunTest(M, N, K, block_size, 0, true, true);
+          for (auto has_gidx : {true, false}) {
+            RunTest(M, N, K, block_size, 0, false, true, has_gidx);
+            RunTest(M, N, K, block_size, 0, true, true, has_gidx, false);
+          }
         }
       }
     }
@@ -183,9 +243,9 @@ TEST(MatMulNBits, Float16) {
 TEST(MatMulNBits, Float16Large) {
   for (auto block_size : {16, 32, 64, 128}) {
     for (auto symmetric : {false, true}) {
-      RunTest(1, 4096, 4096, block_size, 0, symmetric, true, 0.05f);
-      RunTest(1, 4096, 11008, block_size, 0, symmetric, true, 0.05f);
-      RunTest(1, 11008, 4096, block_size, 0, symmetric, true, 0.05f);
+      RunTest(1, 4096, 4096, block_size, 0, symmetric, true, false, true, 0.05f);
+      RunTest(1, 4096, 11008, block_size, 0, symmetric, true, false, true, 0.05f);
+      RunTest(1, 11008, 4096, block_size, 0, symmetric, true, false, true, 0.05f);
     }
   }
 }
diff --git a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc
index 26ce5272d25ee..8d7629b5fda1c 100644
--- a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc
+++ b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc
@@ -23,135 +23,408 @@ using namespace std;
 namespace onnxruntime {
 namespace test {
 
-template <typename IType, typename WType>
-void TestMatMulIntegerToFloat(const std::vector<int64_t>& A_dims,
-                              std::vector<int64_t> B_dims,
-                              const std::string& reference_model,
-                              bool is_matrix_b_constant,
+template <typename IType, typename WType, typename OType>
+static void CalculateMatMulIntegerToFloat(const int64_t M, const int64_t N, const int64_t K,
+                                          const std::vector<IType>& A_data, const std::vector<OType>& A_scale,
+                                          const std::vector<IType>& A_zero_point, const std::vector<WType>& B_data,
+                                          std::vector<OType>& B_scale, std::vector<WType>& B_zero_point,
+                                          const std::vector<OType>& Bias, std::vector<float>& Y_data,
+                                          bool per_column, bool has_zp, bool has_bias) {
+  if (!per_column) {
+    B_zero_point.resize(N, B_zero_point[0]);
+    B_scale.resize(N, B_scale[0]);
+  }
+
+  for (int64_t m = 0; m < M; m++) {
+    for (int64_t n = 0; n < N; n++) {
+      float sum = 0.0f;
+      for (int64_t k = 0; k < K; k++) {
+        float A_dequantized = has_zp ? (static_cast<int>(A_data[m * K + k]) - static_cast<int>(A_zero_point[0])) * A_scale[0] : A_data[m * K + k] * A_scale[0];
+        float B_dequantized = has_zp ? (static_cast<int>(B_data[k * N + n]) - static_cast<int>(B_zero_point[n])) * B_scale[n] : B_data[k * N + n] * B_scale[n];
+
+        sum += A_dequantized * B_dequantized;
+      }
+      if (has_bias) {
+        sum += Bias[n];
+      }
+      Y_data[m * N + n] = static_cast<OType>(sum);
+    }
+  }
+}
+
+template <typename IType, typename WType, typename OType>
+void TestMatMulIntegerToFloat(bool is_matrix_b_constant,
                               bool per_column = false,
                               bool has_zp = true,
                               bool has_bias = false) {
   // create rand inputs
   RandomValueGenerator random{};
-
+  int64_t M = 4;
+  int64_t N = 128;
+  int64_t K = 128;
+  std::vector<int64_t> A_dims{M, K};
+  std::vector<int64_t> B_dims{K, N};
+  std::vector<int64_t> Y_dims{M, K};
   std::vector<IType> A_data;
-  std::vector<int> tmp_A_data = random.Uniform<int32_t>(A_dims,
-                                                        std::numeric_limits<WType>::lowest(),
-                                                        std::numeric_limits<WType>::max());
-  std::transform(tmp_A_data.begin(), tmp_A_data.end(), std::back_inserter(A_data), [](int32_t v) -> WType {
+  std::vector<IType> tmp_A_data = random.Uniform<IType>(A_dims,
+                                                        std::numeric_limits<IType>::lowest(),
+                                                        std::numeric_limits<IType>::max());
+  std::transform(tmp_A_data.begin(), tmp_A_data.end(), std::back_inserter(A_data), [](int32_t v) -> IType {
     return static_cast<IType>(v);
   });
 
   std::vector<WType> B_data;
-  std::vector<int> tmp_B_data = random.Uniform<int32_t>(B_dims,
-                                                        std::numeric_limits<WType>::lowest(),
-                                                        std::numeric_limits<WType>::max());
+
+  std::vector<WType> tmp_B_data;
+  tmp_B_data = random.Uniform<WType>(B_dims,
+                                     std::is_signed<WType>::value ? std::numeric_limits<int8_t>::lowest() / 2 : std::numeric_limits<uint8_t>::lowest(),
+                                     std::numeric_limits<WType>::max() / 2);
   std::transform(tmp_B_data.begin(), tmp_B_data.end(), std::back_inserter(B_data), [](int32_t v) -> WType {
     return static_cast<WType>(v);
   });
 
-  std::vector<float> A_scale = random.Uniform<float>(AsSpan<int64_t>({1}), -0.1f, 0.1f);
+  std::vector<OType> A_scale = random.Uniform<OType>(AsSpan<int64_t>({1}), -0.1f, 0.1f);
   std::vector<IType> A_zero_point{(std::numeric_limits<IType>::lowest() + std::numeric_limits<IType>::max() + IType(2)) / 2};
 
   int64_t b_scale_zp_size = per_column ? B_dims.back() : 1;
-  std::vector<float> B_scale = random.Uniform<float>(AsSpan({b_scale_zp_size}), -0.1f, 0.1f);
+  std::vector<OType> B_scale = random.Uniform<OType>(AsSpan({b_scale_zp_size}), -0.1f, 0.1f);
 
   std::vector<WType> B_zero_point(b_scale_zp_size);
   std::for_each(B_zero_point.begin(),
                 B_zero_point.end(),
                 [&random](WType& zp) {
-                  zp = static_cast<WType>(random.Uniform<int32_t>(std::array<int64_t, 1>{1},
-                                                                  std::numeric_limits<WType>::lowest(),
-                                                                  std::numeric_limits<WType>::max())[0]);
+                  zp = static_cast<WType>(random.Uniform<WType>(std::array<int64_t, 1>{1},
+                                                                std::numeric_limits<WType>::lowest(),
+                                                                std::numeric_limits<WType>::max())[0]);
                 });
 
-  std::vector<float> Bias = random.Uniform<float>(AsSpan({B_dims.back()}), -0.1f, 0.1f);
+  std::vector<OType> Bias = random.Uniform<OType>(AsSpan({B_dims.back()}), -0.1f, 0.1f);
 
   OpTester test("MatMulIntegerToFloat", 1, onnxruntime::kMSDomain);
   test.AddInput<IType>("A", A_dims, A_data);
   test.AddInput<WType>("B", B_dims, B_data, is_matrix_b_constant);
-  test.AddInput<float>("a_scale", {1}, A_scale);
-  test.AddInput<float>("b_scale", {b_scale_zp_size}, B_scale);
+  test.AddInput<OType>("a_scale", {1}, A_scale);
+  test.AddInput<OType>("b_scale", {b_scale_zp_size}, B_scale);
 
   if (has_zp) {
     test.AddInput<IType>("a_zero_point", {1}, A_zero_point);
     test.AddInput<WType>("b_zero_point", {b_scale_zp_size}, B_zero_point);
   } else {
-    test.AddOptionalInputEdge<WType>();
+    test.AddOptionalInputEdge<IType>();
     test.AddOptionalInputEdge<WType>();
   }
 
   if (has_bias) {
-    test.AddInput<float>("bias", {B_dims.back()}, Bias);
+    test.AddInput<OType>("bias", {B_dims.back()}, Bias);
   } else {
-    test.AddOptionalInputEdge<float>();
+    test.AddOptionalInputEdge<OType>();
   }
 
-  test.AddReferenceOutputs(reference_model);
-  test.SetOutputRelErr("Y", 1e-4f);
-  test.Run();
-}
+  std::vector<float> Y_data(M * N);
+  CalculateMatMulIntegerToFloat<IType, WType, OType>(M, N, K, A_data, A_scale, A_zero_point,
+                                                     B_data, B_scale, B_zero_point, Bias, Y_data,
+                                                     per_column, has_zp, has_bias);
 
-template <typename IType, typename WType, bool HasZeroPoint, bool HasBias>
-void RunMatMulIntegerToFloatTest(const string& model_path) {
-  std::vector<int64_t> A_dims{4, 128};
-  std::vector<int64_t> B_dims{128, 128};
-  std::vector<int64_t> Y_dims{4, 128};
+  if (std::is_same_v<OType, float>) {
+    test.AddOutput<float>("Y", {M, N}, Y_data);
+    test.SetOutputAbsErr("Y", 0.001f);
+    test.SetOutputRelErr("Y", 0.02f);
+  } else {
+    test.AddOutput<MLFloat16>("Y", {M, N}, ToFloat16(Y_data));
+    test.SetOutputAbsErr("Y", 0.5f);
+  }
 
-  TestMatMulIntegerToFloat<IType, WType>(A_dims,
-                                         B_dims,
-                                         model_path,
-                                         false,        /*is_matrix_b_constant*/
-                                         false,        /*per_column*/
-                                         HasZeroPoint, /*has_zp*/
-                                         HasBias       /*has_bias*/
+  // Only DML EP supports these data type combinations for now
+  if (std::is_same_v<OType, MLFloat16> ||
+      (std::is_same_v<OType, float> &&
+       std::is_same_v<IType, int8_t> &&
+       std::is_same_v<WType, uint8_t>)) {
+    std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+    execution_providers.push_back(DefaultDmlExecutionProvider());
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+  } else {
+    test.Run();
+  }
+}
+
+template <typename IType, typename WType, typename OType, bool HasZeroPoint, bool HasBias>
+void RunMatMulIntegerToFloatTest() {
+  TestMatMulIntegerToFloat<IType, WType, OType>(
+      false,        /*is_matrix_b_constant*/
+      false,        /*per_column*/
+      HasZeroPoint, /*has_zp*/
+      HasBias       /*has_bias*/
   );
 
-  TestMatMulIntegerToFloat<IType, WType>(A_dims,
-                                         B_dims,
-                                         model_path,
-                                         true,         /*is_matrix_b_constant*/
-                                         false,        /*per_column*/
-                                         HasZeroPoint, /*has_zp*/
-                                         HasBias       /*has_bias*/
+  TestMatMulIntegerToFloat<IType, WType, OType>(
+      true,         /*is_matrix_b_constant*/
+      false,        /*per_column*/
+      HasZeroPoint, /*has_zp*/
+      HasBias       /*has_bias*/
   );
 
-  TestMatMulIntegerToFloat<IType, WType>(A_dims,
-                                         B_dims,
-                                         model_path,
-                                         false,        /*is_matrix_b_constant*/
-                                         true,         /*per_column*/
-                                         HasZeroPoint, /*has_zp*/
-                                         HasBias       /*has_bias*/
+  TestMatMulIntegerToFloat<IType, WType, OType>(
+      false,        /*is_matrix_b_constant*/
+      true,         /*per_column*/
+      HasZeroPoint, /*has_zp*/
+      HasBias       /*has_bias*/
   );
 
-  TestMatMulIntegerToFloat<IType, WType>(A_dims,
-                                         B_dims,
-                                         model_path,
-                                         true,         /*is_matrix_b_constant*/
-                                         true,         /*per_column*/
-                                         HasZeroPoint, /*has_zp*/
-                                         HasBias       /*has_bias*/
+  TestMatMulIntegerToFloat<IType, WType, OType>(
+      true,         /*is_matrix_b_constant*/
+      true,         /*per_column*/
+      HasZeroPoint, /*has_zp*/
+      HasBias       /*has_bias*/
   );
 }
 
-TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8X8) {
-  RunMatMulIntegerToFloatTest<uint8_t, int8_t, true, false>("testdata/matmul_integer_to_float_int8.onnx");
-  RunMatMulIntegerToFloatTest<uint8_t, uint8_t, true, false>("testdata/matmul_integer_to_float_uint8.onnx");
+TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8S8) {
+  RunMatMulIntegerToFloatTest<int8_t, int8_t, float, true, false>();
 }
 
-TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8X8) {
-  RunMatMulIntegerToFloatTest<uint8_t, int8_t, false, true>("testdata/matmul_integer_to_float_int8_bias.onnx");
-  RunMatMulIntegerToFloatTest<uint8_t, uint8_t, false, true>("testdata/matmul_integer_to_float_uint8_bias.onnx");
+TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8S8) {
+  RunMatMulIntegerToFloatTest<int8_t, int8_t, float, false, true>();
 }
 
-TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8S8) {
-  RunMatMulIntegerToFloatTest<int8_t, int8_t, true, false>("testdata/matmul_integer_to_float_int8_int8.onnx");
+TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_S8S8) {
+  RunMatMulIntegerToFloatTest<int8_t, int8_t, float, false, false>();
 }
 
-TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8S8) {
-  RunMatMulIntegerToFloatTest<int8_t, int8_t, false, true>("testdata/matmul_integer_to_float_int8_int8_bias.onnx");
+TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_S8S8) {
+  RunMatMulIntegerToFloatTest<int8_t, int8_t, float, true, true>();
+}
+
+TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8U8) {
+  RunMatMulIntegerToFloatTest<uint8_t, uint8_t, float, true, false>();
+}
+
+TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8U8) {
+  RunMatMulIntegerToFloatTest<uint8_t, uint8_t, float, false, true>();
+}
+
+TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_U8U8) {
+  RunMatMulIntegerToFloatTest<uint8_t, uint8_t, float, false, false>();
+}
+
+TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_U8X8) {
+  RunMatMulIntegerToFloatTest<uint8_t, uint8_t, float, true, true>();
+}
+
+TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8S8) {
+  RunMatMulIntegerToFloatTest<uint8_t, int8_t, float, true, false>();
+}
+
+TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8S8) {
+  RunMatMulIntegerToFloatTest<uint8_t, int8_t, float, false, true>();
+}
+
+TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_U8S8) {
+  RunMatMulIntegerToFloatTest<uint8_t, int8_t, float, false, false>();
+}
+
+TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_U8S8) {
+  RunMatMulIntegerToFloatTest<uint8_t, int8_t, float, true, true>();
+}
+
+// DML EP supports Float16 output type and Signed A Matrix and Unsigned B Matric for Float32 output
+#if defined(USE_DML)
+
+TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8U8) {
+  RunMatMulIntegerToFloatTest<int8_t, uint8_t, float, true, false>();
+}
+
+TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8U8) {
+  RunMatMulIntegerToFloatTest<int8_t, uint8_t, float, false, true>();
+}
+
+TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_S8U8) {
+  RunMatMulIntegerToFloatTest<int8_t, uint8_t, float, false, false>();
+}
+
+TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_S8U8) {
+  RunMatMulIntegerToFloatTest<int8_t, int8_t, float, true, true>();
+}
+
+TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_U8U8) {
+  OpTester test("MatMulIntegerToFloat", 1, kMSDomain);
+  int64_t M = 5;
+  int64_t N = 5;
+  int64_t K = 2;
+
+  std::vector<uint8_t> A_data = {1, 5, 2, 1, 9,
+                                 1, 1, 3, 7, 2};
+  std::vector<uint8_t> B_data = {3, 7, 2, 1, 1,
+                                 2, 1, 9, 1, 1};
+  std::vector<MLFloat16> A_scale = ToFloat16({3.0f});
+  std::vector<MLFloat16> B_scale = ToFloat16({2.0f});
+  test.AddInput<uint8_t>("A", {M, K}, A_data);
+  test.AddInput<uint8_t>("B", {K, N}, B_data);
+  std::vector<uint8_t> A_zero_point = {1};
+  std::vector<uint8_t> B_zero_point = {1};
+
+  test.AddInput<MLFloat16>("a_scale", {1}, A_scale);
+  test.AddInput<MLFloat16>("b_scale", {1}, B_scale);
+  test.AddInput<uint8_t>("a_zero_point", {1}, A_zero_point);
+  test.AddInput<uint8_t>("b_zero_point", {1}, B_zero_point);
+
+  std::vector<float> Y_data(M * N);
+  CalculateMatMulIntegerToFloat<uint8_t, uint8_t, MLFloat16>(M, N, K, A_data, A_scale, A_zero_point,
+                                                             B_data, B_scale, B_zero_point, {}, Y_data,
+                                                             false, true, false);
+
+  test.AddOutput<MLFloat16>("Y", {M, N}, ToFloat16(Y_data));
+  std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+  execution_providers.push_back(DefaultDmlExecutionProvider());
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_U8S8) {
+  OpTester test("MatMulIntegerToFloat", 1, kMSDomain);
+  int64_t M = 5;
+  int64_t N = 5;
+  int64_t K = 2;
+
+  std::vector<uint8_t> A_data = {3, 7, 2, 1, 1,
+                                 2, 1, 9, 1, 1};
+  std::vector<int8_t> B_data = {2, -1, -9, 1, 1,
+                                -1, 0, -3, 1, -4};
+  std::vector<MLFloat16> A_scale = ToFloat16({-4.0f});
+  std::vector<MLFloat16> B_scale = ToFloat16({2.0f});
+  test.AddInput<uint8_t>("A", {M, K}, A_data);
+  test.AddInput<int8_t>("B", {K, N}, B_data);
+  std::vector<uint8_t> A_zero_point = {1};
+  std::vector<int8_t> B_zero_point = {3};
+  std::vector<MLFloat16> Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f});
+
+  test.AddInput<MLFloat16>("a_scale", {1}, A_scale);
+  test.AddInput<MLFloat16>("b_scale", {1}, B_scale);
+  test.AddInput<uint8_t>("a_zero_point", {1}, A_zero_point);
+  test.AddInput<int8_t>("b_zero_point", {1}, B_zero_point);
+
+  std::vector<float> Y_data(M * N);
+  CalculateMatMulIntegerToFloat<uint8_t, int8_t, MLFloat16>(M, N, K, A_data, A_scale, A_zero_point,
+                                                            B_data, B_scale, B_zero_point, {}, Y_data,
+                                                            false, true, false);
+
+  test.AddOutput<MLFloat16>("Y", {M, N}, ToFloat16(Y_data));
+
+  std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+  execution_providers.push_back(DefaultDmlExecutionProvider());
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_S8S8) {
+  OpTester test("MatMulIntegerToFloat", 1, kMSDomain);
+  int64_t M = 5;
+  int64_t N = 5;
+  int64_t K = 2;
+
+  std::vector<int8_t> A_data = {3, 7, -2, 1, 1,
+                                2, -1, -9, 1, 1};
+  std::vector<int8_t> B_data = {2, -1, -9, 1, 1,
+                                -1, 0, -3, 1, -4};
+  std::vector<MLFloat16> A_scale = ToFloat16({-4.0f});
+  std::vector<MLFloat16> B_scale = ToFloat16({2.0f});
+  test.AddInput<int8_t>("A", {M, K}, A_data);
+  test.AddInput<int8_t>("B", {K, N}, B_data);
+  std::vector<int8_t> A_zero_point = {-1};
+  std::vector<int8_t> B_zero_point = {3};
+  std::vector<MLFloat16> Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f});
+
+  test.AddInput<MLFloat16>("a_scale", {1}, A_scale);
+  test.AddInput<MLFloat16>("b_scale", {1}, B_scale);
+  test.AddInput<int8_t>("a_zero_point", {1}, A_zero_point);
+  test.AddInput<int8_t>("b_zero_point", {1}, B_zero_point);
+  test.AddInput<MLFloat16>("bias", {N}, Bias);
+
+  std::vector<float> Y_data(M * N);
+  CalculateMatMulIntegerToFloat<int8_t, int8_t, MLFloat16>(M, N, K, A_data, A_scale, A_zero_point,
+                                                           B_data, B_scale, B_zero_point, Bias, Y_data,
+                                                           false, true, true);
+
+  test.AddOutput<MLFloat16>("Y", {M, N}, ToFloat16(Y_data));
+
+  std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+  execution_providers.push_back(DefaultDmlExecutionProvider());
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_S8U8) {
+  OpTester test("MatMulIntegerToFloat", 1, kMSDomain);
+  int64_t M = 5;
+  int64_t N = 5;
+  int64_t K = 2;
+
+  std::vector<int8_t> A_data = {3, 7, -2, 1, 1,
+                                2, -1, -9, 1, 1};
+  std::vector<uint8_t> B_data = {3, 7, 2, 1, 1,
+                                 2, 1, 9, 1, 1};
+  std::vector<MLFloat16> A_scale = ToFloat16({-4.0f});
+  std::vector<MLFloat16> B_scale = ToFloat16({2.0f});
+  test.AddInput<int8_t>("A", {M, K}, A_data);
+  test.AddInput<uint8_t>("B", {K, N}, B_data);
+  std::vector<int8_t> A_zero_point = {-1};
+  std::vector<uint8_t> B_zero_point = {1};
+  std::vector<MLFloat16> Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f});
+
+  test.AddInput<MLFloat16>("a_scale", {1}, A_scale);
+  test.AddInput<MLFloat16>("b_scale", {1}, B_scale);
+  test.AddInput<int8_t>("a_zero_point", {1}, A_zero_point);
+  test.AddInput<uint8_t>("b_zero_point", {1}, B_zero_point);
+  test.AddInput<MLFloat16>("bias", {N}, Bias);
+
+  std::vector<float> Y_data(M * N);
+  CalculateMatMulIntegerToFloat<int8_t, uint8_t, MLFloat16>(M, N, K, A_data, A_scale, A_zero_point,
+                                                            B_data, B_scale, B_zero_point, Bias, Y_data,
+                                                            false, true, true);
+
+  test.AddOutput<MLFloat16>("Y", {M, N}, ToFloat16(Y_data));
+
+  std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+  execution_providers.push_back(DefaultDmlExecutionProvider());
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16) {
+  OpTester test("MatMulIntegerToFloat", 1, kMSDomain);
+  int64_t M = 2;
+  int64_t N = 2;
+  int64_t K = 3;
+
+  std::vector<int8_t> A_data = {11, -2, 5,
+                                -1, 3, 10};
+  std::vector<int8_t> B_data = {-13, -2,
+                                9, 55,
+                                -1, 23};
+  std::vector<MLFloat16> A_scale = ToFloat16({0.910f});
+  std::vector<MLFloat16> B_scale = ToFloat16({1.10f, 1.123f});
+
+  std::vector<int8_t> A_zero_point = {113};
+  std::vector<int8_t> B_zero_point = {98, 71};
+
+  std::vector<MLFloat16> Bias = ToFloat16({0.10f, 1.123f});
+
+  test.AddInput<int8_t>("A", {M, K}, A_data);
+  test.AddInput<int8_t>("B", {K, N}, B_data);
+
+  test.AddInput<MLFloat16>("a_scale", {}, {A_scale});
+  test.AddInput<MLFloat16>("b_scale", {N}, B_scale);
+  test.AddInput<int8_t>("a_zero_point", {}, {A_zero_point});
+  test.AddInput<int8_t>("b_zero_point", {N}, B_zero_point);
+  test.AddInput<MLFloat16>("bias", {N}, Bias);
+
+  std::vector<float> Y_data(M * N);
+  CalculateMatMulIntegerToFloat<int8_t, int8_t, MLFloat16>(M, N, K, A_data, A_scale, A_zero_point,
+                                                           B_data, B_scale, B_zero_point, Bias, Y_data,
+                                                           true, true, true);
+
+  test.AddOutput<MLFloat16>("Y", {M, N}, ToFloat16(Y_data));
+  test.SetOutputRelErr("Y", 2e-2f);
+  std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+  execution_providers.push_back(DefaultDmlExecutionProvider());
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
 }
+#endif
 
 TEST(MatMulIntegerToFloat, MatMulInteger_With_ZeroPoint) {
   auto test_case = [&](const std::vector<int64_t>& input_shape,
diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc
index ebb0261deefa5..263ace25ddfe0 100644
--- a/onnxruntime/test/contrib_ops/moe_test.cc
+++ b/onnxruntime/test/contrib_ops/moe_test.cc
@@ -14,6 +14,7 @@ static void RunMoETest(
     const std::vector<float>& router_probs,
     const std::vector<float>& fc1_experts_weights,
     const std::vector<float>& fc2_experts_weights,
+    const std::vector<float>& fc3_experts_weights,
     const std::vector<float>& fc1_experts_bias,
     const std::vector<float>& fc2_experts_bias,
     const std::vector<float>& output_data,
@@ -22,19 +23,23 @@ static void RunMoETest(
     int hidden_size,
     int inter_size,
     std::string activation_type,
+    int normalize_routing_weights = 0,
+    int top_k = 1,
     bool use_float16 = false) {
   int min_cuda_architecture = use_float16 ? 530 : 0;
 
   bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
   if (enable_cuda) {
     OpTester tester("MoE", 1, onnxruntime::kMSDomain);
-    tester.AddAttribute<int64_t>("k", static_cast<int64_t>(1));
+    tester.AddAttribute<int64_t>("k", static_cast<int64_t>(top_k));
     tester.AddAttribute<std::string>("activation_type", activation_type);
+    tester.AddAttribute<int64_t>("normalize_routing_weights", static_cast<int64_t>(normalize_routing_weights));
 
     std::vector<int64_t> input_dims = {num_rows, hidden_size};
     std::vector<int64_t> router_probs_dims = {num_rows, num_experts};
     std::vector<int64_t> fc1_experts_weights_dims = {num_experts, hidden_size, inter_size};
     std::vector<int64_t> fc2_experts_weights_dims = {num_experts, inter_size, hidden_size};
+    std::vector<int64_t> fc3_experts_weights_dims = fc1_experts_weights_dims;
     std::vector<int64_t> fc1_experts_bias_dims = {num_experts, inter_size};
     std::vector<int64_t> fc2_experts_bias_dims = {num_experts, hidden_size};
     std::vector<int64_t> output_dims = {num_rows, hidden_size};
@@ -43,18 +48,42 @@ static void RunMoETest(
       tester.AddInput<MLFloat16>("input", input_dims, ToFloat16(input));
       tester.AddInput<MLFloat16>("router_probs", router_probs_dims, ToFloat16(router_probs));
       tester.AddInput<MLFloat16>("fc1_experts_weights", fc1_experts_weights_dims, ToFloat16(fc1_experts_weights));
+      if (!fc1_experts_bias.empty()) {
+        tester.AddInput<MLFloat16>("fc1_experts_bias", fc1_experts_bias_dims, ToFloat16(fc1_experts_bias));
+      } else {
+        tester.AddOptionalInputEdge<MLFloat16>();
+      }
       tester.AddInput<MLFloat16>("fc2_experts_weights", fc2_experts_weights_dims, ToFloat16(fc2_experts_weights));
-      tester.AddInput<MLFloat16>("fc1_experts_bias", fc1_experts_bias_dims, ToFloat16(fc1_experts_bias));
-      tester.AddInput<MLFloat16>("fc2_experts_bias", fc2_experts_bias_dims, ToFloat16(fc2_experts_bias));
+      if (!fc2_experts_bias.empty()) {
+        tester.AddInput<MLFloat16>("fc2_experts_bias", fc2_experts_bias_dims, ToFloat16(fc2_experts_bias));
+      } else {
+        tester.AddOptionalInputEdge<MLFloat16>();
+      }
+      if (!fc3_experts_weights.empty()) {
+        tester.AddInput<MLFloat16>("fc3_experts_weights", fc3_experts_weights_dims, ToFloat16(fc3_experts_weights));
+      }
       tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
+      tester.SetOutputTolerance(0.005f);
     } else {
       tester.AddInput<float>("input", input_dims, input);
       tester.AddInput<float>("router_probs", router_probs_dims, router_probs);
       tester.AddInput<float>("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights);
+      if (!fc1_experts_bias.empty()) {
+        tester.AddInput<float>("fc1_experts_bias", fc1_experts_bias_dims, fc1_experts_bias);
+      } else {
+        tester.AddOptionalInputEdge<float>();
+      }
       tester.AddInput<float>("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights);
-      tester.AddInput<float>("fc1_experts_bias", fc1_experts_bias_dims, fc1_experts_bias);
-      tester.AddInput<float>("fc2_experts_bias", fc2_experts_bias_dims, fc2_experts_bias);
+      if (!fc2_experts_bias.empty()) {
+        tester.AddInput<float>("fc2_experts_bias", fc2_experts_bias_dims, fc2_experts_bias);
+      } else {
+        tester.AddOptionalInputEdge<float>();
+      }
+      if (!fc3_experts_weights.empty()) {
+        tester.AddInput<float>("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights);
+      }
       tester.AddOutput<float>("output", output_dims, output_data);
+      tester.SetOutputTolerance(0.001f);
     }
 
     std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
@@ -231,6 +260,7 @@ TEST(MoETest, MoETest_Gelu) {
              router_probs,
              fc1_experts_weights,
              fc2_experts_weights,
+             {},
              fc1_experts_bias,
              fc2_experts_bias,
              output,
@@ -409,6 +439,7 @@ TEST(MoETest, MoETest_Relu) {
              router_probs,
              fc1_experts_weights,
              fc2_experts_weights,
+             {},
              fc1_experts_bias,
              fc2_experts_bias,
              output,
@@ -419,5 +450,143 @@ TEST(MoETest, MoETest_Relu) {
              "relu");
 }
 
+TEST(MoETest, MoETest_Mixtral) {
+  int num_rows = 6;
+  int num_experts = 8;
+  int hidden_size = 4;
+  int inter_size = 8;
+
+  const std::vector<float> input = {
+      0.9212995f, 0.5282444f, -0.008228387f, -1.449332f, -0.6051824f, -0.17924511f, 0.1995587f, -1.2461947f,
+      0.86708033f, 0.19191018f, 1.1600108f, -0.008815222f, 0.8504777f, -0.84964496f, -1.4019964f, 0.17225051f,
+      0.35569248f, 1.2056456f, 1.3690308f, -0.69495815f, 1.4324434f, 0.22761835f, -1.1286871f, 1.124213f};
+  const std::vector<float> router_probs = {
+      -0.09331456f, -0.47121337f, 0.07311103f, 0.47643483f, 0.21135253f, -0.72226393f, -0.048502743f, 0.39447474f,
+      -0.9014899f, -0.36629856f, -0.23088816f, -0.099606544f, -0.45191774f, -0.30394578f, 0.6266495f, 0.67937183f,
+      0.27117345f, -0.36059442f, 0.81510246f, 0.61359257f, 0.07649982f, -0.44949868f, -0.54758865f, 0.4736983f,
+      0.21584567f, 0.21296778f, 0.093342215f, -0.09353682f, 0.61422515f, 0.19574627f, 0.0063361377f, -0.2465148f,
+      0.15675665f, -0.4546509f, 0.24447554f, 0.5921611f, -0.18192923f, -0.66116416f, -0.40265432f, 0.33475468f,
+      1.2906091f, 0.4709078f, 0.16256471f, 0.19308007f, 0.97568524f, 0.25876164f, -0.7964541f, -1.0319631f};
+  const std::vector<float> fc1_experts_weights = {
+      0.3860137f, 0.077925384f, 0.13434184f, 0.28902978f, 0.25391752f, -0.38351142f, 0.15813059f, 0.031481862f,
+      0.083209574f, 0.4039817f, -0.13558972f, -0.21858627f, -0.30475253f, 0.41026944f, -0.008697987f, -0.3412701f,
+      -0.16235226f, 0.054659843f, 0.21042877f, 0.28863233f, -0.49495423f, 0.14401567f, 0.39130414f, 0.154176f,
+      0.30897498f, -0.15768659f, 0.44641107f, 0.089463115f, -0.19318026f, 0.20710677f, -0.3552568f, -0.17219114f,
+      0.41923493f, -0.4233985f, -0.41503525f, 0.19466156f, -0.08633667f, 0.45547962f, -0.054792404f, 0.26722562f,
+      -0.09923202f, 0.3460176f, -0.49708033f, -0.41033173f, 0.10443485f, -0.39646107f, -0.37424505f, 0.1757198f,
+      0.43019837f, -0.13757241f, 0.14305532f, 0.37121457f, 0.2581259f, 0.12583363f, 0.45542932f, 0.16247797f,
+      0.15579104f, -0.19166303f, -0.109221935f, -0.36702687f, 0.40365517f, -0.21506298f, -0.36697525f, -0.2703231f,
+      -0.49740213f, -0.3486371f, 0.24005288f, -0.0048963428f, 0.20468098f, -0.09111178f, -0.1485982f, -0.088219464f,
+      0.33463532f, -0.49346995f, 0.42075223f, -0.38025302f, -0.245484f, -0.35191745f, 0.3086716f, -0.2423737f,
+      0.37881732f, -0.40608948f, 0.26193494f, -0.4283861f, -0.10062629f, -0.32670784f, -0.16040438f, -0.15297079f,
+      0.1822241f, 0.37285012f, 0.12654608f, -0.46767431f, -0.28775263f, 0.16585541f, -0.36678362f, -0.4759978f,
+      -0.34751755f, -0.3163945f, -0.3858195f, -0.38030273f, -0.06156373f, -0.04352224f, -0.4041785f, -0.335764f,
+      -0.10303855f, -0.4009425f, -0.1236487f, -0.40111196f, 0.23985302f, -0.118291676f, -0.26773083f, 0.121197104f,
+      0.3702919f, -0.34168184f, 0.33743858f, 0.24873763f, -0.23140603f, -0.25351608f, 0.48291886f, 0.13780516f,
+      0.25632292f, -0.49343884f, 0.08369112f, -0.37192065f, -0.05451995f, -0.44571918f, -0.24150735f, 0.27395487f,
+      -0.20423341f, -0.024149835f, 0.40208143f, -0.18211937f, -0.19767642f, -0.19397742f, -0.1510992f, 0.48074025f,
+      0.18377024f, -0.18288034f, 0.08111167f, 0.12729281f, 0.27861303f, 0.0076527f, 0.36356348f, -0.24359548f,
+      -0.33313757f, -0.374829f, -0.08705664f, 0.23576546f, -0.39819986f, -0.09880793f, -0.012998581f, -0.36475456f,
+      -0.32685202f, 0.29657948f, -0.4631365f, -0.06320876f, 0.31600899f, 0.060619473f, 0.39029974f, 0.401151f,
+      0.15562236f, 0.43565983f, -0.058149397f, 0.36150748f, 0.10750586f, -0.063970566f, -0.47026545f, -0.3035437f,
+      -0.38143605f, -0.4734699f, 0.31273925f, -0.43410504f, 0.07299572f, 0.47506f, 0.021913886f, -0.036100805f,
+      -0.31637233f, 0.37718338f, -0.046213806f, 0.19239199f, 0.13676548f, 0.33592474f, -0.34048676f, -0.11097133f,
+      -0.41569126f, -0.01680845f, 0.31357706f, 0.0943895f, -0.24053341f, -0.018784225f, 0.40659577f, 0.08897692f,
+      0.3793823f, -0.3271106f, 0.067666054f, -0.12331611f, -0.010209799f, -0.48908865f, 0.19195485f, -0.45211792f,
+      0.48282713f, 0.4363466f, -0.40184838f, -0.025082052f, -0.31057972f, 0.14850605f, 0.39756012f, -0.25782883f,
+      0.3181312f, 0.17685872f, -0.16694272f, -0.41516554f, -0.062004805f, -0.33060408f, -0.13665432f, -0.43781847f,
+      -0.298562f, 0.013283849f, 0.48130906f, -0.27970356f, 0.20347959f, -0.24402553f, -0.20528454f, -0.114435256f,
+      0.12556863f, -0.4344011f, 0.2868948f, 0.19894183f, -0.12849897f, -0.18726158f, -0.4850099f, -0.4352169f,
+      -0.40527463f, 0.13625044f, -0.49707252f, -0.45698053f, 0.28196156f, 0.16826987f, -0.25944453f, 0.2801003f,
+      0.21121234f, -0.04066527f, 0.45854944f, -0.17861038f, 0.18178529f, 0.17789757f, 0.34227383f, 0.26976448f,
+      0.15789884f, 0.22840887f, 0.419321f, -0.14490443f, 0.39608955f, -0.4162954f, -0.47072983f, 0.41119635f};
+  const std::vector<float> fc2_experts_weights = {
+      0.10833451f, 0.34020698f, -0.18258394f, -0.17842063f, -0.07365984f, -0.29177922f, -0.24102151f, 0.1077901f,
+      0.2932343f, -0.35068116f, 0.1875877f, 0.07474385f, -0.20955177f, -0.27660736f, -0.14290786f, -0.09014153f,
+      -0.21085852f, -0.2378315f, 0.21457997f, 0.21074237f, -0.21087126f, 0.14320332f, -0.08389844f, 0.24034885f,
+      0.31800103f, 0.12659892f, 0.20224877f, -0.2563875f, 0.11782206f, 0.29377612f, -0.27469966f, -0.18875091f,
+      0.32136288f, 0.0788243f, -0.26413083f, 0.18453442f, 0.0776935f, -0.19561274f, 0.12608862f, 0.18579696f,
+      0.045481127f, -0.17894714f, 0.27366453f, 0.13220324f, -0.3115706f, -0.016884197f, -0.3328494f, -0.062126897f,
+      0.14841764f, 0.19741052f, 0.08211302f, -0.09362138f, -0.053040292f, -0.090344846f, 0.18264277f, 0.037823465f,
+      -0.16197139f, -0.20172869f, 0.064109616f, -0.062456656f, 0.30368346f, -0.12107184f, -0.12590908f, -0.10535928f,
+      0.1978099f, 0.13119277f, 0.21948591f, -0.080250844f, -0.24614547f, 0.33202717f, 0.2645375f, -0.21193951f,
+      0.17770219f, -0.04986229f, 0.33435768f, -0.0309231f, 0.16043694f, -0.0027341924f, -0.08339601f, -0.17402375f,
+      0.2525901f, -0.0813988f, -0.2904943f, -0.14452116f, -0.27119386f, -0.2952116f, 0.0794895f, -0.11223866f,
+      0.25427446f, 0.16967128f, 0.19531254f, -0.33598322f, -0.16714293f, -0.35097876f, -0.35189477f, 0.2900932f,
+      0.26874313f, -0.1322388f, -0.330179f, 0.064027935f, 0.19688474f, -0.20129368f, 0.006225848f, 0.19252343f,
+      -0.35054854f, -0.31874785f, 0.32238203f, 0.29287276f, 0.03135616f, 0.015792634f, 0.20397249f, -0.3245995f,
+      0.21416605f, 0.15667121f, -0.2058509f, 0.23639117f, -0.032677338f, 0.07826358f, -0.04589425f, -0.24935842f,
+      -0.20834164f, 0.069915086f, -0.26063374f, 0.13239416f, 0.33705652f, -0.26813045f, -0.17056243f, 0.29919288f,
+      0.27704936f, -0.096224755f, 0.13250813f, 0.26709175f, -0.26995474f, 0.3261805f, -0.18062393f, -0.04732303f,
+      -0.02733084f, 0.050550338f, -0.2937818f, -0.19453493f, -0.34864828f, -0.20862648f, -0.19311349f, 0.17665526f,
+      -0.2894185f, -0.020016002f, 0.3409702f, -0.18320526f, 0.068286195f, 0.08490415f, 0.30223787f, -0.2386011f,
+      0.09405743f, 0.123811804f, 0.31660154f, -0.11290163f, 0.07494662f, -0.24999082f, 0.2075398f, 0.07419645f,
+      0.3327035f, -0.09647329f, 0.24138254f, -0.32546985f, 0.033594366f, 0.16555631f, 0.33516192f, -0.32619375f,
+      0.20476541f, -0.07724f, 0.018923176f, -0.21126744f, 0.2744358f, -0.23979841f, -0.30413106f, -0.3485449f,
+      0.2854276f, 0.14391156f, -0.24802732f, -0.21701548f, -0.122100174f, 0.054206114f, -0.21961808f, 0.13481297f,
+      -0.07907457f, 0.15763119f, -0.31156835f, 0.29488218f, 0.17039073f, 0.35125035f, -0.17721775f, -0.10516899f,
+      0.072144486f, -0.038529005f, -0.058253434f, 0.13062657f, -0.3312356f, -0.15963489f, -0.20129326f, 0.014987925f,
+      0.30869225f, 0.283981f, -0.057181682f, 0.15174268f, 0.22181617f, -0.19763571f, 0.28675067f, 0.0003976555f,
+      -0.34610963f, 0.2931936f, -0.26233214f, 0.19563977f, -0.16886877f, 0.022812065f, 0.080249704f, -0.2798801f,
+      0.11531327f, 0.07107194f, -0.34746924f, -0.051920194f, -0.07264093f, 0.27581826f, 0.18536879f, 0.15684144f,
+      -0.26691115f, -0.22811417f, -0.1498502f, -0.176639f, -0.25876564f, -0.16051741f, -0.0048792143f, -0.08490091f,
+      0.18136817f, 0.24729891f, 0.32358363f, -0.09566104f, 0.3074607f, -0.24191524f, -0.21220984f, -0.23039621f,
+      0.21154472f, -0.19495378f, 0.002779711f, -0.34692943f, 0.055384878f, 0.25809082f, 0.16814983f, 0.19935164f,
+      0.11652225f, 0.1115539f, -0.24407779f, 0.09392998f, 0.33556697f, 0.11422251f, 0.34336287f, -0.33113837f};
+  const std::vector<float> fc3_experts_weights = {
+      0.45783097f, -0.2863351f, 0.011728346f, -0.43760604f, 0.15407985f, 0.07818556f, 0.0013856292f, -0.34319758f,
+      -0.16871625f, 0.12490183f, -0.34154075f, -0.31836903f, -0.46634215f, -0.43996066f, -0.1860516f, -0.2917009f,
+      -0.1772582f, -0.06599659f, -0.42419833f, 0.49980444f, -0.3283869f, -0.21543652f, -0.034647882f, -0.17114872f,
+      -0.4837973f, -0.362943f, -0.27533132f, 0.09443748f, -0.16642791f, -0.2993343f, -0.33881485f, -0.39464045f,
+      0.31960344f, 0.007296145f, -0.45412838f, -0.024868786f, -0.16298121f, -0.44197202f, 0.07232875f, -0.32362783f,
+      0.42969978f, -0.029854119f, -0.18451887f, -0.30145288f, 0.16885209f, -0.30068123f, -0.12948537f, 0.36494362f,
+      -0.049498677f, 0.12020564f, 0.42106473f, -0.30590254f, 0.31881082f, -0.078908324f, 0.20685762f, -0.22735089f,
+      -0.11194843f, 0.14011681f, 0.19477749f, -0.44788343f, 0.23084867f, 0.48367476f, -0.19044077f, -0.100233376f,
+      0.4191656f, -0.4515314f, -0.3214385f, 0.016065598f, -0.4069137f, -0.17348295f, -0.43329984f, 0.33521235f,
+      -0.07843453f, -0.4865722f, -0.039011598f, -0.10605621f, 0.4192536f, 0.04063064f, 0.1984514f, 0.49294376f,
+      -0.056941032f, 0.18582922f, -0.16650558f, -0.17215621f, -0.20009357f, 0.46615022f, 0.47462142f, -0.0766145f,
+      -0.20405996f, -0.27452308f, -0.16176039f, -0.23940295f, 0.13248974f, 0.23036134f, 0.13154167f, 0.10377723f,
+      0.0070211887f, 0.29162645f, 0.34465307f, -0.4058748f, -0.13989884f, -0.12305027f, -0.2541607f, 0.4767149f,
+      0.4549045f, -0.108933926f, 0.2452516f, 0.054080307f, 0.33768386f, -0.45279485f, 0.1557768f, 0.17416143f,
+      -0.42602575f, -0.102350116f, 0.16022503f, 0.14813942f, 0.03982985f, -0.47012872f, -0.14555538f, 0.35645115f,
+      -0.1909796f, -0.20839584f, -0.28098184f, -0.23085594f, 0.022559166f, -0.23900753f, -0.19561106f, -0.24205637f,
+      0.2573983f, -0.2947166f, 0.4568925f, 0.11514187f, 0.18671238f, -0.121082425f, 0.3909887f, -0.10985571f,
+      -0.19420451f, -0.3255307f, 0.4863913f, 0.007830441f, 0.4648854f, -0.24156213f, 0.22956276f, -0.09216207f,
+      -0.29428315f, 0.26062596f, 0.14955276f, -0.036366224f, -0.12957954f, 0.08501935f, -0.36796576f, 0.041123867f,
+      0.06744653f, -0.0839923f, 0.17207885f, 0.006872058f, -0.21135789f, 0.3732242f, -0.2683524f, -0.45898575f,
+      -0.14543939f, 0.30806476f, 0.08574325f, 0.027492225f, -0.38164973f, -0.040038824f, -0.26947904f, -0.09740937f,
+      0.26697665f, -0.43565083f, 0.1359719f, 0.12271714f, 0.0149876475f, -0.44011843f, 0.26128954f, -0.42487514f,
+      -0.24668545f, 0.06113738f, -0.29119557f, 0.194273f, -0.24981815f, 0.3489496f, -0.47321397f, -0.31794417f,
+      -0.23641628f, 0.44169098f, -0.006898284f, 0.43446392f, -0.39553195f, 0.057907403f, -0.19339961f, -0.08160931f,
+      0.4979084f, -0.11149913f, 0.35366338f, -0.16032219f, -0.48278677f, 0.08397317f, 0.4008311f, 0.30288273f,
+      0.2546957f, -0.10675722f, 0.069722414f, 0.456497f, -0.19691509f, 0.49017924f, 0.41796166f, -0.2337895f,
+      -0.3635872f, -0.45445484f, -0.29122698f, -0.4339773f, 0.15762383f, 0.09782606f, -0.27986187f, -0.23860168f,
+      0.38454843f, -0.07870716f, 0.15390605f, -0.15793777f, 0.48130733f, 0.288768f, 0.45969498f, -0.4193731f,
+      -0.3218134f, -0.29914904f, -0.3426242f, 0.06931591f, -0.2633695f, -0.25429398f, 0.25366426f, -0.27700734f,
+      0.49418402f, -0.21919805f, 0.041192472f, -0.19817531f, -0.49578953f, 0.48185098f, -0.41920406f, -0.08335745f,
+      0.19111753f, -0.07547706f, 0.049694f, 0.13012594f, 0.2617172f, -0.22612399f, 0.32247066f, -0.33702326f,
+      0.20062232f, -0.09143996f, -0.063310504f, 0.1885702f, 0.11926836f, 0.3378734f, -0.45973647f, 0.48845494f};
+  const std::vector<float> output = {
+      0.026516449f, 0.04061616f, 0.04403834f, -0.13644142f, 0.038774252f, 0.024002096f, -0.061423667f, 0.034824893f,
+      -0.022858473f, 0.04693405f, -0.0120724365f, -0.028846134f, -0.0168579f, -0.07958221f, 0.048179876f, 0.053492386f,
+      -0.026292695f, -0.009724421f, -0.026503641f, 0.031220898f, 0.04189077f, 0.11775493f, -0.037770163f, -0.0790936f};
+
+  RunMoETest(input,
+             router_probs,
+             fc1_experts_weights,
+             fc2_experts_weights,
+             fc3_experts_weights,
+             {},
+             {},
+             output,
+             num_rows,
+             num_experts,
+             hidden_size,
+             inter_size,
+             "silu",
+             1, /*normalize_routing_weights*/
+             2 /*top_k*/);
+}
+
 }  // namespace test
 }  // namespace onnxruntime
diff --git a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc
index 31ef62e69bb88..09baf8def05f6 100644
--- a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc
@@ -433,8 +433,7 @@ static void RunModelWithRandomInput(
   std::vector<int64_t> token_offset_dims{batch_size, sequence_length};
   std::vector<int64_t> cum_seq_len_dims{batch_size + 1};
 
-  // TF32 in SM >= 80 is enabled by default, need larger threshold for float when TF32 is enabled.
-  float gpu_threshold = is_float16 ? 0.15f : (HasCudaEnvironment(800) ? 0.05f : 0.005f);
+  float gpu_threshold = is_float16 ? 0.15f : 0.005f;
   gpu_threshold *= sequence_length > 1024 ? 4.0f : 1.0f;  // threshold should increase with sequence length
   bool enable_cuda = HasCudaEnvironment(is_float16 ? 530 : 0);
   if (enable_cuda) {
diff --git a/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc
index 22253955566f2..5f811c8cf35f6 100644
--- a/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc
@@ -107,6 +107,7 @@ static void RunPackedMultiHeadAttentionTest(
       }
 
       tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
+      tester.SetOutputTolerance(0.005f);
     } else {
       if (is_packed_qkv) {
         tester.AddInput<float>("query", packed_qkv_dims, query_data);
@@ -131,6 +132,7 @@ static void RunPackedMultiHeadAttentionTest(
       }
 
       tester.AddOutput<float>("output", output_dims, output_data);
+      tester.SetOutputTolerance(0.001f, 0.001f);
     }
 
     std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
diff --git a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
index 3af334696a97d..54dd831fe2fc2 100644
--- a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
@@ -20,7 +20,8 @@ namespace test {
 enum class EP : char {
   CPU,
   CUDA,
-  DNNL
+  DNNL,
+  DML
 };
 
 // input:      [batch_size, sequence_length, hidden_size]
@@ -89,11 +90,13 @@ void RunQAttention(const std::vector<float>& input_data,
     tester.AddInput<MLFloat16>("input_scale", {1}, ToFloat16({input_quant_params.scale}));
     tester.AddInput<MLFloat16>("weight_scale", {1}, ToFloat16({weight_quant_params.scale}));
     tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
+    tester.SetOutputTolerance(0.01f);
   } else {
     tester.AddInput<float>("bias", bias_dims, bias_data);
     tester.AddInput<float>("input_scale", {1}, {input_quant_params.scale});
     tester.AddInput<float>("weight_scale", {1}, {weight_quant_params.scale});
     tester.AddOutput<float>("output", output_dims, output_data);
+    tester.SetOutputTolerance(0.005f);
   }
 
   if (mask_index_data.size() > 0) {
@@ -111,7 +114,9 @@ void RunQAttention(const std::vector<float>& input_data,
     execution_providers.push_back(DefaultCudaExecutionProvider());
   } else if constexpr (ep == EP::CPU) {
     execution_providers.push_back(DefaultCpuExecutionProvider());
-  } else {  // onednn ep
+  } else if constexpr (ep == EP::DML) {
+    execution_providers.push_back(DefaultDmlExecutionProvider());
+  } else {  //  onednn ep
     execution_providers.push_back(DefaultDnnlExecutionProvider());
   }
 
@@ -192,6 +197,52 @@ static void RunQAttentionDNNL(
 #endif
 }
 
+static void RunQAttentionDML(
+    const std::vector<float>& input_data,
+    const std::vector<float>& weights_data,
+    const std::vector<float>& bias_data,
+    const std::vector<int32_t>& mask_index_data,
+    const std::vector<float>& output_data,
+    int batch_size,
+    int sequence_length,
+    int hidden_size,
+    int number_of_heads,
+    bool use_special_quantize_parameter = true,
+    bool is_unidirectional = false,
+    int input_hidden_size = 0) {
+  // Return without running code if USE_DML is not defined
+#ifdef USE_DML
+  bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get());
+  if (enable_dml) {
+    quantization::Params<uint8_t> input_quant_params(/*scale=*/0.0f, /*zero_point=*/0);
+    quantization::Params<int8_t> weights_quant_params(/*scale=*/0.0f, /*zero_point=*/0);
+    if (use_special_quantize_parameter) {
+      input_quant_params.scale = 0.1f;
+      weights_quant_params.scale = 0.1f;
+      input_quant_params.zero_point = 128;
+      weights_quant_params.zero_point = 1;
+    }
+
+    RunQAttention<uint8_t, int8_t, EP::DML>(
+        input_data, weights_data, bias_data, mask_index_data, output_data, input_quant_params, weights_quant_params,
+        batch_size, sequence_length, hidden_size, number_of_heads, is_unidirectional, false, input_hidden_size);
+  }
+#else
+  ORT_UNUSED_PARAMETER(input_data);
+  ORT_UNUSED_PARAMETER(weights_data);
+  ORT_UNUSED_PARAMETER(bias_data);
+  ORT_UNUSED_PARAMETER(mask_index_data);
+  ORT_UNUSED_PARAMETER(output_data);
+  ORT_UNUSED_PARAMETER(batch_size);
+  ORT_UNUSED_PARAMETER(sequence_length);
+  ORT_UNUSED_PARAMETER(hidden_size);
+  ORT_UNUSED_PARAMETER(number_of_heads);
+  ORT_UNUSED_PARAMETER(use_special_quantize_parameter);
+  ORT_UNUSED_PARAMETER(is_unidirectional);
+  ORT_UNUSED_PARAMETER(input_hidden_size);
+#endif
+}
+
 static void RunQAttentionU8U8(
     const std::vector<float>& input_data,
     const std::vector<float>& weights_data,
@@ -272,6 +323,9 @@ static void RunQAttentionAll(
   RunQAttentionDNNL(input_data, weight_data, bias_data, mask_index_data, output_data,
                     batch_size, sequence_length, hidden_size, number_of_heads,
                     use_special_quantize_parameter, is_unidirectional, input_hidden_size);
+  RunQAttentionDML(input_data, weight_data, bias_data, mask_index_data, output_data,
+                   batch_size, sequence_length, hidden_size, number_of_heads,
+                   use_special_quantize_parameter, is_unidirectional, input_hidden_size);
 }
 
 // ONEDNN EP only supports 2D raw mask
@@ -859,8 +913,8 @@ void TestQuantizedAttentionPastState(int64_t batch,
   std::vector<int64_t> input_dims{batch, seq_len, hidden_size};
   std::vector<InputT> input_data = random.Gaussian<InputT>(input_dims, input_mean, static_cast<InputT>(input_range / 6), input_min, input_max);
 
-  constexpr WeightT weight_min = std::numeric_limits<WeightT>::min();
-  constexpr WeightT weight_max = std::numeric_limits<WeightT>::max();
+  constexpr WeightT weight_min = std::is_same_v<WeightT, int8_t> ? std::numeric_limits<int8_t>::min() / 2 : std::numeric_limits<WeightT>::min();
+  constexpr WeightT weight_max = std::numeric_limits<WeightT>::max() / 2;
   constexpr int32_t weight_range = weight_max - weight_min;
 
   std::vector<WeightT> weight_zero_point(weight_scale_zp_size);
diff --git a/onnxruntime/test/contrib_ops/sampling_test.cc b/onnxruntime/test/contrib_ops/sampling_test.cc
index 733bc9f01fd11..d987a1cae427d 100644
--- a/onnxruntime/test/contrib_ops/sampling_test.cc
+++ b/onnxruntime/test/contrib_ops/sampling_test.cc
@@ -8,6 +8,10 @@
 #include "core/session/onnxruntime_cxx_api.h"
 #include "test/common/cuda_op_test_utils.h"
 
+#ifdef USE_CUDA
+#include "core/providers/cuda/cuda_provider_options.h"
+#endif
+
 extern std::unique_ptr<Ort::Env> ort_env;
 
 namespace onnxruntime {
@@ -65,7 +69,10 @@ TEST(SamplingTest, Gpt2Sampling_GPU) {
     LOGS_DEFAULT(WARNING) << "Hardware NOT support current architecture";
     return;
   }
-  Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
+
+  OrtCUDAProviderOptionsV2 cuda_options;
+  cuda_options.use_tf32 = false;
+  session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
 #else  // USE_ROCM
   OrtROCMProviderOptions rocm_options;
   // TODO - verify the default settings
diff --git a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc
index fefd5722054de..ea8537f243f5d 100644
--- a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc
+++ b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc
@@ -114,16 +114,21 @@ TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) {
 
   int min_cuda_architecture = 530;
   bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
+  bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get());
 
   std::array<int, 2> channels_last_values = {-1, 1};
 
   for (const int channels_last : channels_last_values) {
-    if (enable_cuda) {
+    if (enable_cuda || enable_rocm) {
       std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
       if (enable_cuda && channels_last != 0) {
         execution_providers.push_back(DefaultCudaExecutionProvider());
       }
 
+      if (enable_rocm && channels_last != 0) {
+        execution_providers.push_back(DefaultRocmExecutionProvider());
+      }
+
       // Don't run the test if no providers are supported
       if (execution_providers.empty()) {
         continue;
@@ -230,6 +235,7 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) {
 
   int min_cuda_architecture = 530;
   bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
+  bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get());
 
   std::array<bool, 2> has_add_out_values = {true, false};
   std::array<int, 2> skip_dims = {2, 4};
@@ -237,12 +243,16 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) {
   constexpr int channels_last = 1;
   for (const int skip_dim : skip_dims) {
     for (const bool has_add_out : has_add_out_values) {
-      if (enable_cuda) {
+      if (enable_cuda || enable_rocm) {
         std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
         if (enable_cuda && channels_last != 0) {
           execution_providers.push_back(DefaultCudaExecutionProvider());
         }
 
+        if (enable_rocm && channels_last != 0) {
+          execution_providers.push_back(DefaultRocmExecutionProvider());
+        }
+
         // Don't run the test if no providers are supported
         if (execution_providers.empty()) {
           continue;
diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h
new file mode 100644
index 0000000000000..6ea8b55505214
--- /dev/null
+++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h
@@ -0,0 +1,203 @@
+/**
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the MIT License.
+ *
+ * Module Name:
+ *    blkq4_fp16_quant_sm80.h
+ *
+ * Abstract:
+ *   Oracle computation for blockwise 4b quantization for fp16
+ *   gemm kernel specifically for Ampere GPUs. This is used for
+ *   testing the cuda kernel implementation in
+ *   (test/providers/cuda/test_cases)
+ *   and for testing the cuda op prepack code in (test/optimizer)
+ */
+
+#pragma once
+
+#include "core/util/matrix_layout.h"
+#include "core/common/common.h"
+
+namespace onnxruntime {
+namespace test {
+
+static inline void sm80_prepack_weights_ref(
+    int rows,
+    int columns,
+    const MatrixRef<uint8_t const, ColumnMajorLayout, true>& tensor_weight,
+    const MatrixRef<uint8_t, ColumnMajorLayout, true>& tensor_weight_prepacked) {
+  ORT_ENFORCE(tensor_weight.shape()[0] == rows / 2 && tensor_weight.shape()[1] == columns,
+              "Unexpected tensor_weight shape! Expected: (", rows / 2, ", ", columns, "), Got: (",
+              tensor_weight.shape()[0], ", ", tensor_weight.shape()[1], ").");
+  ORT_ENFORCE(tensor_weight_prepacked.shape()[0] == rows && tensor_weight_prepacked.shape()[1] == columns / 2,
+              "tensor_weight_prepacked shape is not compatible with prepacked weight shape");
+
+  auto t0_base = make_Position(0, 0);
+  auto t1_base = make_Position(4, 0);
+  auto t2_base = make_Position(0, 8);
+  auto t3_base = make_Position(4, 8);
+  for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) {
+    for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) {
+      // Packing from a 8x16 tile to a 16x8 tile
+      auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16);
+      auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8);
+      for (int col = 0; col < 8; ++col) {
+        for (int row = 0; row < 4; ++row) {
+          auto cord = make_Position(row, col);
+          auto packed_cord = packed_tile_base + make_Position(row * 4, col);  // packed tile is 16x8
+          uint8_t buf[4];
+          buf[0] = tensor_weight.at(dtile_base + t0_base + cord);
+          buf[1] = tensor_weight.at(dtile_base + t1_base + cord);
+          buf[2] = tensor_weight.at(dtile_base + t2_base + cord);
+          buf[3] = tensor_weight.at(dtile_base + t3_base + cord);
+
+          // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights
+          // are in different b16 register at the same positions. This makes it easier to convert to
+          // fp16x2 format in a b32 register
+
+          tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4);
+          tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4);
+          tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0);
+          tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0);
+        }
+      }
+    }
+  }
+}
+
+template <
+    typename ScaleElementT,
+    typename Layout,
+    typename QuantBlocking>
+inline void sm80_prepack_quant_scales_ref(
+    int rows,
+    int columns,
+    const MatrixRef<ScaleElementT const, Layout, true>& tensor_scale,
+    const MatrixRef<ScaleElementT, Layout, true>& tensor_scale_prepacked) {
+  ORT_ENFORCE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn),
+              "Unexpected tensor_scale shape! Expected: (",
+              rows / QuantBlocking::kRow, ", ", columns / QuantBlocking::kColumn, ")");
+  ORT_ENFORCE(tensor_scale_prepacked.shape() == tensor_scale.shape());
+
+  // Only prepacking scale and offset tensors for a often used special case:
+  //    16b gemm (2 elements per 32b register, operand tile shape 8x8)
+  //    2 B operand tiles per mma instruction stacked on k dimension
+  //    (1,n) quantization blocking
+  if constexpr (sizeof(ScaleElementT) != 2 || QuantBlocking::kRow != 1) {
+    ORT_THROW("sm80_prepack_quant_scales_ref should only be called for row-wise block quantization on 16b float values.");
+  }
+
+  // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread
+  // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use
+  // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension,
+  // as shown below (T stands for thread):
+  // T0, T4, T8, T12
+  // T1, T5, T9, T13
+  // T2, T6, T10, T14
+  // T3, T7, T11, T15
+  // T0, T4, T8, T12
+  // T1, T5, T9, T13
+  // T2, T6, T10, T14
+  // T3, T7, T11, T15
+  //
+  // We need to deliver quantization scale and offset elements to the corresponding threads,
+  // so we can perform dequantization efficiently. With a column major layout, each thread
+  // needs two separate loads for a mma instruction, due to the tile fragment layout shown
+  // above. To reduce the number of loads, we rearrange each column as below, so we can use
+  // a single load to load fragments for two tiles:
+  // T0        T0
+  // T1        T0
+  // T2        T1
+  // T3   =>   T1
+  // T0        T2
+  // T1        T2
+  // T2        T3
+  // T3        T3
+
+  for (int col = 0; col < tensor_scale.shape()[1]; ++col) {
+    for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) {
+      for (int thread_id = 0; thread_id < 4; thread_id++) {
+        const int dst_idx = row_blk + thread_id * 4;
+        const int src_idx = row_blk + thread_id * 2;
+        tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col);
+        tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col);
+        tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col);
+        tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col);
+      }
+    }
+  }
+}
+
+template <typename Layout, typename QuantBlocking>
+inline void sm80_prepack_quant_offsets_ref(
+    int rows,
+    int columns,
+    MatrixRef<uint8_t const, Layout, true> tensor_offset,
+    MatrixRef<uint8_t, Layout, true> tensor_offset_prepacked) {
+  const auto meta_shape = make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn);
+  const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]);
+  ORT_ENFORCE(tensor_offset_prepacked.shape() == meta_shape,
+              "Unexpected tensor_offset_prepacked shape (",
+              tensor_offset_prepacked.shape()[0], ",", tensor_offset_prepacked.shape()[1],
+              ")! Expected: (", meta_shape[0], ", ", meta_shape[1], ")");
+  ORT_ENFORCE(tensor_offset.shape() == zp_shape,
+              "Unexpected tensor_offset shape (",
+              tensor_offset.shape()[0], ",", tensor_offset.shape()[1],
+              ")! Expected: (", zp_shape[0], ", ", zp_shape[1], ")");
+
+  // Only prepacking scale and offset tensors for a often used special case:
+  //    16b gemm (2 elements per 32b register, operand tile shape 8x8)
+  //    2 B operand tiles per mma instruction stacked on k dimension
+  //    (1,n) quantization blocking
+  if constexpr (QuantBlocking::kRow != 1) {
+    ORT_THROW("sm80_prepack_quant_offsets_ref should only be called for row-wise block quantization.");
+  }
+  // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread
+  // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use
+  // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension,
+  // as shown below (T stands for thread):
+  // T0, T4, T8, T12
+  // T1, T5, T9, T13
+  // T2, T6, T10, T14
+  // T3, T7, T11, T15
+  // T0, T4, T8, T12
+  // T1, T5, T9, T13
+  // T2, T6, T10, T14
+  // T3, T7, T11, T15
+  //
+  // We need to deliver quantization scale and offset elements to the corresponding threads,
+  // so we can perform dequantization efficiently. With a column major layout, each thread
+  // needs two separate loads for a mma instruction, due to the tile fragment layout shown
+  // above. To reduce the number of loads, we rearrange each column as below, so we can use
+  // a single load to load fragments for two tiles:
+  // T0        T0
+  // T1        T0
+  // T2        T1
+  // T3   =>   T1
+  // T0        T2
+  // T1        T2
+  // T2        T3
+  // T3        T3
+  if (tensor_offset_prepacked.good()) {
+    for (int col = 0; col < tensor_offset_prepacked.shape()[1]; ++col) {
+      for (int row_blk = 0; row_blk < tensor_offset_prepacked.shape()[0]; row_blk += 16) {
+        for (int thread_id = 0; thread_id < 4; thread_id++) {
+          const int dst_idx = row_blk + thread_id * 4;
+          const int src_idx = row_blk + thread_id * 2;
+          // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own
+          // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to
+          // convert to fp16x2 format in a b32 register
+          uint8_t pair01 = tensor_offset.at(src_idx / 2, col);
+          uint8_t pair89 = tensor_offset.at((src_idx + 8) / 2, col);
+          tensor_offset_prepacked.at(dst_idx + 0, col) = pair01 & 0xf;
+          tensor_offset_prepacked.at(dst_idx + 1, col) = pair89 & 0xf;
+          tensor_offset_prepacked.at(dst_idx + 2, col) = pair01 >> 4;
+          tensor_offset_prepacked.at(dst_idx + 3, col) = pair89 >> 4;
+        }
+      }
+    }
+  }
+}
+
+}  // namespace test
+}  // namespace onnxruntime
diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc
index d7b1de5c930c5..3e0d94e94e48c 100644
--- a/onnxruntime/test/framework/allocation_planner_test.cc
+++ b/onnxruntime/test/framework/allocation_planner_test.cc
@@ -1974,6 +1974,74 @@ TEST_F(PlannerTest, TestCpuIf) {
     ASSERT_TRUE(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()) == WaitOnEPStep);
   }
 }
+
+// model looks like:
+//                                                 |-----------> Gather
+//                                                 |-----------> Gather
+//                                                 |-----------> Gather
+//                                                 |-----------> Gather
+// Shape ----------------> Reshape --> Shape ------------------> Reshape
+//                           ^                                     ^
+// InstanceNormalization ----|         InstanceNormalization ------|
+//
+// Python script to create this model:
+// def CreateModelFor19480():
+//    #shape->reshape->shape->reshape, 4 gather
+//    graphNodes = []
+//    graphNodes.append(h.make_node('Shape', inputs=['shape_input'], outputs=['9']))
+//    graphNodes.append(h.make_node('InstanceNormalization', inputs=['in0_input', 'scale0', 'B0'], outputs=['8']))
+//    graphNodes.append(h.make_node('Reshape', inputs=['8', '9'], outputs=['Reshape15_output']))
+//    graphNodes.append(h.make_node('Shape', inputs=['Reshape15_output'], outputs=['281']))
+//    graphNodes.append(h.make_node('InstanceNormalization', inputs=['in1_input', 'scale1', 'B1'], outputs=['293']))
+//    graphNodes.append(h.make_node('Reshape', inputs=['293', '281'], outputs=['output0']))
+//    graphNodes.append(h.make_node('Gather', inputs=['281', 'indices1'], outputs=['output1']))
+//    graphNodes.append(h.make_node('Gather', inputs=['281', 'indices2'], outputs=['output2']))
+//    graphNodes.append(h.make_node('Gather', inputs=['281', 'indices3'], outputs=['output3']))
+//    graphNodes.append(h.make_node('Gather', inputs=['281', 'indices4'], outputs=['output4']))
+//    g = h.make_graph(graphNodes, 'issue_19480',
+//                     [h.make_tensor_value_info('shape_input', tp.FLOAT, ['batch', 128, None, None]),
+//                      h.make_tensor_value_info('in0_input', tp.FLOAT, ['batch', 32, None]),
+//                      h.make_tensor_value_info('scale0', tp.FLOAT, [32]),
+//                      h.make_tensor_value_info('B0', tp.FLOAT, [32]),
+//                      h.make_tensor_value_info('in1_input', tp.FLOAT, ['batch', 32, None]),
+//                      h.make_tensor_value_info('scale1', tp.FLOAT, [32]),
+//                      h.make_tensor_value_info('B1', tp.FLOAT, [32]),
+//                      h.make_tensor_value_info('indices1', tp.INT32, []),
+//                      h.make_tensor_value_info('indices2', tp.INT32, []),
+//                      h.make_tensor_value_info('indices3', tp.INT32, []),
+//                      h.make_tensor_value_info('indices4', tp.INT32, [])],
+//                     [h.make_tensor_value_info('output0', tp.FLOAT, None),
+//                      h.make_tensor_value_info('output1', tp.INT64, None),
+//                      h.make_tensor_value_info('output2', tp.INT64, None),
+//                      h.make_tensor_value_info('output3', tp.INT64, None),
+//                      h.make_tensor_value_info('output4', tp.INT64, None)])
+//    model = h.make_model(g, opset_imports=[h.make_operatorsetid("", 17)], producer_name='producer_name')
+//    onnx.save(model, 'issue_19480.onnx')
+//
+TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) {
+  SessionOptions sess_opt;
+  sess_opt.graph_optimization_level = TransformerLevel::Default;
+
+  InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/issue_19480.onnx"));
+  auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider());
+  status = sess.Load();
+  status = sess.Initialize();
+  ASSERT_TRUE(status.IsOK()) << "No crash";
+  const SequentialExecutionPlan* plan = sess.GetSessionState().GetExecutionPlan();
+  ASSERT_EQ(plan->allocation_plan[14].alloc_kind, AllocKind::kReuse) << "The input of reshape and gather will reuse the output of shape";
+
+  int gather_count = 0;
+  for (size_t i = 0; i < plan->execution_plan[1]->steps_.size(); i++) {
+    if (strstr(typeid(*(plan->execution_plan[1]->steps_[i])).name(), "LaunchKernelStep")) {
+      const Node* node = sess.GetSessionState().GetGraphViewer().GetNode(plan->execution_plan[1]->steps_[i]->GetNodeIndex());
+      if (node->OpType() == "Gather")
+        gather_count++;
+      else
+        FAIL() << "CPU stream should contain only gather ops";
+    }
+  }
+  ASSERT_EQ(gather_count, 4) << "4 gather ops are all placed in CPU stream";
+}
 #endif
 }  // namespace test
 }  // namespace onnxruntime
diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc
index 60effda9ec772..d0520ebbcba5a 100644
--- a/onnxruntime/test/framework/inference_session_test.cc
+++ b/onnxruntime/test/framework/inference_session_test.cc
@@ -2944,6 +2944,11 @@ TEST(InferenceSessionTests, GlobalThreadPoolWithDenormalAsZero) {
 }
 
 // test inter thread pool with setting denormal as zero
+#if !defined(__APPLE__)
+// TODO (hasesh): Debug this test failure on MacOS 12 with XCode 14.2
+// It seemingly passes on MacOS 13 with XCode 15.x but we had to drop down to Mac OS 12
+// because at the time of writing this, Mac OS 13 images were making CI/Packaging pipelines
+// very unstable.
 TEST(InferenceSessionTests, InterThreadPoolWithDenormalAsZero) {
   if constexpr (!SessionOptions::DEFAULT_USE_PER_SESSION_THREADS) {
     GTEST_SKIP() << "Skipping the test";
@@ -3001,6 +3006,7 @@ TEST(InferenceSessionTests, InterThreadPoolWithDenormalAsZero) {
   VerifyThreadPoolWithDenormalAsZero(session2.GetIntraOpThreadPoolToUse(), false);
   VerifyThreadPoolWithDenormalAsZero(session2.GetInterOpThreadPoolToUse(), false);
 }
+#endif
 
 }  // namespace test
 }  // namespace onnxruntime
diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc
index bfabcd567803b..f5258760eb20d 100644
--- a/onnxruntime/test/framework/shape_inference_test.cc
+++ b/onnxruntime/test/framework/shape_inference_test.cc
@@ -5,13 +5,16 @@
 #include <unordered_map>
 
 #include "gtest/gtest.h"
+#include "core/common/span_utils.h"
 #include "core/graph/model.h"
+#include "core/session/onnxruntime_cxx_api.h"
 #include "test/framework/model_builder_utils.h"
+#include "test/util/include/asserts.h"
 #include "test/util/include/test_utils.h"
+#include "test/util/include/inference_session_wrapper.h"
 #include "test/test_environment.h"
 
 using namespace ONNX_NAMESPACE;
-using namespace std;
 
 namespace onnxruntime {
 namespace test {
@@ -22,7 +25,7 @@ class ShapeInferenceTest : public ::testing::Test {
  protected:
   onnxruntime::Model model_;
   int node_count_;
-  std::unordered_map<string, std::unique_ptr<onnxruntime::NodeArg>> name_to_arg_;
+  std::unordered_map<std::string, std::unique_ptr<onnxruntime::NodeArg>> name_to_arg_;
 
  public:
   ShapeInferenceTest() : model_("Test", false, DefaultLoggingManager().DefaultLogger()), node_count_(0) {}
@@ -73,5 +76,91 @@ TEST_F(ShapeInferenceTest, BasicTest) {
   CheckShapeEquality(InputShape(node), OutputShape(node));
 }
 
+namespace {
+struct MyCustomKernelWithOptionalInput {
+  MyCustomKernelWithOptionalInput(const OrtKernelInfo* /*info*/) {
+  }
+
+  OrtStatusPtr ComputeV2(OrtKernelContext* /* context */) const {
+    return nullptr;
+  }
+};
+
+struct MyCustomOpWithOptionalInput : Ort::CustomOpBase<MyCustomOpWithOptionalInput,
+                                                       MyCustomKernelWithOptionalInput,
+                                                       true> {
+  explicit MyCustomOpWithOptionalInput(const char* provider) : provider_(provider) {}
+
+  OrtStatusPtr CreateKernelV2(const OrtApi& /* api */, const OrtKernelInfo* info, void** kernel) const {
+    *kernel = new MyCustomKernelWithOptionalInput(info);
+    return nullptr;
+  };
+
+  const char* GetName() const { return "FooBar"; };
+  const char* GetExecutionProviderType() const { return provider_; };
+
+  size_t GetInputTypeCount() const { return 3; };
+  ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
+  OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const {
+    // The second input (index == 1) is optional
+    if (index == 1)
+      return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
+
+    return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
+  }
+
+  size_t GetOutputTypeCount() const { return 1; };
+  ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
+  OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
+    return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
+  }
+
+ private:
+  const char* provider_;
+};
+
+const ORTCHAR_T* const OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2 = ORT_TSTR("testdata/foo_bar_2.onnx");
+
+}  // namespace
+
+// CustomOps Output type inference function quits if it
+// encounters the an output that is optional and absent.
+// It quits without any errors or logging. We want to make sure
+// that inference proceeds for all of the outputs when absent optional inputs are present
+TEST(ShapeInferenceCustomOpTest, custom_op_optional_input_inference_test) {
+  MyCustomOpWithOptionalInput custom_op{onnxruntime::kCpuExecutionProvider};
+
+  const auto& env = GetEnvironment();
+
+  Ort::CustomOpDomain op_domain("test");
+  op_domain.Add(&custom_op);
+
+  std::initializer_list<OrtCustomOpDomain*> op_domains = {static_cast<OrtCustomOpDomain*>(op_domain)};
+
+  SessionOptions sess_opts;
+  sess_opts.inter_op_param.thread_pool_size = 1;
+  sess_opts.intra_op_param.thread_pool_size = 1;
+
+  InferenceSessionWrapper session{sess_opts, env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2};
+  ASSERT_STATUS_OK(session.AddCustomOpDomains(AsSpan(op_domains)));
+
+  ASSERT_STATUS_OK(session.Load());
+  ASSERT_STATUS_OK(session.Initialize());
+
+  const onnxruntime::Model& model = session.GetModel();
+  const auto& graph = model.MainGraph();
+  const auto& nodes = graph.Nodes();
+  for (const auto& node : nodes) {
+    if (node.OpType() == "FooBar") {
+      // check inferred shapes
+      const auto* node_arg = node.OutputDefs()[0];
+      const auto* type_proto = node_arg->TypeAsProto();
+      ASSERT_NE(nullptr, type_proto);
+      ASSERT_EQ(ONNX_NAMESPACE::TypeProto::ValueCase::kTensorType, type_proto->value_case());
+      ASSERT_EQ(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, type_proto->tensor_type().elem_type());
+    }
+  }
+}
+
 }  // namespace test
 }  // namespace onnxruntime
diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp
index b7b453415838a..04f5947e1371c 100644
--- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp
+++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp
@@ -5,26 +5,30 @@
 #include "mlas_qnbit.h"
 
 #include <memory>
+#include <sstream>
 #include <stdexcept>
 #include <vector>
 
 #include "benchmark/benchmark.h"
 
 #include "bench_util.h"
-#include "core/util/thread_utils.h"
 #include "core/common/narrow.h"
+#include "core/util/thread_utils.h"
+#include "core/platform/env_var_utils.h"
 
 using onnxruntime::narrow;
 
 template <size_t BlkBitWidth>
-void SQNBITGEMM(benchmark::State& state) {
-  const auto BlkLen = narrow<size_t>(state.range(0));
-  const auto M = narrow<size_t>(state.range(1));
-  const auto N = narrow<size_t>(state.range(2));
-  const auto K = narrow<size_t>(state.range(3));
-  const auto Threads = narrow<size_t>(state.range(4));
-  const auto Symmetric = narrow<bool>(state.range(5));
-  const auto ComputeType = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(state.range(6));
+void RunSQNBitGemmBenchmark(size_t BlkLen,
+                            size_t M, size_t N, size_t K,
+                            size_t Threads,
+                            bool Symmetric,
+                            MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
+                            benchmark::State& state) {
+  if (!MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) {
+    state.SkipWithMessage("SQNBitGemm is not available with the given configuration on the current machine.");
+    return;
+  }
 
   size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes;
   MlasBlockwiseQuantizedBufferSizes(
@@ -88,28 +92,57 @@ void SQNBITGEMM(benchmark::State& state) {
   }
 }
 
-static void SQ4BitGemmArgs(benchmark::internal::Benchmark* b) {
-  constexpr size_t BlkBitWidth = 4;
+template <size_t BlkBitWidth>
+void SQNBITGEMM(benchmark::State& state) {
+  const auto BlkLen = narrow<size_t>(state.range(0));
+  const auto M = narrow<size_t>(state.range(1));
+  const auto N = narrow<size_t>(state.range(2));
+  const auto K = narrow<size_t>(state.range(3));
+  const auto Threads = narrow<size_t>(state.range(4));
+  const auto Symmetric = narrow<bool>(state.range(5));
+  const auto ComputeType = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(state.range(6));
+
+  RunSQNBitGemmBenchmark<BlkBitWidth>(BlkLen, M, N, K, Threads, Symmetric, ComputeType, state);
+}
+
+// This test gets benchmark arguments from environment variables.
+template <size_t BlkBitWidth>
+void SQNBITGEMM_ENV(benchmark::State& state) {
+  using onnxruntime::ParseEnvironmentVariableWithDefault;
+
+  const auto BlkLen = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_BLKLEN", 32);
+  const auto M = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_M", 1);
+  const auto N = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_N", 4096);
+  const auto K = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_K", 4096);
+  const auto Threads = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_THREADS", 1);
+  const auto Symmetric = ParseEnvironmentVariableWithDefault<bool>("ORT_SQNBITGEMM_SYMMETRIC", true);
+  const auto ComputeType = ParseEnvironmentVariableWithDefault<int32_t>("ORT_SQNBITGEMM_COMPUTE_TYPE",
+                                                                        static_cast<int32_t>(CompFp32));
+
+  RunSQNBitGemmBenchmark<BlkBitWidth>(BlkLen, M, N, K, Threads, Symmetric,
+                                      static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(ComputeType),
+                                      state);
+
+  std::ostringstream s;
+  s << "BlkBitWidth:" << BlkBitWidth << "/BlkLen:" << BlkLen
+    << "/M:" << M << "/N:" << N << "/K:" << K
+    << "/Threads:" << Threads << "/Symmetric:" << Symmetric << "/ComputeType:" << ComputeType;
+  state.SetLabel(s.str());
+}
 
+static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) {
   b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "ComputeType"});
 
-  ArgsProductWithFilter(b,
-
-                        {{16, 32, 64, 128, 256},                   // BlkLen
-                         {1, 1024, 2048},                          // M
-                         {4096, 11008},                            // N
-                         {4096, 11008},                            // K
-                         {1, 8},                                   // Threads
-                         {int64_t{false}, int64_t{true}},          // Symmetric
-                         {int64_t{CompFp32}, int64_t{CompInt8}}},  // ComputeType
-
-                        [&](const std::vector<int64_t>& args) {
-                          return MlasIsSQNBitGemmAvailable(
-                              // BlkBitWidth, BlkLen
-                              BlkBitWidth, narrow<size_t>(args[0]),
-                              // ComputeType
-                              static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(args[6]));
-                        });
+  b->ArgsProduct({
+      {16, 32, 64, 128, 256},                  // BlkLen
+      {1, 1024, 2048},                         // M
+      {4096, 11008},                           // N
+      {4096, 11008},                           // K
+      {1, 8},                                  // Threads
+      {int64_t{false}, int64_t{true}},         // Symmetric
+      {int64_t{CompFp32}, int64_t{CompInt8}},  // ComputeType
+  });
 }
 
-BENCHMARK(SQNBITGEMM<4>)->Apply(SQ4BitGemmArgs)->UseRealTime();
+BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime();
+BENCHMARK(SQNBITGEMM_ENV<4>)->UseRealTime();
diff --git a/onnxruntime/test/mlas/bench/bench_util.cpp b/onnxruntime/test/mlas/bench/bench_util.cpp
index d57564615b04e..6b59b7e01b46f 100644
--- a/onnxruntime/test/mlas/bench/bench_util.cpp
+++ b/onnxruntime/test/mlas/bench/bench_util.cpp
@@ -22,30 +22,3 @@ std::vector<float> RandomVectorUniform(std::vector<int64_t> shape, float min_val
   }
   return RandomVectorUniform(static_cast<size_t>(sz), min_value, max_value);
 }
-
-void ArgsProductWithFilter(benchmark::internal::Benchmark* bench,
-                           const std::vector<std::vector<int64_t>>& arglists,
-                           std::function<bool(const std::vector<int64_t>& args)> include_filter) {
-  std::vector<std::size_t> indices(arglists.size(), 0);
-  const std::size_t total = std::accumulate(
-      std::begin(arglists), std::end(arglists), std::size_t{1},
-      [](const std::size_t res, const std::vector<int64_t>& arglist) {
-        return res * arglist.size();
-      });
-  std::vector<int64_t> args;
-  args.reserve(arglists.size());
-  for (std::size_t i = 0; i < total; i++) {
-    for (std::size_t arg = 0; arg < arglists.size(); arg++) {
-      args.push_back(arglists[arg][indices[arg]]);
-    }
-    if (include_filter(args)) {
-      bench->Args(args);
-    }
-    args.clear();
-
-    std::size_t arg = 0;
-    do {
-      indices[arg] = (indices[arg] + 1) % arglists[arg].size();
-    } while (indices[arg++] == 0 && arg < arglists.size());
-  }
-}
diff --git a/onnxruntime/test/mlas/bench/bench_util.h b/onnxruntime/test/mlas/bench/bench_util.h
index ee2ec42d0f755..f96dd5c673b3d 100644
--- a/onnxruntime/test/mlas/bench/bench_util.h
+++ b/onnxruntime/test/mlas/bench/bench_util.h
@@ -8,12 +8,6 @@
 #include <functional>
 #include <random>
 
-// Specifies benchmark arguments from the cartesian product of `arglists`, like Benchmark::ArgsProduct().
-// `include_filter` is called to determine whether a given set of arguments should be included.
-void ArgsProductWithFilter(benchmark::internal::Benchmark* bench,
-                           const std::vector<std::vector<int64_t>>& arglists,
-                           std::function<bool(const std::vector<int64_t>& args)> include_filter);
-
 template <typename ElementType>
 std::vector<ElementType> RandomVectorUniform(
     size_t N,
diff --git a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp
index 484a9a22429d5..969997d2b84ec 100644
--- a/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp
+++ b/onnxruntime/test/mlas/unittest/test_fp16_activation.cpp
@@ -2,6 +2,7 @@
 // Licensed under the MIT License.
 
 #include "test_fp16.h"
+#include <iomanip>
 
 #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
 
diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc
index 57c2061883736..09c8ae213bad2 100644
--- a/onnxruntime/test/onnx/TestCase.cc
+++ b/onnxruntime/test/onnx/TestCase.cc
@@ -267,12 +267,12 @@ void LoopDataFile(int test_data_pb_fd, bool is_input, const TestModelInfo& model
 }  // namespace
 
 #if !defined(ORT_MINIMAL_BUILD)
-std::unique_ptr<TestModelInfo> TestModelInfo::LoadOnnxModel(_In_ const PATH_CHAR_TYPE* model_url) {
+std::unique_ptr<TestModelInfo> TestModelInfo::LoadOnnxModel(const std::filesystem::path& model_url) {
   return std::make_unique<OnnxModelInfo>(model_url);
 }
 #endif
 
-std::unique_ptr<TestModelInfo> TestModelInfo::LoadOrtModel(_In_ const PATH_CHAR_TYPE* model_url) {
+std::unique_ptr<TestModelInfo> TestModelInfo::LoadOrtModel(const std::filesystem::path& model_url) {
   return std::make_unique<OnnxModelInfo>(model_url, true);
 }
 
@@ -290,7 +290,7 @@ class OnnxTestCase : public ITestCase {
   mutable std::vector<std::string> debuginfo_strings_;
   mutable onnxruntime::OrtMutex m_;
 
-  std::vector<std::basic_string<PATH_CHAR_TYPE>> test_data_dirs_;
+  std::vector<std::filesystem::path> test_data_dirs_;
 
   std::string GetDatasetDebugInfoString(size_t dataset_id) const override {
     std::lock_guard<OrtMutex> l(m_);
@@ -343,7 +343,7 @@ class OnnxTestCase : public ITestCase {
 
   size_t GetDataCount() const override { return test_data_dirs_.size(); }
   const std::string& GetNodeName() const override { return model_info_->GetNodeName(); }
-  const PATH_CHAR_TYPE* GetModelUrl() const override { return model_info_->GetModelUrl(); }
+  const std::filesystem::path& GetModelUrl() const override { return model_info_->GetModelUrl(); }
   const std::string& GetTestCaseName() const override { return test_case_name_; }
   std::string GetTestCaseVersion() const override { return model_info_->GetNominalOpsetVersion(); }
 
@@ -396,7 +396,14 @@ static std::string trim_str(const std::string& in) {
   return s;
 }
 
-static bool read_config_file(const std::basic_string<PATH_CHAR_TYPE>& path, std::map<std::string, std::string>& fc) {
+/**
+ * @brief Read a text file that each line is a key value pair separated by ':'
+ * @param path File path
+ * @param fc output key value pairs
+ * @return True, success. False, the file doesn't exist or could be read.
+ */
+static bool ReadConfigFile(const std::filesystem::path& path, std::map<std::string, std::string>& fc) {
+  if (!std::filesystem::exists(path)) return false;
   std::ifstream infile(path);
   if (!infile.good()) {
     return false;
@@ -474,10 +481,10 @@ void OnnxTestCase::LoadTestData(size_t id, onnxruntime::test::HeapBuffer& b,
     ORT_THROW("index out of bound");
   }
 
-  PATH_STRING_TYPE test_data_pb = ConcatPathComponent(
-      test_data_dirs_[id], (is_input ? ORT_TSTR("inputs.pb") : ORT_TSTR("outputs.pb")));
+  std::filesystem::path test_data_pb =
+      test_data_dirs_[id] / (is_input ? ORT_TSTR("inputs.pb") : ORT_TSTR("outputs.pb"));
   int test_data_pb_fd;
-  auto st = Env::Default().FileOpenRd(test_data_pb, test_data_pb_fd);
+  auto st = Env::Default().FileOpenRd(test_data_pb.string(), test_data_pb_fd);
   if (st.IsOK()) {  // has an all-in-one input file
     std::ostringstream oss;
     {
@@ -504,21 +511,23 @@ void OnnxTestCase::LoadTestData(size_t id, onnxruntime::test::HeapBuffer& b,
 
   std::vector<PATH_STRING_TYPE> test_data_pb_files;
 
-  const PATH_STRING_TYPE& dir_path = test_data_dirs_[id];
-  LoopDir(dir_path,
-          [&test_data_pb_files, &dir_path, is_input](const PATH_CHAR_TYPE* filename, OrtFileType f_type) -> bool {
-            if (filename[0] == '.') return true;
-            if (f_type != OrtFileType::TYPE_REG) return true;
-            std::basic_string<PATH_CHAR_TYPE> filename_str = filename;
-            if (!HasExtensionOf(filename_str, ORT_TSTR("pb"))) return true;
-            const std::basic_string<PATH_CHAR_TYPE> file_prefix =
-                is_input ? ORT_TSTR("input_") : ORT_TSTR("output_");
-            if (!filename_str.compare(0, file_prefix.length(), file_prefix)) {
-              std::basic_string<PATH_CHAR_TYPE> p = ConcatPathComponent(dir_path, filename_str);
-              test_data_pb_files.push_back(p);
-            }
-            return true;
-          });
+  std::filesystem::path dir_fs_path = test_data_dirs_[id];
+  if (!std::filesystem::exists(dir_fs_path)) return;
+
+  for (auto const& dir_entry : std::filesystem::directory_iterator(dir_fs_path)) {
+    if (!dir_entry.is_regular_file()) continue;
+    const std::filesystem::path& path = dir_entry.path();
+    if (!path.filename().has_extension()) {
+      continue;
+    }
+    if (path.filename().extension().compare(ORT_TSTR(".pb")) != 0) continue;
+    const std::basic_string<PATH_CHAR_TYPE> file_prefix =
+        is_input ? ORT_TSTR("input_") : ORT_TSTR("output_");
+    auto filename_str = path.filename().native();
+    if (filename_str.compare(0, file_prefix.length(), file_prefix) == 0) {
+      test_data_pb_files.push_back(path.native());
+    }
+  }
 
   SortFileNames(test_data_pb_files);
 
@@ -691,11 +700,13 @@ void OnnxTestCase::ConvertTestData(const ONNX_NAMESPACE::OptionalProto& test_dat
 OnnxTestCase::OnnxTestCase(const std::string& test_case_name, _In_ std::unique_ptr<TestModelInfo> model,
                            double default_per_sample_tolerance, double default_relative_per_sample_tolerance)
     : test_case_name_(test_case_name), model_info_(std::move(model)) {
-  std::basic_string<PATH_CHAR_TYPE> test_case_dir = model_info_->GetDir();
-
+  std::filesystem::path test_case_dir = model_info_->GetDir();
+  if (!std::filesystem::exists(test_case_dir)) {
+    ORT_THROW("test case dir doesn't exist");
+  }
   // parse config
-  std::basic_string<PATH_CHAR_TYPE> config_path =
-      ConcatPathComponent(test_case_dir, ORT_TSTR("config.txt"));
+  std::filesystem::path config_path =
+      test_case_dir / ORT_TSTR("config.txt");
   /* Note: protobuf-lite doesn't support reading protobuf files as text-format. Config.txt is exactly that.
      That's the reason I've to parse the file in a different way to read the configs. Currently
      this affects 2 tests - fp16_tiny_yolov2 and fp16_inception_v1. It's not clear why we've to use protobuf
@@ -705,7 +716,7 @@ OnnxTestCase::OnnxTestCase(const std::string& test_case_name, _In_ std::unique_p
   per_sample_tolerance_ = default_per_sample_tolerance;
   relative_per_sample_tolerance_ = default_relative_per_sample_tolerance;
   post_processing_ = false;
-  if (read_config_file(config_path, fc)) {
+  if (ReadConfigFile(config_path, fc)) {
     if (fc.count("per_sample_tolerance") > 0) {
       per_sample_tolerance_ = stod(fc["per_sample_tolerance"]);
     }
@@ -716,16 +727,11 @@ OnnxTestCase::OnnxTestCase(const std::string& test_case_name, _In_ std::unique_p
       post_processing_ = fc["post_processing"] == "true";
     }
   }
-
-  LoopDir(test_case_dir, [&test_case_dir, this](const PATH_CHAR_TYPE* filename, OrtFileType f_type) -> bool {
-    if (filename[0] == '.') return true;
-    if (f_type == OrtFileType::TYPE_DIR) {
-      std::basic_string<PATH_CHAR_TYPE> p = ConcatPathComponent(test_case_dir, filename);
-      test_data_dirs_.push_back(p);
-      debuginfo_strings_.push_back(ToUTF8String(p));
-    }
-    return true;
-  });
+  for (auto const& dir_entry : std::filesystem::directory_iterator(test_case_dir)) {
+    if (!dir_entry.is_directory()) continue;
+    test_data_dirs_.push_back(dir_entry.path());
+    debuginfo_strings_.push_back(ToUTF8String(dir_entry.path().string()));
+  }
 }
 
 void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths,
@@ -737,20 +743,23 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
                const std::function<void(std::unique_ptr<ITestCase>)>& process_function) {
   std::vector<std::basic_string<PATH_CHAR_TYPE>> paths(input_paths);
   while (!paths.empty()) {
-    std::basic_string<PATH_CHAR_TYPE> node_data_root_path = paths.back();
+    std::filesystem::path node_data_root_path = paths.back();
     paths.pop_back();
-    std::basic_string<PATH_CHAR_TYPE> my_dir_name = GetLastComponent(node_data_root_path);
-    LoopDir(node_data_root_path, [&](const PATH_CHAR_TYPE* filename, OrtFileType f_type) -> bool {
-      if (filename[0] == '.') return true;
-      if (f_type == OrtFileType::TYPE_DIR) {
-        std::basic_string<PATH_CHAR_TYPE> p = ConcatPathComponent(node_data_root_path, filename);
-        paths.push_back(p);
-        return true;
+    if (!std::filesystem::exists(node_data_root_path)) continue;
+    std::filesystem::path my_dir_name = node_data_root_path.filename();
+    for (auto const& dir_entry : std::filesystem::directory_iterator(node_data_root_path)) {
+      if (dir_entry.is_directory()) {
+        paths.push_back(dir_entry.path());
+        continue;
       }
-
-      std::basic_string<PATH_CHAR_TYPE> filename_str = filename;
-      bool is_onnx_format = HasExtensionOf(filename_str, ORT_TSTR("onnx"));
-      bool is_ort_format = HasExtensionOf(filename_str, ORT_TSTR("ort"));
+      if (!dir_entry.is_regular_file()) continue;
+      std::filesystem::path filename_str = dir_entry.path().filename();
+      if (filename_str.empty() || filename_str.native()[0] == ORT_TSTR('.')) {
+        // Ignore hidden files.
+        continue;
+      }
+      bool is_onnx_format = filename_str.has_extension() && (filename_str.extension().compare(ORT_TSTR(".onnx")) == 0);
+      bool is_ort_format = filename_str.has_extension() && (filename_str.extension().compare(ORT_TSTR(".ort")) == 0);
       bool is_valid_model = false;
 
 #if !defined(ORT_MINIMAL_BUILD)
@@ -759,42 +768,40 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
 
       is_valid_model = is_valid_model || is_ort_format;
       if (!is_valid_model)
-        return true;
+        continue;
 
-      std::basic_string<PATH_CHAR_TYPE> test_case_name = my_dir_name;
+      std::basic_string<PATH_CHAR_TYPE> test_case_name = my_dir_name.native();
       if (test_case_name.compare(0, 5, ORT_TSTR("test_")) == 0) test_case_name = test_case_name.substr(5);
 
       if (!whitelisted_test_cases.empty() && std::find(whitelisted_test_cases.begin(), whitelisted_test_cases.end(),
                                                        test_case_name) == whitelisted_test_cases.end()) {
-        return true;
+        continue;
       }
-      if (disabled_tests.find(test_case_name) != disabled_tests.end()) return true;
-
-      std::basic_string<PATH_CHAR_TYPE> p = ConcatPathComponent(node_data_root_path, filename_str);
+      if (disabled_tests.find(test_case_name) != disabled_tests.end()) continue;
 
       std::unique_ptr<TestModelInfo> model_info;
 
       if (is_onnx_format) {
 #if !defined(ORT_MINIMAL_BUILD)
-        model_info = TestModelInfo::LoadOnnxModel(p.c_str());
+        model_info = TestModelInfo::LoadOnnxModel(dir_entry.path());
 #else
         ORT_THROW("onnx model is not supported in this build");
 #endif
       } else if (is_ort_format) {
-        model_info = TestModelInfo::LoadOrtModel(p.c_str());
+        model_info = TestModelInfo::LoadOrtModel(dir_entry.path());
       } else {
         ORT_NOT_IMPLEMENTED(ToUTF8String(filename_str), " is not supported");
       }
 
       auto test_case_dir = model_info->GetDir();
-      auto test_case_name_in_log = test_case_name + ORT_TSTR(" in ") + test_case_dir;
+      auto test_case_name_in_log = test_case_name + ORT_TSTR(" in ") + test_case_dir.native();
 
 #if !defined(ORT_MINIMAL_BUILD) && !defined(USE_QNN)
       // to skip some models like *-int8 or *-qdq
       if ((reinterpret_cast<OnnxModelInfo*>(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) ||
           (reinterpret_cast<OnnxModelInfo*>(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) {
         fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it has training domain");
-        return true;
+        continue;
       }
 #endif
 
@@ -809,7 +816,7 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
       });
       if (!has_test_data) {
         fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to no test data");
-        return true;
+        continue;
       }
 
       if (broken_tests) {
@@ -820,7 +827,7 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
             (opset_version == TestModelInfo::unknown_version || iter->broken_opset_versions_.empty() ||
              iter->broken_opset_versions_.find(opset_version) != iter->broken_opset_versions_.end())) {
           fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to broken_tests");
-          return true;
+          continue;
         }
       }
 
@@ -829,7 +836,7 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
           std::string keyword = *iter2;
           if (ToUTF8String(test_case_name).find(keyword) != std::string::npos) {
             fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it is in broken test keywords");
-            return true;
+            continue;
           }
         }
       }
@@ -841,8 +848,7 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
                                                         tolerances.relative(tolerance_key));
       fprintf(stdout, "Load Test Case: %s\n", ToUTF8String(test_case_name_in_log).c_str());
       process_function(std::move(l));
-      return true;
-    });
+    }
   }
 }
 
diff --git a/onnxruntime/test/onnx/TestCase.h b/onnxruntime/test/onnx/TestCase.h
index 96b0b5f6f7c08..0cb92056d378e 100644
--- a/onnxruntime/test/onnx/TestCase.h
+++ b/onnxruntime/test/onnx/TestCase.h
@@ -6,6 +6,7 @@
 #include <mutex>
 #include <unordered_map>
 #include <unordered_set>
+#include <filesystem>
 #include <core/common/common.h>
 #include <core/common/status.h>
 #include <core/platform/path_lib.h>
@@ -31,7 +32,7 @@ class ITestCase {
   virtual void LoadTestData(size_t id, onnxruntime::test::HeapBuffer& b,
                             std::unordered_map<std::string, Ort::Value>& name_data_map,
                             bool is_input) const = 0;
-  virtual const PATH_CHAR_TYPE* GetModelUrl() const = 0;
+  virtual const std::filesystem::path& GetModelUrl() const = 0;
   virtual const std::string& GetNodeName() const = 0;
   virtual const ONNX_NAMESPACE::ValueInfoProto* GetInputInfoFromModel(size_t i) const = 0;
   virtual const ONNX_NAMESPACE::ValueInfoProto* GetOutputInfoFromModel(size_t i) const = 0;
@@ -50,14 +51,9 @@ class ITestCase {
 
 class TestModelInfo {
  public:
-  virtual const PATH_CHAR_TYPE* GetModelUrl() const = 0;
-  virtual std::basic_string<PATH_CHAR_TYPE> GetDir() const {
-    std::basic_string<PATH_CHAR_TYPE> test_case_dir;
-    auto st = onnxruntime::GetDirNameFromFilePath(GetModelUrl(), test_case_dir);
-    if (!st.IsOK()) {
-      ORT_THROW("GetDirNameFromFilePath failed");
-    }
-    return test_case_dir;
+  virtual const std::filesystem::path& GetModelUrl() const = 0;
+  virtual std::filesystem::path GetDir() const {
+    return GetModelUrl().parent_path();
   }
   virtual const std::string& GetNodeName() const = 0;
   virtual const ONNX_NAMESPACE::ValueInfoProto* GetInputInfoFromModel(size_t i) const = 0;
@@ -70,10 +66,10 @@ class TestModelInfo {
   virtual ~TestModelInfo() = default;
 
 #if !defined(ORT_MINIMAL_BUILD)
-  static std::unique_ptr<TestModelInfo> LoadOnnxModel(_In_ const PATH_CHAR_TYPE* model_url);
+  static std::unique_ptr<TestModelInfo> LoadOnnxModel(const std::filesystem::path& model_url);
 #endif
 
-  static std::unique_ptr<TestModelInfo> LoadOrtModel(_In_ const PATH_CHAR_TYPE* model_url);
+  static std::unique_ptr<TestModelInfo> LoadOrtModel(const std::filesystem::path& model_url);
 
   static const std::string unknown_version;
 };
diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc
index aca609cf94270..0d55fd19b918a 100644
--- a/onnxruntime/test/onnx/main.cc
+++ b/onnxruntime/test/onnx/main.cc
@@ -25,6 +25,10 @@
 #include "core/session/onnxruntime_session_options_config_keys.h"
 #include "nlohmann/json.hpp"
 
+#ifdef USE_CUDA
+#include "core/providers/cuda/cuda_provider_options.h"
+#endif
+
 using namespace onnxruntime;
 
 namespace {
@@ -64,6 +68,8 @@ void usage() {
       "\t    [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n"
       "\t    Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n"
       "\t    [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n"
+      "\t    [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n"
+      "\t    Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n"
       "\t [Usage]: -e <provider_name> -i '<key1>|<value1> <key2>|<value2>' \n\n"
       "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n"
       "\t    [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n"
@@ -339,11 +345,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
     logging_level = ORT_LOGGING_LEVEL_VERBOSE;
   }
 
-  if (concurrent_session_runs > 1 && repeat_count > 1) {
-    fprintf(stderr, "when you use '-r [repeat]', please set '-c' to 1\n");
-    usage();
-    return -1;
-  }
   argc -= optind;
   argv += optind;
   if (argc < 1) {
@@ -404,12 +405,15 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
 
     if (enable_tensorrt) {
 #ifdef USE_TENSORRT
-      OrtCUDAProviderOptions cuda_options;
+      Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(sf, device_id));
+#ifdef USE_CUDA
+      OrtCUDAProviderOptionsV2 cuda_options;
       cuda_options.device_id = device_id;
       cuda_options.do_copy_in_default_stream = true;
+      cuda_options.use_tf32 = false;
       // TODO: Support arena configuration for users of test runner
-      Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(sf, device_id));
-      sf.AppendExecutionProvider_CUDA(cuda_options);
+      sf.AppendExecutionProvider_CUDA_V2(cuda_options);
+#endif
 #else
       fprintf(stderr, "TensorRT is not supported in this build");
       return -1;
@@ -427,10 +431,11 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
     }
     if (enable_cuda) {
 #ifdef USE_CUDA
-      OrtCUDAProviderOptions cuda_options;
+      OrtCUDAProviderOptionsV2 cuda_options;
       cuda_options.do_copy_in_default_stream = true;
+      cuda_options.use_tf32 = false;
       // TODO: Support arena configuration for users of test runner
-      sf.AppendExecutionProvider_CUDA(cuda_options);
+      sf.AppendExecutionProvider_CUDA_V2(cuda_options);
 #else
       fprintf(stderr, "CUDA is not supported in this build");
       return -1;
@@ -525,11 +530,20 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
             std::string str = str_stream.str();
             ORT_THROW("Wrong value for htp_arch. select from: " + str);
           }
+        } else if (key == "enable_htp_fp16_precision") {
+          std::unordered_set<std::string> supported_options = {"0", "1"};
+          if (supported_options.find(value) == supported_options.end()) {
+            std::ostringstream str_stream;
+            std::copy(supported_options.begin(), supported_options.end(),
+                      std::ostream_iterator<std::string>(str_stream, ","));
+            std::string str = str_stream.str();
+            ORT_THROW("Wrong value for enable_htp_fp16_precision. select from: " + str);
+          }
         } else {
           ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path',
 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode',
 'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority',
-'soc_model', 'htp_arch', 'device_id'])");
+'soc_model', 'htp_arch', 'device_id', 'enable_htp_fp16_precision'])");
         }
 
         qnn_options[key] = value;
diff --git a/onnxruntime/test/onnx/microbenchmark/activation.cc b/onnxruntime/test/onnx/microbenchmark/activation.cc
index cf859facf4765..69ee72996365e 100644
--- a/onnxruntime/test/onnx/microbenchmark/activation.cc
+++ b/onnxruntime/test/onnx/microbenchmark/activation.cc
@@ -11,6 +11,7 @@
 #include "core/framework/node_index_info.h"
 #include "core/framework/execution_frame.h"
 #include "contrib_ops/cpu/activations.h"
+#include "core/providers/cpu/tensor/gelu.h"
 #include "core/providers/cpu/activation/activations.h"
 #include <onnx/defs/attr_proto_util.h>
 #include <benchmark/benchmark.h>
@@ -182,7 +183,7 @@ static void RunSingleNode(const std::string& op_name, const std::string& domain,
 }
 
 static void BM_GeluCompute(benchmark::State& state) {
-  RunSingleNode<contrib::Gelu<float>>("Gelu", kMSDomain, {}, state);
+  RunSingleNode<Gelu<float>>("Gelu", kMSDomain, {}, state);
 }
 
 BENCHMARK(BM_GeluCompute)
diff --git a/onnxruntime/test/onnx/onnx_model_info.cc b/onnxruntime/test/onnx/onnx_model_info.cc
index d6afa99382e61..f23012aee9fd2 100644
--- a/onnxruntime/test/onnx/onnx_model_info.cc
+++ b/onnxruntime/test/onnx/onnx_model_info.cc
@@ -14,7 +14,7 @@
 
 using namespace onnxruntime;
 
-OnnxModelInfo::OnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url, bool is_ort_model)
+OnnxModelInfo::OnnxModelInfo(const std::filesystem::path& model_url, bool is_ort_model)
     : model_url_(model_url) {
   if (is_ort_model) {
     InitOrtModelInfo(model_url);
@@ -38,7 +38,7 @@ static void RepeatedPtrFieldToVector(const ::google::protobuf::RepeatedPtrField<
   }
 }
 
-void OnnxModelInfo::InitOnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url) {  // parse model
+void OnnxModelInfo::InitOnnxModelInfo(const std::filesystem::path& model_url) {  // parse model
   int model_fd;
   auto st = Env::Default().FileOpenRd(model_url, model_fd);
   if (!st.IsOK()) {
@@ -50,7 +50,9 @@ void OnnxModelInfo::InitOnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url) {  /
   const bool parse_result = model_pb.ParseFromZeroCopyStream(&input) && input.GetErrno() == 0;
   if (!parse_result) {
     (void)Env::Default().FileClose(model_fd);
-    ORT_THROW("Failed to load model because protobuf parsing failed.");
+    std::ostringstream oss;
+    oss << "Failed to load model from " << model_url << " because protobuf parsing failed.";
+    ORT_THROW(oss.str());
   }
   (void)Env::Default().FileClose(model_fd);
   {
@@ -91,7 +93,7 @@ void OnnxModelInfo::InitOnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url) {  /
 
 #endif  // #if !defined(ORT_MINIMAL_BUILD)
 
-void OnnxModelInfo::InitOrtModelInfo(_In_ const PATH_CHAR_TYPE* model_url) {
+void OnnxModelInfo::InitOrtModelInfo(const std::filesystem::path& model_url) {
   std::vector<uint8_t> bytes;
   size_t num_bytes = 0;
   const auto model_location = ToWideString(model_url);
diff --git a/onnxruntime/test/onnx/onnx_model_info.h b/onnxruntime/test/onnx/onnx_model_info.h
index a0aa27df64a94..48e297376aff5 100644
--- a/onnxruntime/test/onnx/onnx_model_info.h
+++ b/onnxruntime/test/onnx/onnx_model_info.h
@@ -13,16 +13,16 @@ class OnnxModelInfo : public TestModelInfo {
   std::vector<ONNX_NAMESPACE::ValueInfoProto> input_value_info_;
   std::vector<ONNX_NAMESPACE::ValueInfoProto> output_value_info_;
   std::unordered_map<std::string, int64_t> domain_to_version_;
-  const std::basic_string<PATH_CHAR_TYPE> model_url_;
+  const std::filesystem::path model_url_;
 
 #if !defined(ORT_MINIMAL_BUILD)
-  void InitOnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url);
+  void InitOnnxModelInfo(const std::filesystem::path& model_url);
 #endif
 
-  void InitOrtModelInfo(_In_ const PATH_CHAR_TYPE* model_url);
+  void InitOrtModelInfo(const std::filesystem::path& model_url);
 
  public:
-  OnnxModelInfo(_In_ const PATH_CHAR_TYPE* model_url, bool is_ort_model = false);
+  OnnxModelInfo(const std::filesystem::path& path, bool is_ort_model = false);
   bool HasDomain(const std::string& name) const {
     return domain_to_version_.find(name) != domain_to_version_.end();
   }
@@ -32,7 +32,7 @@ class OnnxModelInfo : public TestModelInfo {
     return iter == domain_to_version_.end() ? -1 : iter->second;
   }
 
-  const PATH_CHAR_TYPE* GetModelUrl() const override { return model_url_.c_str(); }
+  const std::filesystem::path& GetModelUrl() const override { return model_url_; }
   std::string GetNominalOpsetVersion() const override { return onnx_nominal_opset_vesion_; }
 
   const std::string& GetNodeName() const override { return node_name_; }
diff --git a/onnxruntime/test/onnx/testcase_request.cc b/onnxruntime/test/onnx/testcase_request.cc
index 9ca8273ac907b..9d653571ca2ec 100644
--- a/onnxruntime/test/onnx/testcase_request.cc
+++ b/onnxruntime/test/onnx/testcase_request.cc
@@ -36,7 +36,7 @@ bool TestCaseRequestContext::SetupSession() {
   ORT_TRY {
     const auto* test_case_name = test_case_.GetTestCaseName().c_str();
     session_opts_.SetLogId(test_case_name);
-    Ort::Session session{env_, test_case_.GetModelUrl(), session_opts_};
+    Ort::Session session{env_, test_case_.GetModelUrl().native().c_str(), session_opts_};
     session_ = std::move(session);
     LOGF_DEFAULT(INFO, "Testing %s\n", test_case_name);
     return true;
diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc
index bf02c1741725f..97f1feaaa612d 100755
--- a/onnxruntime/test/optimizer/graph_transform_test.cc
+++ b/onnxruntime/test/optimizer/graph_transform_test.cc
@@ -65,6 +65,7 @@
 #include "core/optimizer/relu_clip_fusion.h"
 #include "core/optimizer/reshape_fusion.h"
 #include "core/optimizer/rule_based_graph_transformer.h"
+#include "core/optimizer/shape_input_merge.h"
 #include "core/optimizer/slice_elimination.h"
 #include "core/optimizer/unsqueeze_elimination.h"
 #include "core/optimizer/utils.h"
@@ -4879,6 +4880,53 @@ TEST_F(GraphTransformationTests, FastGeluFusionWithCastsTest3) {
   ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1);
 }
 
+TEST_F(GraphTransformationTests, CseWithConstantOfShape) {
+  auto build_test_case = [&](ModelTestBuilder& builder) {
+    std::vector<std::variant<int64_t, std::string>> input_shape;
+    input_shape.reserve(4);
+    input_shape.emplace_back("dim0");
+    input_shape.emplace_back(512);
+    input_shape.emplace_back(16);
+    input_shape.emplace_back("dim3");
+    auto* input_arg = builder.MakeSymbolicInput<float>(input_shape);
+    auto* shape_out_1 = builder.MakeIntermediate();
+    auto* shape_out_2 = builder.MakeIntermediate();
+    auto* constant_of_shape_out_1 = builder.MakeIntermediate();
+    auto* constant_of_shape_out_2 = builder.MakeIntermediate();
+    auto* mul_out_1 = builder.MakeIntermediate();
+    auto* mul_out_2 = builder.MakeOutput();
+    builder.AddNode("Shape", {input_arg}, {shape_out_1});
+    builder.AddNode("Shape", {input_arg}, {shape_out_2});
+    TensorProto value_tensor;
+    value_tensor.add_dims(1);
+    float value = 2.333f;
+    value_tensor.set_raw_data(reinterpret_cast<const char*>(&value), sizeof(float));
+    value_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
+    builder.AddNode("ConstantOfShape", {shape_out_1}, {constant_of_shape_out_1}).AddAttribute("value", value_tensor);
+    builder.AddNode("ConstantOfShape", {shape_out_2}, {constant_of_shape_out_2}).AddAttribute("value", value_tensor);
+    builder.AddNode("Mul", {input_arg, constant_of_shape_out_1}, {mul_out_1});
+    builder.AddNode("Mul", {mul_out_1, constant_of_shape_out_2}, {mul_out_2});
+  };
+
+  auto pre_graph_checker = [&](Graph& graph) {
+    auto op_count_map = CountOpsInGraph(graph);
+    TEST_RETURN_IF_NOT(op_count_map["Shape"] == 2);
+    TEST_RETURN_IF_NOT(op_count_map["ConstantOfShape"] == 2);
+    return Status::OK();
+  };
+
+  auto post_graph_checker = [&](Graph& graph) {
+    auto op_count_map = CountOpsInGraph(graph);
+    TEST_RETURN_IF_NOT(op_count_map["Shape"] == 1);
+    TEST_RETURN_IF_NOT(op_count_map["ConstantOfShape"] == 1);
+    return Status::OK();
+  };
+
+  std::unique_ptr<GraphTransformer> transformer = std::make_unique<CommonSubexpressionElimination>();
+  ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
+                                        1, pre_graph_checker, post_graph_checker));
+}
+
 TEST_F(GraphTransformationTests, QuickGelu) {
   // Sigmoid(x*alpha)*x, float
   {
@@ -5679,6 +5727,24 @@ TEST_F(GraphTransformationTests, MatMulIntegerToFloatTest) {
   EXPECT_EQ(op_to_count["Add"], 1);
 }
 
+#ifdef USE_DML
+TEST_F(GraphTransformationTests, MatMulIntegerToFloat16Test) {
+  constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/matmul_integer_to_float16_int8.onnx";
+  std::shared_ptr<Model> p_model;
+  ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
+  Graph& graph = p_model->MainGraph();
+
+  for (auto& node : graph.Nodes()) {
+    node.SetExecutionProviderType(kDmlExecutionProvider);
+  }
+  onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
+  ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<MatMulIntegerToFloatFusion>(), TransformerLevel::Level2));
+  ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
+  std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
+  EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1);
+}
+#endif  // USE_DML
+
 #endif
 
 #ifndef DISABLE_CONTRIB_OPS
@@ -7058,13 +7124,13 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShouldNotShareForGraphOutput) {
   }
 }
 
-TEST_F(GraphTransformationTests, GatherToSplitFusion) {
+TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_AllGather) {
   auto build_test_case = [&](ModelTestBuilder& builder) {
     auto* data_arg = builder.MakeInput<float>({{54}});
     auto* shape_arg = builder.MakeInput<int64_t>({{4}});
     auto* reshape_out = builder.MakeIntermediate<float>({{2, 3, 3, 3}});
     auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
-    auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
+    auto* gather_index_2 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(1)});
     auto* gather_index_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(2)});
     auto* gather_out_1 = builder.MakeIntermediate();
     auto* gather_out_2 = builder.MakeIntermediate();
@@ -7081,7 +7147,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
     builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3})
         .AddAttribute("axis", static_cast<int64_t>(2));
     builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
-    builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
+    builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2})
+        .AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
     builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
   };
 
@@ -7090,27 +7157,16 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
     return Status::OK();
   };
 
-  // OpSet-12
+  // OpSet-12, not support
   {
     auto post_graph_checker = [&](Graph& graph) {
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
-      for (auto& node : graph.Nodes()) {
-        if (node.OpType() == "Split") {
-          auto& attrs = node.GetAttributes();
-          TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
-          TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
-        } else if (node.OpType() == "Squeeze") {
-          auto& attrs = node.GetAttributes();
-          TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end());
-          TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axes").ints().at(0)));
-        }
-      }
+      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
+      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0);
+      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
       return Status::OK();
     };
 
-    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
+    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
     ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
                                           TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
   }
@@ -7120,7 +7176,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
     auto post_graph_checker = [&](Graph& graph) {
       TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
       TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
+      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 2);
       for (auto& node : graph.Nodes()) {
         if (node.OpType() == "Split") {
           auto& attrs = node.GetAttributes();
@@ -7139,249 +7195,140 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
       return Status::OK();
     };
 
-    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
+    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
     ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
                                           TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
   }
-
-  // OpSet-18
-  {
-    auto post_graph_checker = [&](Graph& graph) {
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
-      for (auto& node : graph.Nodes()) {
-        if (node.OpType() == "Split") {
-          auto& attrs = node.GetAttributes();
-          TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
-          TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
-        } else if (node.OpType() == "Squeeze") {
-          const NodeArg& input_arg = *(node.InputDefs()[1]);
-          const ONNX_NAMESPACE::TensorProto* tensor_proto =
-              graph_utils::GetConstantInitializer(graph, input_arg.Name());
-          TEST_RETURN_IF_NOT(tensor_proto != nullptr);
-          Initializer init_const{*tensor_proto, graph.ModelPath()};
-          TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
-          TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
-        }
-      }
-      return Status::OK();
-    };
-
-    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
-    ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer),
-                                          TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
-  }
 }
 
-TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
+TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_AllSlice_GraphInput) {
   auto build_test_case = [&](ModelTestBuilder& builder) {
-    auto* data_arg = builder.MakeInput<float>({{54}});
-    auto* shape_arg = builder.MakeInput<int64_t>({{4}});
-    auto* reshape_out = builder.MakeIntermediate<float>({{2, 3, 3, 3}});
-    auto* gather_index_1 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(0)});
-    auto* gather_index_2 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(1)});
-    auto* gather_index_3 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(2)});
-    auto* gather_out_1 = builder.MakeIntermediate();
-    auto* gather_out_2 = builder.MakeIntermediate();
-    auto* gather_out_3 = builder.MakeIntermediate();
+    auto* data_arg = builder.MakeInput<float>({{2, 3, 8, 3}});
+    auto* starts_1 = builder.MakeInitializer<int64_t>({1}, {0});
+    auto* ends_1 = builder.MakeInitializer<int64_t>({1}, {2});
+    auto* axes_1 = builder.MakeInitializer<int64_t>({1}, {2});
+    auto* steps_1 = builder.MakeInitializer<int64_t>({1}, {1});
+    auto* starts_2 = builder.MakeInitializer<int64_t>({1}, {2});
+    auto* ends_2 = builder.MakeInitializer<int64_t>({1}, {-2});
+    auto* axes_2 = builder.MakeInitializer<int64_t>({1}, {-2});
+    auto* steps_2 = builder.MakeInitializer<int64_t>({1}, {1});
+    auto* starts_3 = builder.MakeInitializer<int64_t>({1}, {-2});
+    auto* ends_3 = builder.MakeInitializer<int64_t>({1}, {16});
+    auto* axes_3 = builder.MakeInitializer<int64_t>({1}, {2});
+    auto* slice_out_1 = builder.MakeIntermediate();
+    auto* slice_out_2 = builder.MakeIntermediate();
+    auto* slice_out_3 = builder.MakeIntermediate();
     auto* transpose_out_1 = builder.MakeOutput();
     auto* transpose_out_2 = builder.MakeOutput();
     auto* transpose_out_3 = builder.MakeOutput();
 
-    builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out});
-    builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1})
-        .AddAttribute("axis", static_cast<int64_t>(2));
-    builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2})
-        .AddAttribute("axis", static_cast<int64_t>(-2));
-    builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3})
-        .AddAttribute("axis", static_cast<int64_t>(2));
-    builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
-    builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
-    builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
+    builder.AddNode("Slice", {data_arg, starts_1, ends_1, axes_1, steps_1}, {slice_out_1});
+    builder.AddNode("Slice", {data_arg, starts_2, ends_2, axes_2, steps_2}, {slice_out_2});
+    builder.AddNode("Slice", {data_arg, starts_3, ends_3, axes_3}, {slice_out_3});
+    builder.AddNode("Transpose", {slice_out_1}, {transpose_out_1})
+        .AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
+    builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2})
+        .AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
+    builder.AddNode("Transpose", {slice_out_3}, {transpose_out_3})
+        .AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
   };
 
   auto pre_graph_checker = [&](Graph& graph) {
-    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
+    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 3);
     return Status::OK();
   };
 
-  // OpSet-12
-  {
-    auto post_graph_checker = [&](Graph& graph) {
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
-      for (auto& node : graph.Nodes()) {
-        if (node.OpType() == "Split") {
-          auto& attrs = node.GetAttributes();
-          TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
-          TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
-        }
-      }
-      return Status::OK();
-    };
-
-    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
-    ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
-                                          TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
-  }
-
-  // OpSet-14
-  {
-    auto post_graph_checker = [&](Graph& graph) {
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
-      for (auto& node : graph.Nodes()) {
-        if (node.OpType() == "Split") {
-          auto& attrs = node.GetAttributes();
-          TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
-          TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
-        }
-      }
-      return Status::OK();
-    };
-
-    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
-    ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
-                                          TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
-  }
-
-  // OpSet-18
-  {
-    auto post_graph_checker = [&](Graph& graph) {
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
-      for (auto& node : graph.Nodes()) {
-        if (node.OpType() == "Split") {
-          auto& attrs = node.GetAttributes();
-          TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
-          TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
-        }
+  auto post_graph_checker = [&](Graph& graph) {
+    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
+    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
+    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
+    for (auto& node : graph.Nodes()) {
+      if (node.OpType() == "Split") {
+        auto& attrs = node.GetAttributes();
+        TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
+        TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
       }
-      return Status::OK();
-    };
+    }
+    return Status::OK();
+  };
 
-    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
-    ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer),
-                                          TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
-  }
+  std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
+  ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1,
+                                        1, pre_graph_checker, post_graph_checker));
 }
 
-TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Input) {
+TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Combined) {
   auto build_test_case = [&](ModelTestBuilder& builder) {
-    auto* data_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
-    auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
-    auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
-    auto* gather_index_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(2)});
+    auto* data_arg = builder.MakeInput<float>({{144}});
+    auto* shape_arg = builder.MakeInput<int64_t>({{4}});
+    auto* reshape_out = builder.MakeIntermediate<float>({{2, 8, 3, 3}});
+    auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(5)});
+    auto* starts_2 = builder.MakeInitializer<int64_t>({1}, {6});
+    auto* ends_2 = builder.MakeInitializer<int64_t>({1}, {8});
+    auto* axes_2 = builder.MakeInitializer<int64_t>({1}, {-3});
+    auto* steps_2 = builder.MakeInitializer<int64_t>({1}, {1});
+    auto* gather_index_3 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(4)});
+    auto* starts_4 = builder.MakeInitializer<int64_t>({1}, {-16});
+    auto* ends_4 = builder.MakeInitializer<int64_t>({1}, {4});
+    auto* axes_4 = builder.MakeInitializer<int64_t>({1}, {1});
     auto* gather_out_1 = builder.MakeIntermediate();
-    auto* gather_out_2 = builder.MakeIntermediate();
+    auto* slice_out_2 = builder.MakeIntermediate();
     auto* gather_out_3 = builder.MakeIntermediate();
+    auto* slice_out_4 = builder.MakeIntermediate();
     auto* transpose_out_1 = builder.MakeOutput();
     auto* transpose_out_2 = builder.MakeOutput();
     auto* transpose_out_3 = builder.MakeOutput();
+    auto* transpose_out_4 = builder.MakeOutput();
 
-    builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast<int64_t>(2));
-    builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2})
-        .AddAttribute("axis", static_cast<int64_t>(-2));
-    builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast<int64_t>(2));
+    builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out});
+    builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1})
+        .AddAttribute("axis", static_cast<int64_t>(1));
+    builder.AddNode("Slice", {reshape_out, starts_2, ends_2, axes_2, steps_2}, {slice_out_2});
+    builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3})
+        .AddAttribute("axis", static_cast<int64_t>(-3));
+    builder.AddNode("Slice", {reshape_out, starts_4, ends_4, axes_4}, {slice_out_4});
     builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
-    builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
-    builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
+    builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2})
+        .AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
+    builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3})
+        .AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
+    builder.AddNode("Transpose", {slice_out_4}, {transpose_out_4})
+        .AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
   };
 
   auto pre_graph_checker = [&](Graph& graph) {
-    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
+    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2);
+    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 2);
     return Status::OK();
   };
 
-  // OpSet-12
-  {
-    auto post_graph_checker = [&](Graph& graph) {
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
-      for (auto& node : graph.Nodes()) {
-        if (node.OpType() == "Split") {
-          auto& attrs = node.GetAttributes();
-          TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
-          TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
-        } else if (node.OpType() == "Squeeze") {
-          auto& attrs = node.GetAttributes();
-          TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end());
-          TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axes").ints().at(0)));
-        }
-      }
-      return Status::OK();
-    };
-
-    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
-    ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
-                                          TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
-  }
-
-  // OpSet-14
-  {
-    auto post_graph_checker = [&](Graph& graph) {
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
-      for (auto& node : graph.Nodes()) {
-        if (node.OpType() == "Split") {
-          auto& attrs = node.GetAttributes();
-          TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
-          TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
-        } else if (node.OpType() == "Squeeze") {
-          const NodeArg& input_arg = *(node.InputDefs()[1]);
-          const ONNX_NAMESPACE::TensorProto* tensor_proto =
-              graph_utils::GetConstantInitializer(graph, input_arg.Name());
-          TEST_RETURN_IF_NOT(tensor_proto != nullptr);
-          Initializer init_const{*tensor_proto, graph.ModelPath()};
-          TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
-          TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
-        }
-      }
-      return Status::OK();
-    };
-
-    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
-    ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
-                                          TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
-  }
-
-  // OpSet-18
-  {
-    auto post_graph_checker = [&](Graph& graph) {
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
-      TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
-      for (auto& node : graph.Nodes()) {
-        if (node.OpType() == "Split") {
-          auto& attrs = node.GetAttributes();
-          TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
-          TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
-        } else if (node.OpType() == "Squeeze") {
-          const NodeArg& input_arg = *(node.InputDefs()[1]);
-          const ONNX_NAMESPACE::TensorProto* tensor_proto =
-              graph_utils::GetConstantInitializer(graph, input_arg.Name());
-          TEST_RETURN_IF_NOT(tensor_proto != nullptr);
-          Initializer init_const{*tensor_proto, graph.ModelPath()};
-          TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
-          TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
-        }
+  auto post_graph_checker = [&](Graph& graph) {
+    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
+    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
+    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 1);
+    for (auto& node : graph.Nodes()) {
+      if (node.OpType() == "Split") {
+        auto& attrs = node.GetAttributes();
+        TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
+        TEST_RETURN_IF_NOT(1 == static_cast<int>(attrs.at("axis").i()));
+      } else if (node.OpType() == "Squeeze") {
+        const NodeArg& input_arg = *(node.InputDefs()[1]);
+        const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
+        TEST_RETURN_IF_NOT(tensor_proto != nullptr);
+        Initializer init_const{*tensor_proto, graph.ModelPath()};
+        TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
+        TEST_RETURN_IF_NOT(1 == static_cast<int>(*(init_const.data<int64_t>())));
       }
-      return Status::OK();
-    };
+    }
+    return Status::OK();
+  };
 
-    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
-    ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer),
-                                          TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
-  }
+  std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
+  ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
+                                        1, pre_graph_checker, post_graph_checker));
 }
 
-TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) {
+TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Consume_Initializer) {
   auto build_test_case = [&](ModelTestBuilder& builder) {
     auto* data_arg = builder.MakeInitializer<float>({2, 3, 3, 3}, std::vector<float>(54));
     auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
@@ -7429,31 +7376,31 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) {
     return Status::OK();
   };
 
-  std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
+  std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
   ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
                                         1, pre_graph_checker, post_graph_checker));
 }
 
-TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
+TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) {
   auto pre_graph_checker = [&](Graph& graph) {
-    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
+    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] > 0 || CountOpsInGraph(graph)["Slice"] > 0);
     return Status::OK();
   };
   auto post_graph_checker = [&](Graph& graph) {
-    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3);
+    TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] > 0 || CountOpsInGraph(graph)["Slice"] > 0);
     TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0);
     TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
     return Status::OK();
   };
 
-  // Invalid shape.
+  // Not cover all elements of specific dimension.
   {
     auto build_test_case = [&](ModelTestBuilder& builder) {
       auto* data_arg = builder.MakeInput<float>({{72}});
-      auto* shape_arg = builder.MakeInput<int64_t>({{1}});
+      auto* shape_arg = builder.MakeInput<int64_t>({{4}});
       auto* reshape_out = builder.MakeIntermediate<float>({{2, 3, 4, 3}});
       auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
-      auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
+      auto* gather_index_2 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(1)});
       auto* gather_index_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(2)});
       auto* gather_out_1 = builder.MakeIntermediate();
       auto* gather_out_2 = builder.MakeIntermediate();
@@ -7466,63 +7413,65 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
       builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1})
           .AddAttribute("axis", static_cast<int64_t>(2));
       builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2})
-          .AddAttribute("axis", static_cast<int64_t>(2));
+          .AddAttribute("axis", static_cast<int64_t>(-2));
       builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3})
           .AddAttribute("axis", static_cast<int64_t>(2));
       builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1})
           .AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
       builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2})
-          .AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
+          .AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
       builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3})
           .AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
     };
 
-    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
-    ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer),
+    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
+    ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
                                           TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
   }
 
-  // Invalid Gather indices.
+  // Has overlap.
   {
     auto build_test_case = [&](ModelTestBuilder& builder) {
-      auto* data_arg = builder.MakeInput<float>({{54}});
-      auto* shape_arg = builder.MakeInput<int64_t>({{1}});
-      auto* reshape_out = builder.MakeIntermediate<float>({{2, 3, 3, 3}});
-      auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
-      auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
-      auto* gather_index_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
-      auto* gather_out_1 = builder.MakeIntermediate();
-      auto* gather_out_2 = builder.MakeIntermediate();
-      auto* gather_out_3 = builder.MakeIntermediate();
+      auto* data_arg = builder.MakeInput<float>({{2, 3, 8, 3}});
+      auto* starts_1 = builder.MakeInitializer<int64_t>({1}, {0});
+      auto* ends_1 = builder.MakeInitializer<int64_t>({1}, {3});
+      auto* axes_1 = builder.MakeInitializer<int64_t>({1}, {2});
+      auto* steps_1 = builder.MakeInitializer<int64_t>({1}, {1});
+      auto* starts_2 = builder.MakeInitializer<int64_t>({1}, {2});
+      auto* ends_2 = builder.MakeInitializer<int64_t>({1}, {-2});
+      auto* axes_2 = builder.MakeInitializer<int64_t>({1}, {-2});
+      auto* steps_2 = builder.MakeInitializer<int64_t>({1}, {1});
+      auto* starts_3 = builder.MakeInitializer<int64_t>({1}, {-2});
+      auto* ends_3 = builder.MakeInitializer<int64_t>({1}, {16});
+      auto* axes_3 = builder.MakeInitializer<int64_t>({1}, {2});
+      auto* slice_out_1 = builder.MakeIntermediate();
+      auto* slice_out_2 = builder.MakeIntermediate();
+      auto* slice_out_3 = builder.MakeIntermediate();
       auto* transpose_out_1 = builder.MakeOutput();
       auto* transpose_out_2 = builder.MakeOutput();
       auto* transpose_out_3 = builder.MakeOutput();
 
-      builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out});
-      builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1})
-          .AddAttribute("axis", static_cast<int64_t>(2));
-      builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2})
-          .AddAttribute("axis", static_cast<int64_t>(2));
-      builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3})
-          .AddAttribute("axis", static_cast<int64_t>(2));
-      builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1})
-          .AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
-      builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2})
-          .AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
-      builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3})
-          .AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
+      builder.AddNode("Slice", {data_arg, starts_1, ends_1, axes_1, steps_1}, {slice_out_1});
+      builder.AddNode("Slice", {data_arg, starts_2, ends_2, axes_2, steps_2}, {slice_out_2});
+      builder.AddNode("Slice", {data_arg, starts_3, ends_3, axes_3}, {slice_out_3});
+      builder.AddNode("Transpose", {slice_out_1}, {transpose_out_1})
+          .AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
+      builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2})
+          .AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
+      builder.AddNode("Transpose", {slice_out_3}, {transpose_out_3})
+          .AddAttribute("perm", std::vector<int64_t>{0, 2, 1, 3});
     };
 
-    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
-    ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
+    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
+    ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer),
                                           TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
   }
 
-  // Invalid Gather axis.
+  // Invalid axis.
   {
     auto build_test_case = [&](ModelTestBuilder& builder) {
       auto* data_arg = builder.MakeInput<float>({{54}});
-      auto* shape_arg = builder.MakeInput<int64_t>({{1}});
+      auto* shape_arg = builder.MakeInput<int64_t>({{4}});
       auto* reshape_out = builder.MakeIntermediate<float>({{2, 3, 3, 3}});
       auto* gather_index_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
       auto* gather_index_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
@@ -7549,7 +7498,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
           .AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
     };
 
-    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
+    std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
     ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
                                           TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
   }
@@ -7642,5 +7591,79 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) {
   }
 }
 
+TEST_F(GraphTransformationTests, ShapeInputMerge) {
+  auto build_test_case = [&](ModelTestBuilder& builder) {
+    std::vector<std::variant<int64_t, std::string>> input_shape;
+    input_shape.reserve(5);
+    input_shape.emplace_back("dim0");
+    input_shape.emplace_back(512);
+    input_shape.emplace_back(1);
+    input_shape.emplace_back(1536);
+    input_shape.emplace_back("dim4");
+    auto* input_arg = builder.MakeSymbolicInput<float>(input_shape);
+    auto* neg_out = builder.MakeIntermediate();
+    auto* axes_initializer = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(2)});
+    auto* squeeze_out = builder.MakeIntermediate();
+    auto* cast_out = builder.MakeIntermediate();
+    auto* unsqueeze_out = builder.MakeOutput();
+    auto* shape_1_out = builder.MakeOutput();
+    auto* shape_2_out = builder.MakeOutput();
+    auto* shape_3_out = builder.MakeOutput();
+    auto* shape_4_out = builder.MakeOutput();
+    auto* shape_5_out = builder.MakeOutput();
+    builder.AddNode("Neg", {input_arg}, {neg_out});
+    builder.AddNode("Squeeze", {neg_out, axes_initializer}, {squeeze_out});
+    builder.AddNode("Cast", {squeeze_out}, {cast_out}).AddAttribute("to", static_cast<int64_t>(10));
+    builder.AddNode("Unsqueeze", {cast_out, axes_initializer}, {unsqueeze_out});
+    builder.AddNode("Shape", {input_arg}, {shape_1_out});
+    builder.AddNode("Shape", {neg_out}, {shape_2_out});
+    builder.AddNode("Shape", {squeeze_out}, {shape_3_out});
+    builder.AddNode("Shape", {cast_out}, {shape_4_out});
+    builder.AddNode("Shape", {unsqueeze_out}, {shape_5_out});
+  };
+
+  auto pre_graph_checker = [&](Graph& graph) {
+    InlinedHashMap<std::string, int> ref_count;
+    for (auto& node : graph.Nodes()) {
+      if (node.OpType() == "Shape") {
+        std::string name = node.InputDefs()[0]->Name();
+        if (ref_count.find(name) == ref_count.end()) {
+          ref_count[name] = 1;
+        } else {
+          ref_count[name]++;
+        }
+      }
+    }
+    TEST_RETURN_IF_NOT(ref_count.size() == 5);
+    return Status::OK();
+  };
+
+  auto post_graph_checker = [&](Graph& graph) {
+    InlinedHashMap<std::string, int> ref_count;
+    for (auto& node : graph.Nodes()) {
+      if (node.OpType() == "Shape") {
+        std::string name = node.InputDefs()[0]->Name();
+        if (ref_count.find(name) == ref_count.end()) {
+          ref_count[name] = 1;
+        } else {
+          ref_count[name]++;
+        }
+      }
+    }
+    TEST_RETURN_IF_NOT(ref_count.size() == 2);
+    int sum = 0, mul = 1;
+    for (auto& entry : ref_count) {
+      sum += entry.second;
+      mul *= entry.second;
+    }
+    TEST_RETURN_IF_NOT(sum == 5 && mul == 6);
+    return Status::OK();
+  };
+
+  std::unique_ptr<GraphTransformer> transformer = std::make_unique<ShapeInputMerge>();
+  ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
+                                        1, pre_graph_checker, post_graph_checker));
+}
+
 }  // namespace test
 }  // namespace onnxruntime
diff --git a/onnxruntime/test/optimizer/nhwc_transformer_test.cc b/onnxruntime/test/optimizer/nhwc_transformer_test.cc
index c254d340cdcb8..e6f0a259805e5 100644
--- a/onnxruntime/test/optimizer/nhwc_transformer_test.cc
+++ b/onnxruntime/test/optimizer/nhwc_transformer_test.cc
@@ -518,7 +518,7 @@ TEST(NhwcTransformerTests, ConvMixTensorRanks) {
 
 #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
 
-std::vector<MLFloat16> randomfp16(const std::vector<int64_t>& shape, MLFloat16 min, MLFloat16 max) {
+static std::vector<MLFloat16> ARangeOfFP16Values(const std::vector<int64_t>& shape, MLFloat16 min, MLFloat16 max) {
   std::vector<MLFloat16> val(detail::SizeFromDims(shape));
   float start = min.ToFloat();
   float end = max.ToFloat();
@@ -534,22 +534,22 @@ std::vector<MLFloat16> randomfp16(const std::vector<int64_t>& shape, MLFloat16 m
   return val;
 }
 
-template <>
-NodeArg* ModelTestBuilder::MakeInput<MLFloat16>(const std::vector<int64_t>& shape, MLFloat16 min, MLFloat16 max) {
-  return MakeInput<MLFloat16>(shape, randomfp16(shape, min, max));
+static NodeArg* MakeInputARangeFP16(ModelTestBuilder& builder, const std::vector<int64_t>& shape,
+                                    MLFloat16 min, MLFloat16 max) {
+  return builder.MakeInput<MLFloat16>(shape, ARangeOfFP16Values(shape, min, max));
 }
 
-template <>
-NodeArg* ModelTestBuilder::MakeInitializer(const std::vector<int64_t>& shape, MLFloat16 min, MLFloat16 max) {
-  return MakeInitializer(shape, randomfp16(shape, min, max));
+static NodeArg* MakeInitializerARangeFP16(ModelTestBuilder& builder, const std::vector<int64_t>& shape,
+                                          MLFloat16 min, MLFloat16 max) {
+  return builder.MakeInitializer<MLFloat16>(shape, ARangeOfFP16Values(shape, min, max));
 }
 
 TEST(NhwcTransformerTests, ConvFp16) {
   auto test_case = [&](const std::vector<int64_t>& input_shape, const std::vector<int64_t>& weights_shape) {
     auto build_test_case = [&](ModelTestBuilder& builder) {
-      auto* input_arg = builder.MakeInput<MLFloat16>(input_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
+      auto* input_arg = MakeInputARangeFP16(builder, input_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
       auto* output_arg = builder.MakeOutput();
-      auto* weight_arg = builder.MakeInitializer(weights_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
+      auto* weight_arg = MakeInitializerARangeFP16(builder, weights_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
 
       builder.AddConvNode(input_arg, weight_arg, output_arg);
     };
@@ -575,10 +575,10 @@ TEST(NhwcTransformerTests, ConvFp16) {
 TEST(NhwcTransformerTests, ConvMaxPoolFp16) {
   auto test_case = [&](const std::vector<int64_t>& input_shape, const std::vector<int64_t>& weights_shape) {
     auto build_test_case = [&](ModelTestBuilder& builder) {
-      auto* input_arg = builder.MakeInput<MLFloat16>(input_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
+      auto* input_arg = MakeInputARangeFP16(builder, input_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
       auto* conv_output_arg = builder.MakeIntermediate();
       auto* output_arg = builder.MakeOutput();
-      auto* conv_weight_arg = builder.MakeInitializer(weights_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
+      auto* conv_weight_arg = MakeInitializerARangeFP16(builder, weights_shape, MLFloat16(-1.5f), MLFloat16(1.5f));
 
       builder.AddConvNode(input_arg, conv_weight_arg, conv_output_arg);
       Node& pool_node = builder.AddNode("MaxPool", {conv_output_arg}, {output_arg});
@@ -609,13 +609,13 @@ TEST(NhwcTransformerTests, ConvMaxPoolFp16) {
 
 TEST(NhwcTransformerTests, ConvGlobalAveragePoolFp16) {
   auto build_test_case = [&](ModelTestBuilder& builder) {
-    auto* input_arg = builder.MakeInput<MLFloat16>({1, 23, 13, 13}, MLFloat16(-1.5f), MLFloat16(1.5f));
+    auto* input_arg = MakeInputARangeFP16(builder, {1, 23, 13, 13}, MLFloat16(-1.5f), MLFloat16(1.5f));
     auto* conv1_output_arg = builder.MakeIntermediate();
     auto* conv2_output_arg = builder.MakeIntermediate();
     auto* gavgpool1_output_arg = builder.MakeIntermediate();
     auto* output_arg = builder.MakeOutput();
-    auto* conv1_weight_arg = builder.MakeInitializer<MLFloat16>({30, 23, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f));
-    auto* conv2_weight_arg = builder.MakeInitializer<MLFloat16>({16, 30, 1, 1}, MLFloat16(-1.5f), MLFloat16(1.5f));
+    auto* conv1_weight_arg = MakeInitializerARangeFP16(builder, {30, 23, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f));
+    auto* conv2_weight_arg = MakeInitializerARangeFP16(builder, {16, 30, 1, 1}, MLFloat16(-1.5f), MLFloat16(1.5f));
 
     Node& conv1_node = builder.AddConvNode(input_arg, conv1_weight_arg, conv1_output_arg);
     conv1_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
@@ -640,13 +640,13 @@ TEST(NhwcTransformerTests, ConvGlobalAveragePoolFp16) {
 
 TEST(NhwcTransformerTests, ConvAveragePoolFp16) {
   auto build_test_case = [&](ModelTestBuilder& builder) {
-    auto* input_arg = builder.MakeInput<MLFloat16>({1, 23, 13, 13}, MLFloat16(-1.5f), MLFloat16(1.5f));
+    auto* input_arg = MakeInputARangeFP16(builder, {1, 23, 13, 13}, MLFloat16(-1.5f), MLFloat16(1.5f));
     auto* conv1_output_arg = builder.MakeIntermediate();
     auto* conv2_output_arg = builder.MakeIntermediate();
     auto* avgpool1_output_arg = builder.MakeIntermediate();
     auto* output_arg = builder.MakeOutput();
-    auto* conv1_weight_arg = builder.MakeInitializer<MLFloat16>({30, 23, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f));
-    auto* conv2_weight_arg = builder.MakeInitializer<MLFloat16>({16, 30, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f));
+    auto* conv1_weight_arg = MakeInitializerARangeFP16(builder, {30, 23, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f));
+    auto* conv2_weight_arg = MakeInitializerARangeFP16(builder, {16, 30, 3, 3}, MLFloat16(-1.5f), MLFloat16(1.5f));
 
     Node& conv1_node = builder.AddConvNode(input_arg, conv1_weight_arg, conv1_output_arg);
     conv1_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc
index 13333f1558cc6..fbd5c9b5a137b 100644
--- a/onnxruntime/test/optimizer/qdq_transformer_test.cc
+++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc
@@ -2,6 +2,7 @@
 // Licensed under the MIT License.
 
 #include "core/framework/compute_capability.h"
+#include "core/framework/node_unit.h"
 #include "core/graph/model.h"
 #include "core/graph/onnx_protobuf.h"
 #include "core/mlas/inc/mlas.h"
@@ -9,7 +10,6 @@
 #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
 #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
 #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
-#include "core/optimizer/utils.h"
 #include "core/providers/partitioning_utils.h"
 #include "core/session/onnxruntime_session_options_config_keys.h"
 #include "core/session/environment.h"
@@ -30,10 +30,6 @@
 #pragma warning(disable : 4127)
 #endif  // #if defined(_MSC_VER)
 
-#ifdef USE_NNAPI
-#include "core/providers/shared/node_unit/node_unit.h"
-#endif  // #ifdef USE_NNAPI
-
 struct QDQOpKeys {
   const char* quantize_linear;
   const char* dequantize_linear;
@@ -3243,14 +3239,14 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
     ASSERT_EQ(std::vector<NodeIndex>({4}), qdq_group.q_nodes);
   }
 
-// The function GetAllNodeUnits is enabled for NNAPI EP only for now
-#ifdef USE_NNAPI
+// The function GetAllNodeUnits is used by NNAPI, XNNPACK and QNN
+#if defined(USE_NNAPI) || defined(USE_QNN) || defined(USE_XNNPACK)
   {
     // Get all the NodeUnits in the graph_viewer
     std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
     std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
 
-    std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(whole_graph_viewer);
+    std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer);
 
     // We should get a single QDQ Node unit in the result
     ASSERT_EQ(1, node_unit_holder.size());
@@ -3288,7 +3284,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
     verify_io_def(qdq_node_unit.Inputs()[2], *whole_graph_viewer.GetNode(2));   // DQ_bias
     verify_io_def(qdq_node_unit.Outputs()[0], *whole_graph_viewer.GetNode(4));  // Q_output
   }
-#endif  // #ifdef USE_NNAPI
+#endif  // defined(USE_NNAPI) || defined(USE_QNN) || defined(USE_XNNPACK)
 
   // Create a graph viewer covers part of the graph
   // Make sure the qdq conv selector will fail for the partial graph
diff --git a/onnxruntime/test/perftest/TFModelInfo.cc b/onnxruntime/test/perftest/TFModelInfo.cc
deleted file mode 100644
index 82f5359545b4d..0000000000000
--- a/onnxruntime/test/perftest/TFModelInfo.cc
+++ /dev/null
@@ -1,58 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "TFModelInfo.h"
-
-#include <memory>
-
-#include <core/platform/env.h>
-
-std::unique_ptr<TestModelInfo> TFModelInfo::Create(_In_ const PATH_CHAR_TYPE* model_url) {
-  std::unique_ptr<TFModelInfo> model_info = std::make_unique<TFModelInfo>();
-
-  model_info->model_url_ = model_url;
-  std::basic_string<PATH_CHAR_TYPE> meta_file_path = model_url;
-  meta_file_path.append(ORT_TSTR(".meta"));
-  const onnxruntime::Env& env = onnxruntime::Env::Default();
-  size_t len;
-  auto status = env.GetFileLength(meta_file_path.c_str(), len);
-  if (!status.IsOK()) {
-    ORT_THROW(status.ErrorMessage());
-  }
-  std::string file_content;
-  file_content.resize(len);
-  auto buffer_span = gsl::make_span(&file_content[0], file_content.size());
-  status = onnxruntime::Env::Default().ReadFileIntoBuffer(meta_file_path.c_str(), 0, len, buffer_span);
-  if (!status.IsOK()) {
-    ORT_THROW(status.ErrorMessage());
-  }
-  // this string is not null terminated
-  std::istringstream is{file_content};
-
-  std::string line;
-  while (std::getline(is, line)) {
-    size_t line_len = 0;
-    if (!line.empty() && line.back() == '\n') {
-      line_len = line.length() - 1;
-      if (line_len > 0 && line[line_len - 1] == '\r') {
-        --line_len;
-      }
-      line.resize(line_len);
-    }
-    if (line.empty()) continue;
-    if (line.compare(0, 6, "input=") == 0) {
-      model_info->input_names_.push_back(line.substr(6));
-    } else if (line.compare(0, 7, "output=") == 0) {
-      model_info->output_names_.push_back(line.substr(7));
-    } else {
-      ORT_THROW("unknown line:", line.size());
-    }
-  }
-
-  return model_info;
-}
-
-int TFModelInfo::GetInputCount() const { return static_cast<int>(input_names_.size()); }
-int TFModelInfo::GetOutputCount() const { return static_cast<int>(output_names_.size()); }
-const std::string& TFModelInfo::GetInputName(size_t i) const { return input_names_[i]; }
-const std::string& TFModelInfo::GetOutputName(size_t i) const { return output_names_[i]; }
diff --git a/onnxruntime/test/perftest/TFModelInfo.h b/onnxruntime/test/perftest/TFModelInfo.h
deleted file mode 100644
index 2ca60010e300b..0000000000000
--- a/onnxruntime/test/perftest/TFModelInfo.h
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include "TestCase.h"
-#include <string>
-#include <vector>
-
-class TFModelInfo : public TestModelInfo {
- public:
-  const PATH_CHAR_TYPE* GetModelUrl() const override { return model_url_.c_str(); }
-
-  const std::string& GetNodeName() const override { return node_name_; }
-  const ONNX_NAMESPACE::ValueInfoProto* GetInputInfoFromModel(size_t) const override { return nullptr; }
-  const ONNX_NAMESPACE::ValueInfoProto* GetOutputInfoFromModel(size_t) const override { return nullptr; }
-
-  int GetInputCount() const override;
-  int GetOutputCount() const override;
-  const std::string& GetInputName(size_t i) const override;
-  const std::string& GetOutputName(size_t i) const override;
-  ~TFModelInfo() override = default;
-
-  static std::unique_ptr<TestModelInfo> Create(_In_ const PATH_CHAR_TYPE* model_url);
-  TFModelInfo() = default;
-
- private:
-  std::basic_string<PATH_CHAR_TYPE> model_url_;
-  std::vector<std::string> input_names_;
-  std::vector<std::string> output_names_;
-  std::string node_name_;
-};
diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc
index 7cfbe0a84e3e6..93e44fd8e8d2d 100644
--- a/onnxruntime/test/perftest/command_args_parser.cc
+++ b/onnxruntime/test/perftest/command_args_parser.cc
@@ -64,16 +64,22 @@ namespace perftest {
       "\t    Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n"
       "\t    [Example] -C \"session.disable_cpu_ep_fallback|1 ep.context_enable|1\" \n"
       "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n"
+      "\t    [Usage]: -e <provider_name> -i '<key1>|<value1> <key2>|<value2>'\n"
+      "\n"
       "\t    [DML only] [performance_preference]: DML device performance preference, options: 'default', 'minimum_power', 'high_performance', \n"
       "\t    [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n"
       "\t    [DML only] [disable_metacommands]: Options: 'true', 'false', \n"
       "\t    [DML only] [enable_dynamic_graph_fusion]: Options: 'true', 'false', \n"
+      "\t    [DML only] [enable_graph_serialization]: Options: 'true', 'false', \n"
+      "\n"
       "\t    [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n"
       "\t    [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n"
       "\t    [OpenVINO only] [enable_npu_fast_compile]: Optionally enabled to speeds up the model's compilation on NPU device targets.\n"
       "\t    [OpenVINO only] [num_of_threads]: Overrides the accelerator hardware type and precision with these values at runtime.\n"
       "\t    [OpenVINO only] [cache_dir]: Explicitly specify the path to dump and load the blobs(Model caching) or cl_cache (Kernel Caching) files feature. If blob files are already present, it will be directly loaded.\n"
       "\t    [OpenVINO only] [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU device(Reduces the CPU Utilization while using GPU) \n"
+      "\t    [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_npu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"<path>\"\"\n"
+      "\n"
       "\t    [QNN only] [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/folderpath/libQnnCpu.so'.\n"
       "\t    [QNN only] [profiling_level]: QNN profiling level, options: 'basic', 'detailed', default 'off'.\n"
       "\t    [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n"
@@ -88,9 +94,10 @@ namespace perftest {
       "\t    [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n"
       "\t    Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n"
       "\t    [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n"
-      "\t [Usage]: -e <provider_name> -i '<key1>|<value1> <key2>|<value2>'\n\n"
-      "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_npu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"<path>\"\"\n"
-      "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n\n"
+      "\t    [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n"
+      "\t    Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n"
+      "\t    [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n"
+      "\n"
       "\t    [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n"
       "\t    [TensorRT only] [trt_min_subgraph_size]: Minimum size of TensorRT subgraphs.\n"
       "\t    [TensorRT only] [trt_max_workspace_size]: Set TensorRT maximum workspace size in byte.\n"
@@ -107,20 +114,23 @@ namespace perftest {
       "\t    [TensorRT only] [trt_force_sequential_engine_build]: Force TensorRT engines to be built sequentially.\n"
       "\t    [TensorRT only] [trt_context_memory_sharing_enable]: Enable TensorRT context memory sharing between subgraphs.\n"
       "\t    [TensorRT only] [trt_layer_norm_fp32_fallback]: Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow.\n"
-      "\t [Usage]: -e <provider_name> -i '<key1>|<value1> <key2>|<value2>'\n\n"
-      "\t [Example] [For TensorRT EP] -e tensorrt -i 'trt_fp16_enable|true trt_int8_enable|true trt_int8_calibration_table_name|calibration.flatbuffers trt_int8_use_native_calibration_table|false trt_force_sequential_engine_build|false'\n"
+      "\t    [Example] [For TensorRT EP] -e tensorrt -i 'trt_fp16_enable|true trt_int8_enable|true trt_int8_calibration_table_name|calibration.flatbuffers trt_int8_use_native_calibration_table|false trt_force_sequential_engine_build|false'\n"
+      "\n"
       "\t    [NNAPI only] [NNAPI_FLAG_USE_FP16]: Use fp16 relaxation in NNAPI EP..\n"
       "\t    [NNAPI only] [NNAPI_FLAG_USE_NCHW]: Use the NCHW layout in NNAPI EP.\n"
       "\t    [NNAPI only] [NNAPI_FLAG_CPU_DISABLED]: Prevent NNAPI from using CPU devices.\n"
       "\t    [NNAPI only] [NNAPI_FLAG_CPU_ONLY]: Using CPU only in NNAPI EP.\n"
-      "\t [Usage]: -e <provider_name> -i '<key1> <key2>'\n\n"
-      "\t [Example] [For NNAPI EP] -e nnapi -i \" NNAPI_FLAG_USE_FP16 NNAPI_FLAG_USE_NCHW NNAPI_FLAG_CPU_DISABLED \"\n"
+      "\t    [Example] [For NNAPI EP] -e nnapi -i \"NNAPI_FLAG_USE_FP16 NNAPI_FLAG_USE_NCHW NNAPI_FLAG_CPU_DISABLED\"\n"
+      "\n"
+      "\t    [CoreML only] [COREML_FLAG_CREATE_MLPROGRAM]: Create an ML Program model instead of Neural Network.\n"
+      "\t    [Example] [For CoreML EP] -e coreml -i \"COREML_FLAG_CREATE_MLPROGRAM\"\n"
+      "\n"
       "\t    [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n"
       "\t    [SNPE only] [priority]: execution priority, options: 'low', 'normal'. \n"
       "\t    [SNPE only] [buffer_type]: options: 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. default: ITENSOR'. \n"
       "\t    [SNPE only] [enable_init_cache]: enable SNPE init caching feature, set to 1 to enabled it. Disabled by default. \n"
-      "\t [Usage]: -e <provider_name> -i '<key1>|<value1> <key2>|<value2>' \n\n"
-      "\t [Example] [For SNPE EP] -e snpe -i \"runtime|CPU priority|low\" \n\n"
+      "\t    [Example] [For SNPE EP] -e snpe -i \"runtime|CPU priority|low\" \n\n"
+      "\n"
       "\t-T [Set intra op thread affinities]: Specify intra op thread affinity string\n"
       "\t [Example]: -T 1,2;3,4;5,6 or -T 1-2;3-4;5-6 \n"
       "\t\t Use semicolon to separate configuration between threads.\n"
@@ -128,6 +138,7 @@ namespace perftest {
       "\t\t The number of affinities must be equal to intra_op_num_threads - 1\n\n"
       "\t-D [Disable thread spinning]: disable spinning entirely for thread owned by onnxruntime intra-op thread pool.\n"
       "\t-Z [Force thread to stop spinning between runs]: disallow thread from spinning during runs to reduce cpu usage.\n"
+      "\t-n [Exit after session creation]: allow user to measure session creation time to measure impact of enabling any initialization optimizations.\n"
       "\t-h: help\n");
 }
 #ifdef _WIN32
@@ -190,7 +201,7 @@ static bool ParseSessionConfigs(const std::string& configs_string,
 
 /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) {
   int ch;
-  while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqz"))) != -1) {
+  while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqzn"))) != -1) {
     switch (ch) {
       case 'f': {
         std::basic_string<ORTCHAR_T> dim_name;
@@ -219,9 +230,6 @@ static bool ParseSessionConfigs(const std::string& configs_string,
           return false;
         }
         break;
-      case 'b':
-        test_config.backend = optarg;
-        break;
       case 'p':
         test_config.run_config.profile_file = optarg;
         break;
@@ -373,6 +381,9 @@ static bool ParseSessionConfigs(const std::string& configs_string,
       case 'Z':
         test_config.run_config.disable_spinning_between_run = true;
         break;
+      case 'n':
+        test_config.run_config.exit_after_session_creation = true;
+        break;
       case '?':
       case 'h':
       default:
diff --git a/onnxruntime/test/perftest/main.cc b/onnxruntime/test/perftest/main.cc
index 36f08167c2217..43bf54963cabb 100644
--- a/onnxruntime/test/perftest/main.cc
+++ b/onnxruntime/test/perftest/main.cc
@@ -43,6 +43,13 @@ int real_main(int argc, char* argv[]) {
   }
   std::random_device rd;
   perftest::PerformanceRunner perf_runner(env, test_config, rd);
+
+  // Exit if user enabled -n option so that user can measure session creation time
+  if (test_config.run_config.exit_after_session_creation) {
+    perf_runner.LogSessionCreationTime();
+    return 0;
+  }
+
   auto status = perf_runner.Run();
   if (!status.IsOK()) {
     printf("Run failed:%s\n", status.ErrorMessage().c_str());
diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc
index 87506c7240578..9743ed18a6cc0 100644
--- a/onnxruntime/test/perftest/ort_test_session.cc
+++ b/onnxruntime/test/perftest/ort_test_session.cc
@@ -18,6 +18,7 @@
 
 #ifdef USE_DML
 #include "core/providers/dml/dml_provider_factory.h"
+#include "core/providers/dml/dml_session_options_config_keys.h"
 #endif
 
 #ifdef _WIN32
@@ -246,7 +247,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
       if (key == "device_type") {
         std::set<std::string> ov_supported_device_types = {"CPU_FP32", "CPU_FP16", "GPU_FP32",
                                                            "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16",
-                                                           "GPU.0_FP16", "GPU.1_FP16"};
+                                                           "GPU.0_FP16", "GPU.1_FP16", "NPU"};
         if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) {
           ov_options[key] = value;
         } else if (value.find("HETERO:") == 0) {
@@ -259,7 +260,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
           ORT_THROW(
               "[ERROR] [OpenVINO] You have selected a wrong configuration value for the key 'device_type'. "
               "Select from 'CPU_FP32', 'CPU_FP16', 'GPU_FP32', 'GPU.0_FP32', 'GPU.1_FP32', 'GPU_FP16', "
-              "'GPU.0_FP16', 'GPU.1_FP16' or from"
+              "'GPU.0_FP16', 'GPU.1_FP16', 'NPU' or from"
               " HETERO/MULTI/AUTO options available. \n");
         }
       } else if (key == "device_id") {
@@ -381,11 +382,20 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
           std::string str = str_stream.str();
           ORT_THROW("Wrong value for htp_arch. select from: " + str);
         }
+      } else if (key == "enable_htp_fp16_precision") {
+        std::unordered_set<std::string> supported_options = {"0", "1"};
+        if (supported_options.find(value) == supported_options.end()) {
+          std::ostringstream str_stream;
+          std::copy(supported_options.begin(), supported_options.end(),
+                    std::ostream_iterator<std::string>(str_stream, ","));
+          std::string str = str_stream.str();
+          ORT_THROW("Wrong value for enable_htp_fp16_precision. select from: " + str);
+        }
       } else {
         ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path',
 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode',
 'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', 'soc_model',
-'htp_arch', 'device_id'])");
+'htp_arch', 'device_id', 'enable_htp_fp16_precision'])");
       }
 
       qnn_options[key] = value;
@@ -467,7 +477,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
         nnapi_flags |= NNAPI_FLAG_CPU_ONLY;
       } else if (key.empty()) {
       } else {
-        ORT_THROW("[ERROR] [NNAPI] wrong key type entered. Choose from the following runtime key options that are available for NNAPI. ['NNAPI_FLAG_USE_FP16', 'NNAPI_FLAG_USE_NCHW', 'NNAPI_FLAG_CPU_DISABLED', 'NNAPI_FLAG_CPU_ONLY'] \n");
+        ORT_THROW(
+            "[ERROR] [NNAPI] wrong key type entered. Choose from the following runtime key options "
+            "that are available for NNAPI. "
+            "['NNAPI_FLAG_USE_FP16', 'NNAPI_FLAG_USE_NCHW', 'NNAPI_FLAG_CPU_DISABLED', 'NNAPI_FLAG_CPU_ONLY'] \n");
       }
     }
     Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options, nnapi_flags));
@@ -475,10 +488,31 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
     ORT_THROW("NNAPI is not supported in this build\n");
 #endif
   } else if (provider_name_ == onnxruntime::kCoreMLExecutionProvider) {
+#ifdef __APPLE__
 #ifdef USE_COREML
-    Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, 0));
+    uint32_t coreml_flags = 0;
+    std::string ov_string = performance_test_config.run_config.ep_runtime_config_string;
+    std::istringstream ss(ov_string);
+
+    std::string key;
+    while (ss >> key) {
+      if (key == "COREML_FLAG_CREATE_MLPROGRAM") {
+        coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM;
+        std::cout << "Enabling ML Program.\n";
+      } else if (key.empty()) {
+      } else {
+        ORT_THROW(
+            "[ERROR] [CoreML] wrong key type entered. Choose from the following runtime key options "
+            "that are available for CoreML. ['COREML_FLAG_CREATE_MLPROGRAM'] \n");
+      }
+    }
+    // COREML_FLAG_CREATE_MLPROGRAM
+    Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, coreml_flags));
 #else
-    ORT_THROW("COREML is not supported in this build\n");
+    ORT_THROW("CoreML is not supported in this build\n");
+#endif
+#else
+    ORT_THROW("COREML is not supported on this platform.\n");
 #endif
   } else if (provider_name_ == onnxruntime::kDmlExecutionProvider) {
 #ifdef USE_DML
@@ -542,6 +576,15 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
               "[ERROR] [DML] You have selcted wrong value for the key 'enable_dynamic_graph_fusion'. "
               "Select from 'true' or 'false' \n");
         }
+      } else if (key == "enable_graph_serialization") {
+        std::set<std::string> ov_supported_values = {"true", "True", "false", "False"};
+        if (ov_supported_values.find(value) != ov_supported_values.end()) {
+          session_options.AddConfigEntry(kOrtSessionOptionsConfigEnableGraphSerialization, value.data());
+        } else {
+          ORT_THROW(
+              "[ERROR] [DML] You have selcted wrong value for the key 'enable_graph_serialization'. "
+              "Select from 'true' or 'false' \n");
+        }
       }
     }
     session_options.AppendExecutionProvider("DML", dml_options);
@@ -615,7 +658,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
       std::string value(token.substr(pos + 1));
       vitisai_session_options[key] = value;
     }
-    session_options.AppendExecutionProvider("VitisAI", vitisai_session_options);
+    session_options.AppendExecutionProvider_VitisAI(vitisai_session_options);
 #else
     ORT_THROW("VitisAI is not supported in this build\n");
 #endif
diff --git a/onnxruntime/test/perftest/performance_runner.cc b/onnxruntime/test/perftest/performance_runner.cc
index 9f2cbcf6a21f1..08d77008dc25c 100644
--- a/onnxruntime/test/perftest/performance_runner.cc
+++ b/onnxruntime/test/perftest/performance_runner.cc
@@ -10,12 +10,8 @@
 #include <iostream>
 
 #include "TestCase.h"
-#include "TFModelInfo.h"
 #include "utils.h"
 #include "ort_test_session.h"
-#ifdef HAVE_TENSORFLOW
-#include "tf_test_session.h"
-#endif
 using onnxruntime::Status;
 
 // TODO: Temporary, while we bring up the threadpool impl...
@@ -115,6 +111,11 @@ void PerformanceResult::DumpToFile(const std::basic_string<ORTCHAR_T>& path, boo
   }
 }
 
+void PerformanceRunner::LogSessionCreationTime() {
+  std::chrono::duration<double> session_create_duration = session_create_end_ - session_create_start_;
+  std::cout << "\nSession creation time cost: " << session_create_duration.count() << " s\n";
+}
+
 Status PerformanceRunner::Run() {
   if (!Initialize()) {
     return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "failed to initialize.");
@@ -255,47 +256,25 @@ Status PerformanceRunner::ForkJoinRepeat() {
 }
 
 static std::unique_ptr<TestModelInfo> CreateModelInfo(const PerformanceTestConfig& performance_test_config_) {
-  if (CompareCString(performance_test_config_.backend.c_str(), ORT_TSTR("ort")) == 0) {
-    const auto& file_path = performance_test_config_.model_info.model_file_path;
+  const auto& file_path = performance_test_config_.model_info.model_file_path;
 #if !defined(ORT_MINIMAL_BUILD)
-    if (HasExtensionOf(file_path, ORT_TSTR("onnx"))) {
-      return TestModelInfo::LoadOnnxModel(performance_test_config_.model_info.model_file_path.c_str());
-    }
-#endif
-
-    if (HasExtensionOf(file_path, ORT_TSTR("ort"))) {
-      return TestModelInfo::LoadOrtModel(performance_test_config_.model_info.model_file_path.c_str());
-    }
-
-    ORT_NOT_IMPLEMENTED(ToUTF8String(file_path), " is not supported");
+  if (HasExtensionOf(file_path, ORT_TSTR("onnx"))) {
+    return TestModelInfo::LoadOnnxModel(performance_test_config_.model_info.model_file_path.c_str());
   }
+#endif
 
-  if (CompareCString(performance_test_config_.backend.c_str(), ORT_TSTR("tf")) == 0) {
-    return TFModelInfo::Create(performance_test_config_.model_info.model_file_path.c_str());
+  if (HasExtensionOf(file_path, ORT_TSTR("ort"))) {
+    return TestModelInfo::LoadOrtModel(performance_test_config_.model_info.model_file_path.c_str());
   }
 
-  ORT_NOT_IMPLEMENTED(ToUTF8String(performance_test_config_.backend), " is not supported");
-}
-
-static std::unique_ptr<TestSession> CreateSession(Ort::Env& env, std::random_device& rd,
-                                                  const PerformanceTestConfig& performance_test_config_,
-                                                  const TestModelInfo& test_model_info) {
-  if (CompareCString(performance_test_config_.backend.c_str(), ORT_TSTR("ort")) == 0) {
-    return std::make_unique<OnnxRuntimeTestSession>(env, rd, performance_test_config_, test_model_info);
-  }
-#ifdef HAVE_TENSORFLOW
-  if (CompareCString(performance_test_config_.backend.c_str(), ORT_TSTR("tf")) == 0) {
-    return new TensorflowTestSession(rd, performance_test_config_, test_model_info);
-  }
-#endif
-  ORT_NOT_IMPLEMENTED(ToUTF8String(performance_test_config_.backend), " is not supported");
+  ORT_NOT_IMPLEMENTED(ToUTF8String(file_path), " is not supported");
 }
 
 PerformanceRunner::PerformanceRunner(Ort::Env& env, const PerformanceTestConfig& test_config, std::random_device& rd)
     : performance_test_config_(test_config),
       test_model_info_(CreateModelInfo(test_config)) {
   session_create_start_ = std::chrono::high_resolution_clock::now();
-  session_ = CreateSession(env, rd, test_config, *test_model_info_);
+  session_ = std::make_unique<OnnxRuntimeTestSession>(env, rd, performance_test_config_, *test_model_info_);
   session_create_end_ = std::chrono::high_resolution_clock::now();
 }
 
diff --git a/onnxruntime/test/perftest/performance_runner.h b/onnxruntime/test/perftest/performance_runner.h
index da2df9c39f44c..cb1cb661550a7 100644
--- a/onnxruntime/test/perftest/performance_runner.h
+++ b/onnxruntime/test/perftest/performance_runner.h
@@ -46,6 +46,8 @@ class PerformanceRunner {
   ~PerformanceRunner();
   Status Run();
 
+  void LogSessionCreationTime();
+
   inline const PerformanceResult& GetResult() const { return performance_result_; }
 
   inline void SerializeResult() const {
diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h
index 5a49414a49004..70a6b12690d5d 100644
--- a/onnxruntime/test/perftest/test_configuration.h
+++ b/onnxruntime/test/perftest/test_configuration.h
@@ -63,13 +63,13 @@ struct RunConfig {
   std::string intra_op_thread_affinities;
   bool disable_spinning = false;
   bool disable_spinning_between_run = false;
+  bool exit_after_session_creation = false;
 };
 
 struct PerformanceTestConfig {
   ModelInfo model_info;
   MachineConfig machine_config;
   RunConfig run_config;
-  std::basic_string<ORTCHAR_T> backend = ORT_TSTR("ort");
 };
 
 }  // namespace perftest
diff --git a/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj b/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj
index f0582d41734bd..eb7345be3770b 100644
--- a/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj
+++ b/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj
@@ -49,6 +49,7 @@
 		229E595826586B4A006E41AE /* sigmoid.ort */ = {isa = PBXFileReference; lastKnownFileType = file; path = sigmoid.ort; sourceTree = "<group>"; };
 		22C1D8DE271A79AF002CEE67 /* ios_package_testUITests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = ios_package_testUITests.xctest; sourceTree = BUILT_PRODUCTS_DIR; };
 		22C1D8E9271A79FD002CEE67 /* ios_package_uitest_cpp_api.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_package_uitest_cpp_api.mm; sourceTree = "<group>"; };
+		513C65792B85789400E4EDFD /* ios_package_test.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = ios_package_test.entitlements; sourceTree = "<group>"; };
 		51C316B92B0881450033C70B /* macos_package_test.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = macos_package_test.app; sourceTree = BUILT_PRODUCTS_DIR; };
 		51C316BB2B0881450033C70B /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = "<group>"; };
 		51C316BC2B0881450033C70B /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = "<group>"; };
@@ -117,6 +118,7 @@
 		229E591E265869BF006E41AE /* ios_package_test */ = {
 			isa = PBXGroup;
 			children = (
+				513C65792B85789400E4EDFD /* ios_package_test.entitlements */,
 				229E591F265869BF006E41AE /* AppDelegate.h */,
 				229E5920265869BF006E41AE /* AppDelegate.m */,
 				229E5928265869BF006E41AE /* Main.storyboard */,
@@ -521,8 +523,11 @@
 			buildSettings = {
 				ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
 				ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
-				CODE_SIGN_STYLE = Automatic;
+				CODE_SIGNING_REQUIRED = NO;
+				CODE_SIGNING_STYLE = Automatic;
+				CODE_SIGN_ENTITLEMENTS = ios_package_test/ios_package_test.entitlements;
 				INFOPLIST_FILE = ios_package_test/Info.plist;
+				IPHONEOS_DEPLOYMENT_TARGET = 14.0;
 				LD_RUNPATH_SEARCH_PATHS = (
 					"$(inherited)",
 					"@executable_path/Frameworks",
@@ -530,9 +535,9 @@
 				PRODUCT_BUNDLE_IDENTIFIER = "ai.onnxruntime.tests.ios-package-test";
 				PRODUCT_NAME = "$(TARGET_NAME)";
 				SUPPORTED_PLATFORMS = "iphoneos iphonesimulator";
-				SUPPORTS_MACCATALYST = NO;
+				SUPPORTS_MACCATALYST = YES;
 				SUPPORTS_MAC_DESIGNED_FOR_IPHONE_IPAD = NO;
-				TARGETED_DEVICE_FAMILY = 1;
+				TARGETED_DEVICE_FAMILY = "1,2";
 			};
 			name = Debug;
 		};
@@ -541,8 +546,11 @@
 			buildSettings = {
 				ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
 				ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
-				CODE_SIGN_STYLE = Automatic;
+				CODE_SIGNING_REQUIRED = NO;
+				CODE_SIGNING_STYLE = Automatic;
+				CODE_SIGN_ENTITLEMENTS = ios_package_test/ios_package_test.entitlements;
 				INFOPLIST_FILE = ios_package_test/Info.plist;
+				IPHONEOS_DEPLOYMENT_TARGET = 14.0;
 				LD_RUNPATH_SEARCH_PATHS = (
 					"$(inherited)",
 					"@executable_path/Frameworks",
@@ -550,9 +558,9 @@
 				PRODUCT_BUNDLE_IDENTIFIER = "ai.onnxruntime.tests.ios-package-test";
 				PRODUCT_NAME = "$(TARGET_NAME)";
 				SUPPORTED_PLATFORMS = "iphoneos iphonesimulator";
-				SUPPORTS_MACCATALYST = NO;
+				SUPPORTS_MACCATALYST = YES;
 				SUPPORTS_MAC_DESIGNED_FOR_IPHONE_IPAD = NO;
-				TARGETED_DEVICE_FAMILY = 1;
+				TARGETED_DEVICE_FAMILY = "1,2";
 			};
 			name = Release;
 		};
@@ -563,7 +571,7 @@
 				CODE_SIGN_STYLE = Automatic;
 				CURRENT_PROJECT_VERSION = 1;
 				GENERATE_INFOPLIST_FILE = YES;
-				IPHONEOS_DEPLOYMENT_TARGET = 13.0;
+				IPHONEOS_DEPLOYMENT_TARGET = 14.0;
 				LD_RUNPATH_SEARCH_PATHS = (
 					"$(inherited)",
 					"@executable_path/Frameworks",
@@ -585,7 +593,7 @@
 				CODE_SIGN_STYLE = Automatic;
 				CURRENT_PROJECT_VERSION = 1;
 				GENERATE_INFOPLIST_FILE = YES;
-				IPHONEOS_DEPLOYMENT_TARGET = 13.0;
+				IPHONEOS_DEPLOYMENT_TARGET = 14.0;
 				LD_RUNPATH_SEARCH_PATHS = (
 					"$(inherited)",
 					"@executable_path/Frameworks",
diff --git a/onnxruntime/test/platform/apple/apple_package_test/ios_package_test/ios_package_test.entitlements b/onnxruntime/test/platform/apple/apple_package_test/ios_package_test/ios_package_test.entitlements
new file mode 100644
index 0000000000000..ee95ab7e582d4
--- /dev/null
+++ b/onnxruntime/test/platform/apple/apple_package_test/ios_package_test/ios_package_test.entitlements
@@ -0,0 +1,10 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
+<plist version="1.0">
+<dict>
+	<key>com.apple.security.app-sandbox</key>
+	<true/>
+	<key>com.apple.security.network.client</key>
+	<true/>
+</dict>
+</plist>
diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc
index 16cce85f7cb0a..8d84c689cd23e 100644
--- a/onnxruntime/test/providers/base_tester.cc
+++ b/onnxruntime/test/providers/base_tester.cc
@@ -120,6 +120,20 @@ void BaseTester::SetOutputRelErr(const char* name, float v) {
   it->validation_params.relative_error = optional<float>(v);
 }
 
+void BaseTester::SetOutputTolerance(float abs_error, float rel_error) {
+  for (auto& output : output_data_) {
+    if (output.def.Exists()) {
+      if (abs_error >= 0.0f) {
+        output.validation_params.absolute_error = optional<float>(abs_error);
+      }
+
+      if (rel_error >= 0.0f) {
+        output.validation_params.relative_error = optional<float>(rel_error);
+      }
+    }
+  }
+}
+
 std::vector<int64_t> BaseTester::GetDimsForProto(gsl::span<const int64_t> dims) {
   std::vector<int64_t> dims_for_proto{dims.begin(), dims.end()};
   if (add_symbolic_dim_to_tensor_data_ >= 0 &&
@@ -613,6 +627,9 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter,
                          number_of_pre_packed_weights_counter,
                          number_of_shared_pre_packed_weights_counter);
     } else {
+      // synthetic EP name for testing CoreML EP with ML Program
+      constexpr const char* kCoreMLExecutionProviderMLProgram = "CoreMLExecutionProvider_MLProgram";
+
 #ifdef USE_TENSORRT
       // only run trt ep to reduce test time
       static const std::string all_provider_types[] = {
@@ -622,6 +639,9 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter,
       static const std::string all_provider_types[] = {
           kCpuExecutionProvider,
           kCudaExecutionProvider,
+#ifdef ENABLE_CUDA_NHWC_OPS
+          kCudaNHWCExecutionProvider,
+#endif
           kDnnlExecutionProvider,
           kTensorrtExecutionProvider,
           kOpenVINOExecutionProvider,
@@ -631,10 +651,16 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter,
           kNnapiExecutionProvider,
           kRocmExecutionProvider,
           kCoreMLExecutionProvider,
+          kCoreMLExecutionProviderMLProgram,
           kQnnExecutionProvider,
           kSnpeExecutionProvider,
           kXnnpackExecutionProvider,
       };
+
+      // need to special case any synthetic EP names in the exclude list
+      if (ctx_.excluded_provider_types.count(kCoreMLExecutionProvider) > 0) {
+        ctx_.excluded_provider_types.insert(kCoreMLExecutionProviderMLProgram);
+      }
 #endif
 
       bool has_run = false;
@@ -650,6 +676,10 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter,
           execution_provider = DefaultCpuExecutionProvider();
         else if (provider_type == onnxruntime::kCudaExecutionProvider)
           execution_provider = DefaultCudaExecutionProvider();
+#ifdef ENABLE_CUDA_NHWC_OPS
+        else if (provider_type == onnxruntime::kCudaNHWCExecutionProvider)
+          execution_provider = DefaultCudaNHWCExecutionProvider();
+#endif
         else if (provider_type == onnxruntime::kDnnlExecutionProvider)
           execution_provider = DefaultDnnlExecutionProvider();
         else if (provider_type == onnxruntime::kOpenVINOExecutionProvider)
@@ -668,6 +698,8 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter,
           execution_provider = DefaultRocmExecutionProvider();
         else if (provider_type == onnxruntime::kCoreMLExecutionProvider)
           execution_provider = DefaultCoreMLExecutionProvider();
+        else if (provider_type == kCoreMLExecutionProviderMLProgram)
+          execution_provider = DefaultCoreMLExecutionProvider(/*use_mlprogram*/ true);
         else if (provider_type == onnxruntime::kSnpeExecutionProvider)
           execution_provider = DefaultSnpeExecutionProvider();
         else if (provider_type == onnxruntime::kQnnExecutionProvider)
diff --git a/onnxruntime/test/providers/base_tester.h b/onnxruntime/test/providers/base_tester.h
index 5607e58315a12..c276ae494df43 100644
--- a/onnxruntime/test/providers/base_tester.h
+++ b/onnxruntime/test/providers/base_tester.h
@@ -519,9 +519,20 @@ class BaseTester {
     custom_session_registries_.push_back(registry);
   }
 
+  // For floating types (double/float/half/bfloat16), tolerance is similar to numpy.isclose:
+  //   absolute(expected_value - actual_value) <= abs_error + rel_error * absolute(expected_value)
+  // For integer types, tolerance parameters are ignored except the following cases:
+  //   For uint8, tolerance is only applied to NNAPI/XNNPACK/DML providers.
+  //   For int8, only abs_error is used, and rel_error is ignored. See checkers.cc for detail.
+  // If abs_error or rel_error is not set, a default value is used (search DefaultTolerance for detail).
   void SetOutputAbsErr(const char* name, float v);
   void SetOutputRelErr(const char* name, float v);
 
+  // Set absolute and relative tolerance for all existed outputs.
+  // Negative value will be ignored.
+  // Note that it will not set tolerance for new outputs added after this call.
+  void SetOutputTolerance(float abs_error, float rel_error = -1.0f);
+
   // Number of times to call InferenceSession::Run. The same feeds are used each time.
   // e.g. used to verify the generator ops behave as expected
   void SetNumRunCalls(int n) {
diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc
index 85ccb8f175f62..47c18c478dd9c 100644
--- a/onnxruntime/test/providers/checkers.cc
+++ b/onnxruntime/test/providers/checkers.cc
@@ -14,6 +14,95 @@
 namespace onnxruntime {
 namespace test {
 namespace {
+
+template <typename T>
+struct DefaultTolerance;
+
+template <>
+struct DefaultTolerance<double> {
+  static constexpr float absolute = 1e-5f;
+  static constexpr float relative = 1e-5f;
+
+  // Allow to have different default absolute tolerance for different providers.
+  static float get_absolute(const std::string& /*provider_type*/) {
+    return absolute;
+  }
+};
+
+template <>
+struct DefaultTolerance<float> {
+#if defined(ENABLE_TRAINING)
+  static constexpr float absolute = 1e-3f;
+#else
+  static constexpr float absolute = 1e-5f;
+#endif
+
+  static constexpr float relative = 1e-4f;
+
+  static float get_absolute(const std::string& /*provider_type*/) {
+    return absolute;
+  }
+};
+
+template <>
+struct DefaultTolerance<MLFloat16> {
+#if defined(ENABLE_TRAINING)
+  static constexpr float absolute = 0.005f;
+#else
+  // The thresholds for inference are estimated with PyTorch script like the following:
+  //    x = torch.rand(1000, 1000)
+  //    absolute = ((x + 1e-6).to(torch.float16) - x).abs().max() * 10
+  //    x[abs(x) < absolute] = absolute
+  //    relative = ((x - x.to(torch.float16)) / x).abs().max() * 2
+  static constexpr float absolute = 0.0025f;
+#endif
+
+  static constexpr float relative = 0.001f;
+
+  static float get_absolute(const std::string& provider_type) {
+    if (provider_type == kDmlExecutionProvider) {
+      return 0.005f;
+    }
+    return absolute;
+  }
+};
+
+template <>
+struct DefaultTolerance<BFloat16> {
+  // The thresholds for inference are estimated with PyTorch script like the following:
+  //    x = torch.rand(1000, 1000)
+  //    absolute = ((x + 1e-6).to(torch.bfloat16) - x).abs().max() * 10
+  //    x[abs(x) < absolute] = absolute
+  //    relative = ((x - x.to(torch.bfloat16)) / x).abs().max() * 2
+  static constexpr float absolute = 0.02f;
+  static constexpr float relative = 0.01f;
+
+  static float get_absolute(const std::string& /*provider_type*/) {
+    return absolute;
+  }
+};
+
+struct ToleranceParams {
+  float absolute;
+  float relative;
+};
+
+template <typename T>
+ToleranceParams get_tolerance_params(const ValidateOutputParams& params, const std::string& provider_type) {
+  ToleranceParams new_params;
+  new_params.absolute = params.absolute_error.has_value() ? *(params.absolute_error) : DefaultTolerance<T>::get_absolute(provider_type);
+  new_params.relative = params.relative_error.has_value() ? *(params.relative_error) : DefaultTolerance<T>::relative;
+  return new_params;
+}
+
+template <typename T>
+T get_tolerance(const ToleranceParams& params, T expected_value) {
+  static_assert(std::is_floating_point<T>::value, "T must be a floating point type");
+
+  // The formula is similar to numpy.isclose: https://numpy.org/doc/stable/reference/generated/numpy.isclose.html
+  return static_cast<T>(params.absolute) + static_cast<T>(params.relative) * std::abs(expected_value);
+}
+
 template <typename T>
 Tensor copy_sort(const Tensor& src, const AllocatorPtr& allocator) {
   Tensor result(src.DataType(), src.Shape(), allocator);
@@ -67,7 +156,7 @@ struct TensorCheck {
       cur_actual = actual.Data<T>();
     }
 
-    for (int i = 0; i < size; ++i) {
+    for (int64_t i = 0; i < size; ++i) {
       EXPECT_EQ(cur_expected[i], cur_actual[i]) << "i:" << i;
     }
   }
@@ -111,7 +200,7 @@ struct TensorCheck<uint8_t> {
       double threshold = has_abs_err ? *(params.absolute_error)
                                      : 0.0;
 
-      for (int i = 0; i < size; ++i) {
+      for (int64_t i = 0; i < size; ++i) {
         if (has_rel_err) {
           EXPECT_NEAR(cur_expected[i], cur_actual[i],
                       *(params.relative_error) * cur_expected[i])  // expected[i] is unsigned, can't be negative
@@ -121,7 +210,7 @@ struct TensorCheck<uint8_t> {
         }
       }
     } else {
-      for (int i = 0; i < size; ++i) {
+      for (int64_t i = 0; i < size; ++i) {
         EXPECT_EQ(cur_expected[i], cur_actual[i]) << "i:" << i;
       }
     }
@@ -153,15 +242,18 @@ struct TensorCheck<int8_t> {
       cur_actual = actual.template Data<int8_t>();
     }
 
-    const bool has_abs_err = params.absolute_error.has_value();
+    // When absolute error is less than 1 for int8, it has same effect as no tolerance.
+    const bool has_abs_err = params.absolute_error.has_value() && *(params.absolute_error) >= 1.0f;
+
+    // TODO: the relative error is not used for int8 yet.
     if (has_abs_err) {
       double threshold = *(params.absolute_error);
 
-      for (int i = 0; i < size; ++i) {
+      for (int64_t i = 0; i < size; ++i) {
         EXPECT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i;
       }
     } else {
-      for (int i = 0; i < size; ++i) {
+      for (int64_t i = 0; i < size; ++i) {
         EXPECT_EQ(cur_expected[i], cur_actual[i]) << "i:" << i;
       }
     }
@@ -173,12 +265,9 @@ struct TensorCheck<double> {
   void operator()(const Tensor& expected,
                   const Tensor& actual,
                   const ValidateOutputParams& params,
-                  const std::string& /*provider_type*/) const {
+                  const std::string& provider_type) const {
     auto size = actual.Shape().Size();
 
-    bool has_abs_err = params.absolute_error.has_value();
-    bool has_rel_err = params.relative_error.has_value();
-
     // deal with rare cases in which order of output data from a kernel MAY be
     // undefined
     Tensor expected_sorted, actual_sorted;
@@ -193,12 +282,9 @@ struct TensorCheck<double> {
       cur_actual = actual.Data<double>();
     }
 
-    double threshold = 0.001;
-#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
-    threshold = 0.005;
-#endif
+    auto tolerance_params = get_tolerance_params<double>(params, provider_type);
 
-    for (int i = 0; i < size; ++i) {
+    for (int64_t i = 0; i < size; ++i) {
       // NOTE: Check isnan first to work around MSVC linker bug when /LTCG:incremental is specified.
       // If the isinf check is first the isnan check and branch gets omitted
       if (std::isnan(cur_expected[i])) {
@@ -206,53 +292,36 @@ struct TensorCheck<double> {
       } else if (std::isinf(cur_expected[i])) {  // Test infinity for equality
         EXPECT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i;
       } else {
-        if (!has_abs_err && !has_rel_err) {
-          // the default for existing tests
-          EXPECT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i;
-        } else {
-          if (has_abs_err) {
-            EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.absolute_error)) << "i:" << i;
-          }
-          if (has_rel_err) {
-            EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.relative_error) * std::abs(cur_expected[i]))
-                << "i:" << i;
-          }
-        }
+        double tolerance = get_tolerance<double>(tolerance_params, cur_expected[i]);
+        EXPECT_NEAR(cur_expected[i], cur_actual[i], tolerance) << "i:" << i;
       }
     }
   }
 };
 
-template <typename TypeToCheck>
+template <typename T>
 void InternalNumericalCheck(const Tensor& expected,
                             const Tensor& actual,
                             const ValidateOutputParams& params,
-                            const std::string& /*provider_type*/) {
-  const bool has_abs_err = params.absolute_error.has_value();
-  const bool has_rel_err = params.relative_error.has_value();
-
+                            const std::string& provider_type) {
   // deal with rare cases in which order of output data from a kernel MAY be
   // undefined
   Tensor expected_sorted, actual_sorted;
-  const TypeToCheck* cur_expected;
-  const TypeToCheck* cur_actual;
+  const T* cur_expected;
+  const T* cur_actual;
   auto size = actual.Shape().Size();
   if (params.sort_output) {
-    sort_expected_and_actual_buffers<TypeToCheck>(expected, expected_sorted, actual, actual_sorted);
-    cur_expected = expected_sorted.Data<TypeToCheck>();
-    cur_actual = actual_sorted.Data<TypeToCheck>();
+    sort_expected_and_actual_buffers<T>(expected, expected_sorted, actual, actual_sorted);
+    cur_expected = expected_sorted.Data<T>();
+    cur_actual = actual_sorted.Data<T>();
   } else {
-    cur_expected = expected.Data<TypeToCheck>();
-    cur_actual = actual.Data<TypeToCheck>();
+    cur_expected = expected.Data<T>();
+    cur_actual = actual.Data<T>();
   }
 
-#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
-  constexpr float threshold = 0.005f;
-#else
-  constexpr float threshold = 0.0001f;
-#endif
+  auto tolerance_params = get_tolerance_params<T>(params, provider_type);
 
-  for (int i = 0; i < size; ++i) {
+  for (int64_t i = 0; i < size; ++i) {
     // NOTE: Check isnan first to work around MSVC linker bug when /LTCG:incremental is specified.
     // If the isinf check is first the isnan check and branch gets omitted
     if (std::isnan(cur_expected[i])) {
@@ -260,19 +329,8 @@ void InternalNumericalCheck(const Tensor& expected,
     } else if (std::isinf(cur_expected[i])) {  // Test infinity for equality
       EXPECT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i;
     } else {
-      if (!has_abs_err && !has_rel_err) {
-        // the default for existing tests
-        EXPECT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i;
-      } else {
-        if (has_abs_err) {
-          EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.absolute_error))
-              << "i:" << i;
-        }
-        if (has_rel_err) {
-          EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.relative_error) * std::abs(cur_expected[i]))
-              << "i:" << i;
-        }
-      }
+      T tolerance = get_tolerance<T>(tolerance_params, cur_expected[i]);
+      EXPECT_NEAR(cur_expected[i], cur_actual[i], tolerance) << "i:" << i;
     }
   }
 }
@@ -292,7 +350,7 @@ struct TensorCheck<MLFloat16> {
   void operator()(const Tensor& expected,
                   const Tensor& actual,
                   const ValidateOutputParams& params,
-                  const std::string& /*provider_type*/) const {
+                  const std::string& provider_type) const {
     auto* cur_expected = expected.Data<MLFloat16>();
     auto* cur_actual = actual.Data<MLFloat16>();
     auto size = actual.Shape().Size();
@@ -308,34 +366,16 @@ struct TensorCheck<MLFloat16> {
       sort_expected_and_actual_buffers<float>(f_expected, f_actual);
     }
 
-    const bool has_abs_err = params.absolute_error.has_value();
-    const bool has_rel_err = params.relative_error.has_value();
+    auto tolerance_params = get_tolerance_params<MLFloat16>(params, provider_type);
 
-    float threshold = 0.001f;
-#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING_CORE) || defined(USE_CUDA) || defined(USE_ROCM)
-    threshold = 0.005f;
-#elif defined(USE_DML)
-    threshold = 0.02f;
-#endif
-    for (int i = 0; i < size; ++i) {
+    for (int64_t i = 0; i < size; ++i) {
       if (std::isnan(f_expected[i])) {
         EXPECT_TRUE(std::isnan(f_expected[i])) << "Expected NaN. i:" << i;
       } else if (std::isinf(f_expected[i])) {  // Test infinity for equality
         EXPECT_EQ(f_expected[i], f_actual[i]) << "Expected infinity. i:" << i;
       } else {
-        if (!has_abs_err && !has_rel_err) {
-          // the default for existing tests
-          EXPECT_NEAR(f_expected[i], f_actual[i], threshold) << "i:" << i;
-        } else {
-          if (has_abs_err) {
-            EXPECT_NEAR(f_expected[i], f_actual[i], *(params.absolute_error))
-                << "i:" << i;
-          }
-          if (has_rel_err) {
-            EXPECT_NEAR(f_expected[i], f_actual[i], *(params.relative_error) * std::abs(static_cast<float>(cur_expected[i])))
-                << "i:" << i;
-          }
-        }
+        float tolerance = get_tolerance<float>(tolerance_params, f_expected[i]);
+        EXPECT_NEAR(f_expected[i], f_actual[i], tolerance) << "i:" << i;
       }
     }
   }
@@ -346,7 +386,7 @@ struct TensorCheck<BFloat16> {
   void operator()(const Tensor& expected,
                   const Tensor& actual,
                   const ValidateOutputParams& params,
-                  const std::string& /*provider_type*/) const {
+                  const std::string& provider_type) const {
     auto* cur_expected = expected.Data<BFloat16>();
     auto* cur_actual = actual.Data<BFloat16>();
     auto size = actual.Shape().Size();
@@ -362,32 +402,16 @@ struct TensorCheck<BFloat16> {
       sort_expected_and_actual_buffers<float>(f_expected, f_actual);
     }
 
-    /// XXX: May need to adjust threshold as BFloat is coarse
-    float abs_threshold = 0.0001f;
-    float threshold = 0.001f;
-#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING_CORE) || defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_DNNL)
-    threshold = 0.05f;  // expect at least 95% close
-#endif
+    auto tolerance_params = get_tolerance_params<BFloat16>(params, provider_type);
 
-    for (int i = 0; i < size; ++i) {
+    for (int64_t i = 0; i < size; ++i) {
       if (std::isnan(f_expected[i])) {
         EXPECT_TRUE(std::isnan(f_expected[i])) << "Expected NaN. i:" << i;
       } else if (std::isinf(f_expected[i])) {  // Test infinity for equality
         EXPECT_EQ(f_expected[i], f_actual[i]) << "Expected infinity. i:" << i;
       } else {
-        // the default for existing tests
-        const float max_value = fmax(fabs(f_expected[i]), fabs(f_actual[i]));
-        if (max_value != 0) {  // max_value = 0 means output and expected are 0s.
-          const float abs_error = fabs(f_expected[i] - f_actual[i]);
-          if (abs_error <= abs_threshold) {
-            // if the absolute error is small enough, then no need to calculate realative error
-            EXPECT_NEAR(0, abs_error, abs_threshold);
-          } else {
-            // default for existing tests.
-            const float rel_error = abs_error / max_value;
-            EXPECT_NEAR(0, rel_error, threshold);
-          }
-        }
+        float tolerance = get_tolerance<float>(tolerance_params, f_expected[i]);
+        EXPECT_NEAR(f_expected[i], f_actual[i], tolerance) << "i:" << i;
       }
     }
   }
diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc
index 7b6f1b9244be9..0f068ba48d3d8 100644
--- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc
+++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc
@@ -192,5 +192,29 @@ TEST(CoreMLExecutionProviderTest, TestOrtFormatModel) {
 #endif
 }
 
+#if defined(COREML_ENABLE_MLPROGRAM)
+// Names in CoreML cannot start with [0-9] or contain anything but "[a-z][A-Z][0-9]_"
+// Test that we fix invalid names in model inputs, initializers and outputs.
+// This is only enforced for ML Program, so we only do name sanitization when creating an ML Program format model.
+TEST(CoreMLExecutionProviderTest, TestNameSanitization) {
+  OpTester test("Clip", 11);
+
+  std::vector<int64_t> dims{3, 3};
+  test.AddInput<float>("0", dims,
+                       {-1.0f, 0.0f, 1.0f,
+                        -6.0f, 0.0f, 6.0f,
+                        -5.4f, 2.0f, 6.0f});
+  test.AddInput<float>("1.min", {}, {-5}, true);  // add as initializers
+  test.AddInput<float>("2/max", {}, {5}, true);
+  test.AddOutput<float>("3", dims,
+                        {-1.0f, 0.0f, 1.0f,
+                         -5.0f, 0.0f, 5.0f,
+                         -5.0f, 2.0f, 5.0f});
+
+  // TensorRT does not support Clip opset 11 yet.
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+}
+#endif
+
 }  // namespace test
 }  // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc
index ddb0a6620619c..d2e883331acd4 100644
--- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc
+++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc
@@ -116,13 +116,13 @@ TEST_F(ActivationOpTest, Relu) {
       "Relu",
       input_values_double,
       [](double x) { return std::max(x, 0.0); },
-      {},
+      {}, {},
       /*is_tensorrt_supported=*/false);
   TestActivationOp<int8_t>(
       "Relu",
       input_values_int8,
       [](int8_t x) { return std::max(x, static_cast<int8_t>(0)); },
-      {},
+      {}, {},
       /*is_tensorrt_supported=*/false,
       /*opset_version= */ 14);
 #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
@@ -133,7 +133,7 @@ TEST_F(ActivationOpTest, Relu) {
         if (x.ToFloat() > 0.0f) return x;
         return MLFloat16();
       },
-      {},
+      {}, {},
       /*is_tensorrt_supported=*/false,
       /*opset_version= */ 11);
 #endif  // MLAS_F16VEC_INTRINSICS_SUPPORTED
@@ -402,7 +402,7 @@ TEST_F(ActivationOpTest, Celu) {
       // TODO: Investigate why gcc 4 fails to compile without the explicit cast
       [alpha](float x) { return std::max(0.0f, x) + std::min(0.0f, alpha * (static_cast<float>(exp(x / alpha)) - 1)); },
       // Disable on TensorRT as it seems like it doesn't yet support Celu
-      {{"alpha", alpha}}, false, 12);
+      {{"alpha", alpha}}, {}, false, 12);
 }
 
 TEST_F(ActivationOpTest, LeakyRelu) {
@@ -410,7 +410,7 @@ TEST_F(ActivationOpTest, LeakyRelu) {
   TestActivationOp<float>("LeakyRelu",
                           input_values,
                           [alpha](float x) { return (x >= 0) ? x : alpha * x; },
-                          {{"alpha", alpha}});
+                          {{"alpha", alpha}}, {});
 }
 
 #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
@@ -442,7 +442,7 @@ TEST_F(ActivationOpTest, ThresholdedRelu) {
       "ThresholdedRelu",
       input_values,
       [alpha](float x) { return (x >= alpha) ? x : 0; },
-      {{"alpha", alpha}}, true, 10);
+      {{"alpha", alpha}}, {}, true, 10);
 }
 
 TEST_F(ActivationOpTest, Selu) {
@@ -452,7 +452,7 @@ TEST_F(ActivationOpTest, Selu) {
   TestActivationOp<float>("Selu",
                           input_values,
                           [](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
-                          {{"alpha", alpha}, {"gamma", gamma}});
+                          {{"alpha", alpha}, {"gamma", gamma}}, {});
 }
 
 TEST_F(ActivationOpTest, Selu_Attributes) {
@@ -462,7 +462,7 @@ TEST_F(ActivationOpTest, Selu_Attributes) {
   TestActivationOp<float>("Selu",
                           input_values,
                           [](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
-                          {{"alpha", alpha}, {"gamma", gamma}});
+                          {{"alpha", alpha}, {"gamma", gamma}}, {});
 }
 
 TEST_F(ActivationOpTest, Selu_GH10726) {
@@ -472,7 +472,7 @@ TEST_F(ActivationOpTest, Selu_GH10726) {
   TestActivationOp<float>("Selu",
                           {{1.f, -1.f}},
                           [](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; },
-                          {{"alpha", alpha}, {"gamma", gamma}});
+                          {{"alpha", alpha}, {"gamma", gamma}}, {});
 }
 
 TEST_F(ActivationOpTest, PRelu) {
@@ -625,7 +625,7 @@ TEST_F(ActivationOpNoInfTest, Softsign) {
 
         return result;
       },
-      {}, false);  // Disable TensorRT because result mismatches
+      {}, {}, false);  // Disable TensorRT because result mismatches
 }
 
 #if defined(ENABLE_TRAINING_OPS)
@@ -695,5 +695,35 @@ TEST(LeakyReluGradInferenceTest, Basic) {
 }
 #endif
 
+// Remove DNNL from running this test because DNNL Gelu op seems not check domain for kernel implementation.
+// It will run the DNNL Gelu op which only be part of standard of Gelu-20 op.
+// [TODO] Temporarily ignore this test for OpenVINO to avoid an exception due to mishandling of the
+// approximate parameter. Re-enable it later when the issue is fixed
+#if !defined(USE_DNNL) && !defined(USE_QNN) && !defined(USE_OPENVINO)
+TEST_F(ActivationOpTest, ONNX_Gelu) {
+  TestActivationOp<float>(
+      "Gelu",
+      input_values,
+      [](float x) { return 0.5 * x * (1 + erf(x * M_SQRT1_2)); }, {},
+      {{"approximate", "none"}}, true, 20);
+
+  TestActivationOp<float>(
+      "Gelu",
+      input_values,
+      [](float x) { return 0.5 * x * (1 + erf(x * M_SQRT1_2)); },
+      {},
+      {/*default value of approximate attribute is none */}, true, 20);
+
+  TestActivationOp<float>(
+      "Gelu",
+      input_values,
+      [](float x) {
+        return 0.5 * x * (1 + tanh(sqrt(2 / M_PI) * (x + 0.044715 * x * x * x)));
+      },
+      {},
+      {{"approximate", "tanh"}}, true, 20);
+}
+#endif
+
 }  // namespace test
 }  // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.h b/onnxruntime/test/providers/cpu/activation/activation_op_test.h
index b5ec1402584fb..9a74d763a13e3 100644
--- a/onnxruntime/test/providers/cpu/activation/activation_op_test.h
+++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.h
@@ -17,13 +17,16 @@ namespace test {
 template <typename T>
 inline void TestActivationOp(const char* szOp, const std::vector<std::vector<T>>& input_vals_vec,
                              std::function<T(T)> expected_func,
-                             const std::unordered_map<std::string, float> attribs = {},
+                             const std::unordered_map<std::string, float> float_attribs = {},
+                             const std::unordered_map<std::string, std::string> string_attribs = {},
                              bool is_tensorrt_supported = true, int opset_version = 7,
                              const char* domain = kOnnxDomain) {
   for (const std::vector<T>& input_vals : input_vals_vec) {
     OpTester test(szOp, opset_version, domain);
 
-    for (auto attr : attribs) test.AddAttribute<float>(attr.first, attr.second);
+    for (auto attr : float_attribs) test.AddAttribute<float>(attr.first, attr.second);
+    for (auto attr : string_attribs) test.AddAttribute(attr.first, attr.second);
+
     std::vector<int64_t> dims{(int64_t)input_vals.size()};
 
     std::vector<T> expected_vals;
@@ -66,6 +69,11 @@ inline void TestActivationOp(const char* szOp, const std::vector<std::vector<T>>
       test.SetOutputRelErr("Y", .000001f);
     }
 #endif
+
+    if (strcmp(szOp, "QuickGelu") == 0) {
+      test.SetOutputTolerance(0.0001f);
+    }
+
     test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded_providers);
   }
 }
diff --git a/onnxruntime/test/providers/cpu/generator/random_test.cc b/onnxruntime/test/providers/cpu/generator/random_test.cc
index 16582696a81d4..be049d1cf0ce3 100644
--- a/onnxruntime/test/providers/cpu/generator/random_test.cc
+++ b/onnxruntime/test/providers/cpu/generator/random_test.cc
@@ -36,7 +36,8 @@ TEST(Random, RandomNormal2DDouble) {
 
   // The expected_output is generated using std lib, which is used by CPU kernel only.
   // So we need to exclude other EPs here. Ditto for other places.
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider});
 }
 
 void RunRandomNormalLike3DFloat(bool infer_dtype = false) {
@@ -72,7 +73,8 @@ void RunRandomNormalLike3DFloat(bool infer_dtype = false) {
   test.AddOutput<float>("Y", dims, expected_output);
 
   // TensorRT does not support manual seed overrides and there will be result mismatch
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider});
 }
 
 TEST(Random, RandomNormalLike3DDouble) {
@@ -109,7 +111,8 @@ TEST(Random, RandomUniform1DFloat) {
   test.AddOutput<float>("Y", dims, expected_output);
 
   // TensorRT does not support manual seed overrides and there will be result mismatch
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider});
 }
 
 void RunRandomUniformLikeTest(bool infer_dtype = false) {
@@ -142,7 +145,8 @@ void RunRandomUniformLikeTest(bool infer_dtype = false) {
   test.AddOutput<double>("Y", dims, expected_output);
 
   // TensorRT does not support seed parameter and there will be result mismatch
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider});
 }
 
 TEST(Random, RandomUniformLike2DDouble) {
@@ -380,7 +384,7 @@ void RunRandomNormalGpuTest(const std::vector<int64_t> dims, const float mean, c
     test.AddOutput("Y", dims, fp16_data);
   }
 
-  auto output_verifier = [&](const std::vector<OrtValue>& fetches, const std::string& provider_type) {
+  auto output_verifier = [&](const std::vector<OrtValue>& fetches, const std::string& /*provider_type*/) {
     // Only one output, and mean of output values are near attribute mean.
     ASSERT_EQ(fetches.size(), 1u);
     const auto& output_tensor = fetches[0].Get<Tensor>();
@@ -472,7 +476,7 @@ void RunRandomUniformGpuTest(const std::vector<int64_t> dims, const float low, c
     test.AddOutput("Y", dims, fp16_data);
   }
 
-  auto output_verifier = [&](const std::vector<OrtValue>& fetches, const std::string& provider_type) {
+  auto output_verifier = [&](const std::vector<OrtValue>& fetches, const std::string& /*provider_type*/) {
     // Only one output. Each value in output tensoer is between low and high.
     // Mean of output values are near attribute mean of low and high.
     ASSERT_EQ(fetches.size(), 1u);
diff --git a/onnxruntime/test/providers/cpu/math/clip_test.cc b/onnxruntime/test/providers/cpu/math/clip_test.cc
index efb46e86d04e4..b5d5f84df950a 100644
--- a/onnxruntime/test/providers/cpu/math/clip_test.cc
+++ b/onnxruntime/test/providers/cpu/math/clip_test.cc
@@ -182,7 +182,7 @@ TEST(MathOpTest, Clip) {
   run_test(true);
 }
 
-// Use clip between [0, 6] as Relu6 (for some EPs, such as NNAPI)
+// Use clip between [0, 6] as Relu6 to test optimized path in some  EPs, such as NNAPI and CoreML
 TEST(MathOpTest, Clip_Relu6) {
   // To test NNAPI EP, we need the min/max to be in initializers
   auto run_test = [](bool min_max_are_initializer) {
@@ -208,6 +208,31 @@ TEST(MathOpTest, Clip_Relu6) {
   run_test(true);
 }
 
+// Use clip between [0, inf] as Relu to test optimized path in some EPs, such as CoreML
+TEST(MathOpTest, Clip_Relu) {
+  // To test NNAPI EP, we need the min/max to be in initializers
+  auto run_test = [](bool min_max_are_initializer) {
+    OpTester test("Clip", 11);
+
+    std::vector<int64_t> dims{3, 3};
+    test.AddInput<float>("X", dims,
+                         {-1.0f, 0.0f, 1.0f,
+                          -6.0f, 3.5f, 6.0f,
+                          -5.4f, 2.0f, 8.0f});
+    test.AddInput<float>("min", {}, {0.0f}, min_max_are_initializer);
+    test.AddOutput<float>("Y", dims,
+                          {0.0f, 0.0f, 1.0f,
+                           0.0f, 3.5f, 6.0f,
+                           0.0f, 2.0f, 8.0f});
+
+    // TensorRT does not support Clip opset 11 yet.
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+  };
+
+  run_test(false);
+  run_test(true);
+}
+
 // Use clip between [-1, 1] as Relu1 (for some EPs, such as NNAPI)
 TEST(MathOpTest, Clip_Relu1) {
   // To test NNAPI EP, we need the min/max to be in initializers
diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc
index 4e968d3de6b8a..423ea3f682f4c 100644
--- a/onnxruntime/test/providers/cpu/math/einsum_test.cc
+++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc
@@ -4,6 +4,7 @@
 #include "gtest/gtest.h"
 #include "test/providers/provider_test_utils.h"
 #include "test/common/cuda_op_test_utils.h"
+#include "test/common/trt_op_test_utils.h"
 #include "core/framework/data_types.h"
 #include "core/util/math.h"
 
@@ -50,7 +51,7 @@ TEST(Einsum, ExplicitEinsumAsTransposeOp_2D_input_With_Broadcasting) {
   test.AddAttribute<std::string>("equation", "...i->i...");
   test.AddInput<float>("x", {2, 2}, {1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("y", {2, 2}, {1.f, 3.f, 2.f, 4.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsBatchedTransposeOp_3D_input) {
@@ -58,7 +59,7 @@ TEST(Einsum, ExplicitEinsumAsBatchedTransposeOp_3D_input) {
   test.AddAttribute<std::string>("equation", "...ji->...ij");
   test.AddInput<float>("x", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("y", {2, 2, 2}, {1.f, 3.f, 2.f, 4.f, 1.f, 3.f, 2.f, 4.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 // Implicit
@@ -75,7 +76,7 @@ TEST(Einsum, ImplicitEinsumAsBatchedTransposeOp_3D_input) {
   test.AddAttribute<std::string>("equation", "...ji");
   test.AddInput<float>("x", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("y", {2, 2, 2}, {1.f, 3.f, 2.f, 4.f, 1.f, 3.f, 2.f, 4.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 // Theme: Axis/Axes reduction
@@ -102,7 +103,7 @@ TEST(Einsum, ExplicitEinsumAsBatchedReduceOp_3D_input_0) {
   test.AddAttribute<std::string>("equation", "...ji->...j");
   test.AddInput<float>("x", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("y", {2, 2}, {3.f, 7.f, 3.f, 7.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsBatchedReduceOp_3D_input_1) {
@@ -110,7 +111,7 @@ TEST(Einsum, ExplicitEinsumAsBatchedReduceOp_3D_input_1) {
   test.AddAttribute<std::string>("equation", "...ji->...");
   test.AddInput<float>("x", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("y", {2}, {10.f, 10.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 // Implicit
@@ -144,7 +145,7 @@ TEST(Einsum, ExplicitEinsumAsOuterProductWithTransposeOp_Multi_Input) {
   test.AddInput<float>("y", {2}, {3.f, 4.f});
   test.AddInput<float>("z", {2}, {5.f, 6.f});
   test.AddOutput<float>("o", {2, 2, 2}, {15.f, 18.f, 30.f, 36.f, 20.f, 24.f, 40.f, 48.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 // Implicit
@@ -155,7 +156,7 @@ TEST(Einsum, ImplicitEinsumAsOuterProductOp_2D_input) {
   test.AddInput<float>("y", {2}, {3.f, 4.f});
   test.AddInput<float>("z", {2}, {5.f, 6.f});
   test.AddOutput<float>("o", {2, 2, 2}, {15.f, 18.f, 20.f, 24.f, 30.f, 36.f, 40.f, 48.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ImplicitEinsumAsOuterProductOp_Multi_Input) {
@@ -165,7 +166,7 @@ TEST(Einsum, ImplicitEinsumAsOuterProductOp_Multi_Input) {
   test.AddInput<float>("y", {2}, {3.f, 4.f});
   test.AddInput<float>("z", {2}, {5.f, 6.f});
   test.AddOutput<float>("o", {2, 2, 2}, {15.f, 18.f, 20.f, 24.f, 30.f, 36.f, 40.f, 48.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 // Theme: MatMul
 
@@ -233,7 +234,7 @@ TEST(Einsum, ExplicitEinsumAsMatmul_Multi_Input) {
   test.AddInput<float>("y", {2, 2}, {1.f, 2.f, 3.f, 4.f});
   test.AddInput<float>("z", {2, 2}, {1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2, 2}, {37.f, 81.f, 54.f, 118.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsBatchedMatmul) {
@@ -251,7 +252,7 @@ TEST(Einsum, ExplicitEinsumAsBatchedMatmulWithBroadcasting_0) {
   test.AddInput<float>("x", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddInput<float>("y", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2, 2, 2}, {7.f, 10.f, 15.f, 22.f, 7.f, 10.f, 15.f, 22.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsBatchedMatmulWithBroadcasting_1) {
@@ -260,7 +261,7 @@ TEST(Einsum, ExplicitEinsumAsBatchedMatmulWithBroadcasting_1) {
   test.AddInput<float>("x", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddInput<float>("y", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2, 2, 2}, {14.f, 20.f, 30.f, 44.f, 14.f, 20.f, 30.f, 44.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsMatmul_OutputTransposed) {
@@ -303,7 +304,7 @@ TEST(Einsum, ImplicitEinsumAsMatmul_Multi_Input) {
   test.AddInput<float>("y", {2, 2}, {1.f, 2.f, 3.f, 4.f});
   test.AddInput<float>("z", {2, 2}, {1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2, 2}, {37.f, 54.f, 81.f, 118.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 TEST(Einsum, ImplicitEinsumAsBatchedMatmul) {
   OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
@@ -320,7 +321,7 @@ TEST(Einsum, ImplicitEinsumAsBatchedMatmulWithBroadcasting_0) {
   test.AddInput<float>("x", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddInput<float>("y", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2, 2, 2}, {7.f, 10.f, 15.f, 22.f, 7.f, 10.f, 15.f, 22.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ImplicitEinsumAsMatmul_2) {
@@ -343,7 +344,7 @@ TEST(Einsum, DiagonalWithMatmul) {
   test.AddInput<float>("x", {2, 2, 3}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f});
   test.AddInput<float>("y", {3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f});
   test.AddOutput<float>("o", {3}, {60.f, 72.f, 84.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 // Theme: Diagonal parsing
@@ -354,7 +355,7 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOp) {
   test.AddAttribute<std::string>("equation", "ii->i");
   test.AddInput<float>("x", {2, 2}, {1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2}, {1.f, 4.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsDiagonalOp_1) {
@@ -362,7 +363,7 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOp_1) {
   test.AddAttribute<std::string>("equation", "iii->i");
   test.AddInput<float>("x", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2}, {1.f, 4.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsDiagonalOpWithAxisReduced) {
@@ -370,7 +371,7 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOpWithAxisReduced) {
   test.AddAttribute<std::string>("equation", "iji->j");
   test.AddInput<float>("x", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2}, {3.f, 7.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsDiagonalOpWithAxisPreserved) {
@@ -378,7 +379,7 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOpWithAxisPreserved) {
   test.AddAttribute<std::string>("equation", "iji->ij");
   test.AddInput<float>("x", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2, 2}, {1.f, 3.f, 2.f, 4.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose) {
@@ -386,7 +387,7 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose) {
   test.AddAttribute<std::string>("equation", "iji->ji");
   test.AddInput<float>("x", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2, 2}, {1.f, 2.f, 3.f, 4.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 // ROCm doesn't support double
@@ -396,7 +397,7 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_double) {
   test.AddAttribute<std::string>("equation", "iji->ji");
   test.AddInput<double>("x", {2, 2, 2}, {1., 2., 3., 4., 1., 2., 3., 4.});
   test.AddOutput<double>("o", {2, 2}, {1., 2., 3., 4.});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 #endif
 
@@ -405,7 +406,7 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_int32) {
   test.AddAttribute<std::string>("equation", "iji->ji");
   test.AddInput<int32_t>("x", {2, 2, 2}, {1, 2, 3, 4, 1, 2, 3, 4});
   test.AddOutput<int32_t>("o", {2, 2}, {1, 2, 3, 4});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_int64) {
@@ -413,14 +414,14 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOpWithTranspose_int64) {
   test.AddAttribute<std::string>("equation", "iji->ji");
   test.AddInput<int64_t>("x", {2, 2, 2}, {1, 2, 3, 4, 1, 2, 3, 4});
   test.AddOutput<int64_t>("o", {2, 2}, {1, 2, 3, 4});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 TEST(Einsum, ExplicitEinsumAsBatchedDiagonalOp) {
   OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
   test.AddAttribute<std::string>("equation", "...ii->...i");
   test.AddInput<float>("x", {3, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {3, 2}, {1.f, 4.f, 1.f, 4.f, 1.f, 4.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsBatchedDiagonalOp_1) {
@@ -428,7 +429,7 @@ TEST(Einsum, ExplicitEinsumAsBatchedDiagonalOp_1) {
   test.AddAttribute<std::string>("equation", "...iij->...j");
   test.AddInput<float>("x", {2, 2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2, 2}, {4.f, 6.f, 4.f, 6.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 // Implicit (Implicit diagonal ops will sum up diagonal values)
@@ -442,7 +443,7 @@ TEST(Einsum, ImplicitEinsumAsDiagonalOp) {
   test.AddAttribute<std::string>("equation", "ii");
   test.AddInput<float>("x", {2, 2}, {1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {}, {5.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ImplicitEinsumAsDiagonalOp_1) {
@@ -455,7 +456,7 @@ TEST(Einsum, ImplicitEinsumAsDiagonalOp_1) {
   test.AddAttribute<std::string>("equation", "iii");
   test.AddInput<float>("x", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {}, {5.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ImplicitEinsumAsDiagonalOpWithAxisReduced) {
@@ -463,7 +464,7 @@ TEST(Einsum, ImplicitEinsumAsDiagonalOpWithAxisReduced) {
   test.AddAttribute<std::string>("equation", "iji");
   test.AddInput<float>("x", {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2}, {3.f, 7.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ImplicitEinsumAsBatchedDiagonalOp) {
@@ -471,7 +472,7 @@ TEST(Einsum, ImplicitEinsumAsBatchedDiagonalOp) {
   test.AddAttribute<std::string>("equation", "...ii");
   test.AddInput<float>("x", {2, 1, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2, 1}, {5.f, 5.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ImplicitEinsumAsBatchedDiagonalOp_1) {
@@ -479,7 +480,7 @@ TEST(Einsum, ImplicitEinsumAsBatchedDiagonalOp_1) {
   test.AddAttribute<std::string>("equation", "...iij");
   test.AddInput<float>("x", {2, 2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2, 2}, {4.f, 6.f, 4.f, 6.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 // Theme: Scalar inputs and outputs
@@ -491,7 +492,7 @@ TEST(Einsum, ExplicitEinsumAsElementwiseMulOpWithOneScalar) {
   test.AddInput<float>("x", {}, {10.f});
   test.AddInput<float>("y", {2, 2}, {1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2, 2}, {10.f, 20.f, 30.f, 40.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsElementwiseMulOpWithTwoScalars_Multi_Input) {
@@ -501,7 +502,7 @@ TEST(Einsum, ExplicitEinsumAsElementwiseMulOpWithTwoScalars_Multi_Input) {
   test.AddInput<float>("y", {2, 2}, {1.f, 2.f, 3.f, 4.f});
   test.AddInput<float>("z", {}, {10.f});
   test.AddOutput<float>("o", {2, 2}, {100.f, 200.f, 300.f, 400.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 TEST(Einsum, ExplicitEinsumAsElementwiseMulOpWithAllScalars) {
   OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
@@ -527,7 +528,7 @@ TEST(Einsum, ImplicitEinsumAsElementwiseMulOpWithOneScalar) {
   test.AddInput<float>("x", {}, {10.f});
   test.AddInput<float>("y", {2, 2}, {1.f, 2.f, 3.f, 4.f});
   test.AddOutput<float>("o", {2, 2}, {10.f, 20.f, 30.f, 40.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ImplicitEinsumAsElementwiseMulOpWithThreeScalars_Multi_Input) {
@@ -538,7 +539,7 @@ TEST(Einsum, ImplicitEinsumAsElementwiseMulOpWithThreeScalars_Multi_Input) {
   test.AddInput<float>("c", {}, {10.f});
   test.AddInput<float>("d", {}, {10.f});
   test.AddOutput<float>("o", {2, 2}, {1000.f, 2000.f, 3000.f, 4000.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 TEST(Einsum, ImplicitEinsumAsElementwiseMulOpWithAllScalars) {
   OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
@@ -568,7 +569,7 @@ TEST(Einsum, ExplicitEinsumAsTensorContractionReshapeFinal) {
   test.AddInput<float>("y", {2, 2}, {1.f, 2.f, -6.f, 2.f});
   test.AddInput<float>("z", {2, 2}, {3.f, 4.f, 5.f, 6.f});
   test.AddOutput<float>("o", {2, 2, 2}, {63.f, -132.f, 63.f, -132.f, 63.f, -132.f, 63.f, -132.f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsTensorContractionReshapeLeft) {
@@ -720,7 +721,7 @@ TEST(Einsum, ExplicitEinsumAsDiagonalOp_Half) {
   ConvertFloatToMLFloat16(output_f.data(), output.data(), 2);
   test.AddInput<MLFloat16>("x", {2, 2}, input_x);
   test.AddOutput<MLFloat16>("o", {2}, output);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsElementwiseMulOpWithOneScalar_Half) {
@@ -741,7 +742,7 @@ TEST(Einsum, ExplicitEinsumAsElementwiseMulOpWithOneScalar_Half) {
   test.AddInput<MLFloat16>("x", {}, input_x);
   test.AddInput<MLFloat16>("y", {2, 2}, input_y);
   test.AddOutput<MLFloat16>("o", {2, 2}, output);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(Einsum, ExplicitEinsumAsTensorContraction_Half) {
@@ -2093,7 +2094,7 @@ TEST_P(EinsumTransposeMatMulThreeInputsTest, EinsumTransposeMatMulThreeInputsTes
   std::vector<int64_t> v1(tst.shape.begin(), tst.shape.end());
   std::vector<float> v2(tst.expected.begin(), tst.expected.end());
   test.AddOutput<float>("o", v1, v2);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 INSTANTIATE_TEST_SUITE_P(EinsumTransposeMatMulThreeInputsTests, EinsumTransposeMatMulThreeInputsTest, testing::ValuesIn(case1));
diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
index 5e746ed0c62d4..c02486a2ec26f 100644
--- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
+++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
@@ -5,8 +5,11 @@
 #include "test/providers/provider_test_utils.h"
 #include "test/util/include/default_providers.h"
 #include "test/common/dnnl_op_test_utils.h"
+#include "test/common/cuda_op_test_utils.h"
+#include "test/common/trt_op_test_utils.h"
 #include "core/util/math.h"
 #include <algorithm>
+#include <limits>
 #include <math.h>
 
 namespace onnxruntime {
@@ -786,13 +789,20 @@ TEST(MathOpTest, Sqrt_Float) {
   test.Run();
 }
 
-#if defined(USE_DNNL)
+#if defined(USE_DNNL) || defined(USE_CUDA)
 TEST(MathOpTest, Sqrt_bfloat16) {
 #ifdef USE_DNNL
   if (!DnnlHasBF16Support()) {
     LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16";
     return;
   }
+#endif
+#ifdef USE_CUDA
+  int min_cuda_architecture = 530;
+  if (!HasCudaEnvironment(min_cuda_architecture)) {
+    LOGS_DEFAULT(WARNING) << "Hardware does NOT support BFP16";
+    return;
+  }
 #endif
   OpTester test_bf16("Sqrt", 13);  // only version 13 support bf16 for sqrt
   test_bf16.AddInput<BFloat16>("X", {2, 3},
@@ -804,6 +814,9 @@ TEST(MathOpTest, Sqrt_bfloat16) {
   std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
 #if defined(USE_DNNL)
   execution_providers.push_back(DefaultDnnlExecutionProvider());
+#endif
+#ifdef USE_CUDA
+  execution_providers.push_back(DefaultCudaExecutionProvider());
 #endif
   test_bf16.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
 }
@@ -1359,7 +1372,8 @@ static void TestSumMultipleInputsNoBroadcasting(size_t num_inputs, const TensorS
 
   test.AddOutput<element_type>("sum", dims, expected_output_data);
 
-  test.Run();
+  // TRT EP segmentation fault in A100
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(MathOpTest, SumMultipleInputsNoBroadcasting) {
@@ -1495,6 +1509,34 @@ TEST(MathOpTest, Min_12_Float_2_Input) {
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});  // TensorRT: Input batch size is inconsistent
 }
 
+TEST(MathOpTest, Min_12_Float_Nan) {
+  OpTester test("Min", 12);
+  test.AddInput<float>("data_2", {3, 3},
+                       {std::numeric_limits<float>::quiet_NaN(),
+                        std::numeric_limits<float>::quiet_NaN(),
+                        std::numeric_limits<float>::quiet_NaN(),
+                        -0.5f, 0.0f, -2.0f,
+                        0.5f, 0.0f, 2.0f});
+  test.AddInput<float>("data_1", {3, 1},
+                       {0.0f, -1.0f, 1.0f});
+  test.AddOutput<float>("min", {3, 3},
+                        {std::numeric_limits<float>::quiet_NaN(),
+                         std::numeric_limits<float>::quiet_NaN(),
+                         std::numeric_limits<float>::quiet_NaN(),
+                         -1.0f, -1.0f, -2.0f,
+                         0.5f, 0.0f, 1.0f});
+  if (nullptr != DefaultCpuExecutionProvider().get()) {
+    std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+    execution_providers.push_back(DefaultCpuExecutionProvider());
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+  }
+  if (nullptr != DefaultCudaExecutionProvider().get()) {
+    std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+    execution_providers.push_back(DefaultCudaExecutionProvider());
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+  }
+}
+
 TEST(MathOpTest, Min_12_Double) {
   OpTester test("Min", 12);
   test.AddInput<double>("data_0", {1, 3},
@@ -1512,6 +1554,34 @@ TEST(MathOpTest, Min_12_Double) {
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});  // TensorRT: Input batch size is inconsistent
 }
 
+TEST(MathOpTest, Min_12_Double_Nan) {
+  OpTester test("Min", 12);
+  test.AddInput<double>("data_2", {3, 3},
+                        {std::numeric_limits<double>::quiet_NaN(),
+                         std::numeric_limits<double>::quiet_NaN(),
+                         std::numeric_limits<double>::quiet_NaN(),
+                         -0.5, 0.0, -2.0,
+                         0.5, 0.0, 2.0});
+  test.AddInput<double>("data_1", {3, 1},
+                        {0.0, -1.0, 1.0});
+  test.AddOutput<double>("min", {3, 3},
+                         {std::numeric_limits<double>::quiet_NaN(),
+                          std::numeric_limits<double>::quiet_NaN(),
+                          std::numeric_limits<double>::quiet_NaN(),
+                          -1.0, -1.0, -2.0,
+                          0.5, 0.0, 1.0});
+  if (nullptr != DefaultCpuExecutionProvider().get()) {
+    std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+    execution_providers.push_back(DefaultCpuExecutionProvider());
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+  }
+  if (nullptr != DefaultCudaExecutionProvider().get()) {
+    std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+    execution_providers.push_back(DefaultCudaExecutionProvider());
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+  }
+}
+
 TEST(MathOpTest, Min_12_Int32) {
   OpTester test("Min", 12);
   test.AddInput<int32_t>("data_0", {1, 3},
@@ -1618,6 +1688,7 @@ TEST(MathOpTest, Min_12_MLFLoat16_Scalar1) {
                             MakeMLFloat16({-10.f, -10.f, -10.f}));
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});  // TensorRT: Input batch size is inconsistent
 }
+
 TEST(MathOpTest, Max_6) {
   OpTester test("Max", 6);
   std::vector<int64_t> dims{3, 3};
@@ -1706,6 +1777,34 @@ TEST(MathOpTest, Max_12_Float) {
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});  // TensorRT: Input batch size is inconsistent
 }
 
+TEST(MathOpTest, Max_12_Float_Nan) {
+  OpTester test("Max", 12);
+  test.AddInput<float>("data_2", {3, 3},
+                       {std::numeric_limits<float>::quiet_NaN(),
+                        std::numeric_limits<float>::quiet_NaN(),
+                        std::numeric_limits<float>::quiet_NaN(),
+                        -0.5f, 0.0f, -2.0f,
+                        0.5f, 0.0f, 2.0f});
+  test.AddInput<float>("data_1", {3, 1},
+                       {0.0f, -1.0f, 1.0f});
+  test.AddOutput<float>("max", {3, 3},
+                        {std::numeric_limits<float>::quiet_NaN(),
+                         std::numeric_limits<float>::quiet_NaN(),
+                         std::numeric_limits<float>::quiet_NaN(),
+                         -0.5f, 0.0f, -1.0f,
+                         1.0f, 1.0f, 2.0f});
+  if (nullptr != DefaultCpuExecutionProvider().get()) {
+    std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+    execution_providers.push_back(DefaultCpuExecutionProvider());
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+  }
+  if (nullptr != DefaultCudaExecutionProvider().get()) {
+    std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+    execution_providers.push_back(DefaultCudaExecutionProvider());
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+  }
+}
+
 TEST(MathOpTest, Max_12_Double) {
   OpTester test("Max", 12);
   test.AddInput<double>("data_0", {1, 3},
@@ -1723,6 +1822,34 @@ TEST(MathOpTest, Max_12_Double) {
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});  // TensorRT: Input batch size is inconsistent
 }
 
+TEST(MathOpTest, Max_12_Double_Nan) {
+  OpTester test("Max", 12);
+  test.AddInput<double>("data_2", {3, 3},
+                        {std::numeric_limits<double>::quiet_NaN(),
+                         std::numeric_limits<double>::quiet_NaN(),
+                         std::numeric_limits<double>::quiet_NaN(),
+                         -0.5, 0.0, -2.0,
+                         0.5, 0.0, 2.0});
+  test.AddInput<double>("data_1", {3, 1},
+                        {0.0, -1.0, 1.0});
+  test.AddOutput<double>("max", {3, 3},
+                         {std::numeric_limits<double>::quiet_NaN(),
+                          std::numeric_limits<double>::quiet_NaN(),
+                          std::numeric_limits<double>::quiet_NaN(),
+                          -0.5, 0.0, -1.0,
+                          1.0, 1.0, 2.0});
+  if (nullptr != DefaultCpuExecutionProvider().get()) {
+    std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+    execution_providers.push_back(DefaultCpuExecutionProvider());
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+  }
+  if (nullptr != DefaultCudaExecutionProvider().get()) {
+    std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+    execution_providers.push_back(DefaultCudaExecutionProvider());
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+  }
+}
+
 TEST(MathOpTest, Max_12_Int32) {
   OpTester test("Max", 12);
   test.AddInput<int32_t>("data_0", {1, 3},
@@ -2619,7 +2746,7 @@ TEST(MathOpTest, Mean_8) {
 #endif
 
 template <float (&op)(float value) MATH_NO_EXCEPT>
-void TrigFloatTest(OpTester& test, std::initializer_list<float> input) {
+void TrigFloatTest(OpTester& test, std::initializer_list<float> input, float abs_error = -1.0f) {
   std::vector<int64_t> dims{static_cast<int64_t>(input.size())};
 
   std::vector<float> output;
@@ -2628,6 +2755,11 @@ void TrigFloatTest(OpTester& test, std::initializer_list<float> input) {
 
   test.AddInput<float>("X", dims, input);
   test.AddOutput<float>("Y", dims, output);
+
+  if (abs_error >= 0.0f) {
+    test.SetOutputTolerance(abs_error);
+  }
+
   test.Run();
 }
 
@@ -2697,6 +2829,7 @@ TEST(MathOpTest, CosFloat16) {
     TrigFloat16Test<::cosf>(test, {1.1f, -1.1f, 2.2f, -2.2f});
   }
 }
+
 TEST(MathOpTest, Tan) {
   OpTester test("Tan");
   TrigFloatTest<::tanf>(test, {-100.0f, -50.0f, 0.0f, 50.0f, 100.0f});
@@ -2704,7 +2837,8 @@ TEST(MathOpTest, Tan) {
 
 TEST(MathOpTest, Asin) {
   OpTester test("Asin");
-  TrigFloatTest<::asinf>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f});
+  float abs_error = DefaultDmlExecutionProvider().get() != nullptr ? 0.0001f : -1.0f;
+  TrigFloatTest<::asinf>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}, abs_error);
 }
 
 TEST(MathOpTest, Acos) {
diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc
index bf089e083d67e..1a542fb67418e 100644
--- a/onnxruntime/test/providers/cpu/math/gemm_test.cc
+++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc
@@ -277,28 +277,35 @@ class GemmOpTypedTests : public ::testing::Test {
 // On CPUs without fp16 instructions the tests will output a warning:
 // "registered execution providers CPUExecutionProvider were unable to run the model"
 // , then they will still pass.
-using GemmOpTypedTestsTypes = ::testing::Types<float, double, MLFloat16>;
+using GemmOpTypedTestsTypes = ::testing::Types<float, double>;
 TYPED_TEST_SUITE(GemmOpTypedTests, GemmOpTypedTestsTypes);
 
 TYPED_TEST(GemmOpTypedTests, TestGemmScalarBroadcast) {
-  OpTester test("Gemm");
+  auto run_test = [](bool b_is_initializer, bool c_is_initializer) {
+    OpTester test("Gemm");
 
-  test.AddAttribute("transA", (int64_t)0);
-  test.AddAttribute("transB", (int64_t)0);
-  test.AddAttribute("alpha", 1.0f);
-  test.AddAttribute("beta", 1.0f);
+    test.AddAttribute("transA", (int64_t)0);
+    test.AddAttribute("transB", (int64_t)0);
+    test.AddAttribute("alpha", 1.0f);
+    test.AddAttribute("beta", 1.0f);
 
-  test.AddInput<TypeParam>("A", {2, 4},
-                           {static_cast<TypeParam>(1.0f), static_cast<TypeParam>(2.0f), static_cast<TypeParam>(3.0f), static_cast<TypeParam>(4.0f),
-                            static_cast<TypeParam>(-1.0f), static_cast<TypeParam>(-2.0f), static_cast<TypeParam>(-3.0f), static_cast<TypeParam>(-4.0f)});
-  test.AddInput<TypeParam>("B", {4, 3}, std::vector<TypeParam>(12, static_cast<TypeParam>(1.0f)));
-  test.AddInput<TypeParam>("C", {1}, std::vector<TypeParam>{static_cast<TypeParam>(1.0f)});
-  test.AddOutput<TypeParam>("Y", {2, 3},
-                            {static_cast<TypeParam>(11.0f), static_cast<TypeParam>(11.0f), static_cast<TypeParam>(11.0f),
-                             static_cast<TypeParam>(-9.0f), static_cast<TypeParam>(-9.0f), static_cast<TypeParam>(-9.0f)});
-  test.Config(run_with_tunable_op)
-      .RunWithConfig();
+    test.AddInput<TypeParam>("A", {2, 4},
+                             {static_cast<TypeParam>(1.0f), static_cast<TypeParam>(2.0f), static_cast<TypeParam>(3.0f), static_cast<TypeParam>(4.0f),
+                              static_cast<TypeParam>(-1.0f), static_cast<TypeParam>(-2.0f), static_cast<TypeParam>(-3.0f), static_cast<TypeParam>(-4.0f)});
+    test.AddInput<TypeParam>("B", {4, 3}, std::vector<TypeParam>(12, static_cast<TypeParam>(1.0f)), b_is_initializer);
+    test.AddInput<TypeParam>("C", {1}, std::vector<TypeParam>{static_cast<TypeParam>(1.0f)}, c_is_initializer);
+    test.AddOutput<TypeParam>("Y", {2, 3},
+                              {static_cast<TypeParam>(11.0f), static_cast<TypeParam>(11.0f), static_cast<TypeParam>(11.0f),
+                               static_cast<TypeParam>(-9.0f), static_cast<TypeParam>(-9.0f), static_cast<TypeParam>(-9.0f)});
+    test.Config(run_with_tunable_op)
+        .RunWithConfig();
+  };
+
+  run_test(false, false);
+  // CoreML EP requires weight and bias to be initializers
+  run_test(true, true);
 }
+
 TYPED_TEST(GemmOpTypedTests, TestGemm2DBroadcast_2) {
   OpTester test("Gemm");
 
diff --git a/onnxruntime/test/providers/cpu/math/logsoftmax_test.cc b/onnxruntime/test/providers/cpu/math/logsoftmax_test.cc
index 273503e7bf6af..f057e4a071bd9 100644
--- a/onnxruntime/test/providers/cpu/math/logsoftmax_test.cc
+++ b/onnxruntime/test/providers/cpu/math/logsoftmax_test.cc
@@ -15,7 +15,8 @@ static void RunTest(const std::vector<float>& x_vals,
                     int64_t axis = 1,
                     bool is_tensorrt_supported = true,
                     OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
-                    const std::string& error_msg = "") {
+                    const std::string& error_msg = "",
+                    float tolerance = 0.0f) {
   OpTester tester("LogSoftmax", opset);
 
   if (opset < 13) {
@@ -31,6 +32,10 @@ static void RunTest(const std::vector<float>& x_vals,
   tester.AddInput("X", dimensions, x_vals);
   tester.AddOutput("Y", dimensions, expected_vals);
 
+  if (tolerance != 0.0f) {
+    tester.SetOutputAbsErr("Y", tolerance);
+  }
+
   std::unordered_set<std::string> excluded_providers;
   if (!is_tensorrt_supported) {
     excluded_providers.insert(kTensorrtExecutionProvider);
@@ -62,7 +67,7 @@ TEST(LogSoftmaxOperator, LargeNumber) {
                                       -3.4401896f, -2.4401896f, -1.44018972f, -0.44018969f};
   std::vector<int64_t> dimensions = {2, 4};
 
-  RunTest(x_vals, expected_vals, dimensions);
+  RunTest(x_vals, expected_vals, dimensions, 7, 1, true, OpTester::ExpectResult::kExpectSuccess, "", 0.0005f);
 }
 
 // np.random.seed(123)   # Use a seed so we can replicate the input and expected values here and in python
diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc
index 8128c170c5211..aa752ed7308c6 100644
--- a/onnxruntime/test/providers/cpu/model_tests.cc
+++ b/onnxruntime/test/providers/cpu/model_tests.cc
@@ -3,6 +3,12 @@
 
 #include <iostream>
 #include <iterator>
+#include <string>
+#include <codecvt>
+#include <locale>
+#include <filesystem>
+#include <utility>
+#include <unordered_map>
 #include <gtest/gtest.h>
 
 #include "core/session/onnxruntime_c_api.h"
@@ -15,9 +21,6 @@
 #include <core/platform/path_lib.h>
 #include "default_providers.h"
 #include "test/onnx/TestCase.h"
-#include <string>
-#include <codecvt>
-#include <locale>
 
 #ifdef USE_DNNL
 #include "core/providers/dnnl/dnnl_provider_factory.h"
@@ -47,7 +50,6 @@
 #include "test/compare_ortvalue.h"
 #include "test/onnx/heap_buffer.h"
 #include "test/onnx/onnx_model_info.h"
-#include "test/onnx/callback.h"
 #include "test/onnx/testcase_request.h"
 
 extern std::unique_ptr<Ort::Env> ort_env;
@@ -90,27 +92,13 @@ TEST_P(ModelTest, Run) {
   // when cuda or openvino is enabled, set it to a larger value for resolving random MNIST test failure
   if (model_path.find(ORT_TSTR("_MNIST")) > 0) {
     if (provider_name == "cuda" || provider_name == "openvino") {
+      per_sample_tolerance = 2.5e-2;
       relative_per_sample_tolerance = 1e-2;
     }
   }
 
   std::unique_ptr<OnnxModelInfo> model_info = std::make_unique<OnnxModelInfo>(model_path.c_str());
 
-#if defined(__linux__)
-  // ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test.
-  if (HasCudaEnvironment(800) && provider_name == "cuda") {
-    per_sample_tolerance = 1e-1;
-    if (model_path.find(ORT_TSTR("SSD")) > 0 ||
-        model_path.find(ORT_TSTR("ssd")) > 0 ||
-        model_path.find(ORT_TSTR("yolov3")) > 0 ||
-        model_path.find(ORT_TSTR("mask_rcnn")) > 0 ||
-        model_path.find(ORT_TSTR("FNS")) > 0) {
-      SkipTest("Skipping SSD test for big tolearance failure or other errors");
-      return;
-    }
-  }
-#endif
-
   if (model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) ||
       model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) {
     SkipTest("it has the training domain. No pipeline should need to run these tests.");
@@ -190,12 +178,14 @@ TEST_P(ModelTest, Run) {
         ASSERT_ORT_STATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options));
         std::unique_ptr<OrtCUDAProviderOptionsV2, decltype(&OrtApis::ReleaseCUDAProviderOptions)> rel_cuda_options(
             cuda_options, &OrtApis::ReleaseCUDAProviderOptions);
-        std::vector<const char*> keys{"device_id"};
 
+        std::vector<const char*> keys{"device_id", "use_tf32"};
         std::vector<const char*> values;
         std::string device_id = Env::Default().GetEnvironmentVar("ONNXRUNTIME_TEST_GPU_DEVICE_ID");
         values.push_back(device_id.empty() ? "0" : device_id.c_str());
-        ASSERT_ORT_STATUS_OK(OrtApis::UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), 1));
+        values.push_back("0");
+        ASSERT_ORT_STATUS_OK(OrtApis::UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), 2));
+
         ortso.AppendExecutionProvider_CUDA_V2(*cuda_options);
       } else if (provider_name == "rocm") {
         OrtROCMProviderOptions ep_options;
@@ -227,6 +217,14 @@ TEST_P(ModelTest, Run) {
         ASSERT_ORT_STATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options));
         std::unique_ptr<OrtCUDAProviderOptionsV2, decltype(&OrtApis::ReleaseCUDAProviderOptions)> rel_cuda_options(
             cuda_options, &OrtApis::ReleaseCUDAProviderOptions);
+
+        std::vector<const char*> keys{"device_id", "use_tf32"};
+        std::vector<const char*> values;
+        std::string device_id = Env::Default().GetEnvironmentVar("ONNXRUNTIME_TEST_GPU_DEVICE_ID");
+        values.push_back(device_id.empty() ? "0" : device_id.c_str());
+        values.push_back("0");
+        ASSERT_ORT_STATUS_OK(OrtApis::UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), 2));
+
         ortso.AppendExecutionProvider_CUDA_V2(*cuda_options);
       } else if (provider_name == "migraphx") {
         OrtMIGraphXProviderOptions ep_options;
@@ -378,46 +376,46 @@ TEST_P(ModelTest, Run) {
 }
 
 using ORT_STRING_VIEW = std::basic_string_view<ORTCHAR_T>;
-static ORT_STRING_VIEW opset7 = ORT_TSTR("opset7");
-static ORT_STRING_VIEW opset8 = ORT_TSTR("opset8");
-static ORT_STRING_VIEW opset9 = ORT_TSTR("opset9");
-static ORT_STRING_VIEW opset10 = ORT_TSTR("opset10");
-static ORT_STRING_VIEW opset11 = ORT_TSTR("opset11");
-static ORT_STRING_VIEW opset12 = ORT_TSTR("opset12");
-static ORT_STRING_VIEW opset13 = ORT_TSTR("opset13");
-static ORT_STRING_VIEW opset14 = ORT_TSTR("opset14");
-static ORT_STRING_VIEW opset15 = ORT_TSTR("opset15");
-static ORT_STRING_VIEW opset16 = ORT_TSTR("opset16");
-static ORT_STRING_VIEW opset17 = ORT_TSTR("opset17");
-static ORT_STRING_VIEW opset18 = ORT_TSTR("opset18");
+static constexpr ORT_STRING_VIEW opset7 = ORT_TSTR("opset7");
+static constexpr ORT_STRING_VIEW opset8 = ORT_TSTR("opset8");
+static constexpr ORT_STRING_VIEW opset9 = ORT_TSTR("opset9");
+static constexpr ORT_STRING_VIEW opset10 = ORT_TSTR("opset10");
+static constexpr ORT_STRING_VIEW opset11 = ORT_TSTR("opset11");
+static constexpr ORT_STRING_VIEW opset12 = ORT_TSTR("opset12");
+static constexpr ORT_STRING_VIEW opset13 = ORT_TSTR("opset13");
+static constexpr ORT_STRING_VIEW opset14 = ORT_TSTR("opset14");
+static constexpr ORT_STRING_VIEW opset15 = ORT_TSTR("opset15");
+static constexpr ORT_STRING_VIEW opset16 = ORT_TSTR("opset16");
+static constexpr ORT_STRING_VIEW opset17 = ORT_TSTR("opset17");
+static constexpr ORT_STRING_VIEW opset18 = ORT_TSTR("opset18");
 // TODO: enable opset19 tests
-// static ORT_STRING_VIEW opset19 = ORT_TSTR("opset19");
+// static constexpr ORT_STRING_VIEW opset19 = ORT_TSTR("opset19");
 
-static ORT_STRING_VIEW provider_name_cpu = ORT_TSTR("cpu");
-static ORT_STRING_VIEW provider_name_tensorrt = ORT_TSTR("tensorrt");
+static constexpr ORT_STRING_VIEW provider_name_cpu = ORT_TSTR("cpu");
+static constexpr ORT_STRING_VIEW provider_name_tensorrt = ORT_TSTR("tensorrt");
 #ifdef USE_MIGRAPHX
-static ORT_STRING_VIEW provider_name_migraphx = ORT_TSTR("migraphx");
+static constexpr ORT_STRING_VIEW provider_name_migraphx = ORT_TSTR("migraphx");
 #endif
-static ORT_STRING_VIEW provider_name_openvino = ORT_TSTR("openvino");
-static ORT_STRING_VIEW provider_name_cuda = ORT_TSTR("cuda");
+static constexpr ORT_STRING_VIEW provider_name_openvino = ORT_TSTR("openvino");
+static constexpr ORT_STRING_VIEW provider_name_cuda = ORT_TSTR("cuda");
 #ifdef USE_ROCM
-static ORT_STRING_VIEW provider_name_rocm = ORT_TSTR("rocm");
+static constexpr ORT_STRING_VIEW provider_name_rocm = ORT_TSTR("rocm");
 #endif
-static ORT_STRING_VIEW provider_name_dnnl = ORT_TSTR("dnnl");
+static constexpr ORT_STRING_VIEW provider_name_dnnl = ORT_TSTR("dnnl");
 // For any non-Android system, NNAPI will only be used for ort model converter
 #if defined(USE_NNAPI) && defined(__ANDROID__)
-static ORT_STRING_VIEW provider_name_nnapi = ORT_TSTR("nnapi");
+static constexpr ORT_STRING_VIEW provider_name_nnapi = ORT_TSTR("nnapi");
 #endif
 #ifdef USE_RKNPU
-static ORT_STRING_VIEW provider_name_rknpu = ORT_TSTR("rknpu");
+static constexpr ORT_STRING_VIEW provider_name_rknpu = ORT_TSTR("rknpu");
 #endif
 #ifdef USE_ACL
-static ORT_STRING_VIEW provider_name_acl = ORT_TSTR("acl");
+static constexpr ORT_STRING_VIEW provider_name_acl = ORT_TSTR("acl");
 #endif
 #ifdef USE_ARMNN
-static ORT_STRING_VIEW provider_name_armnn = ORT_TSTR("armnn");
+static constexpr ORT_STRING_VIEW provider_name_armnn = ORT_TSTR("armnn");
 #endif
-static ORT_STRING_VIEW provider_name_dml = ORT_TSTR("dml");
+static constexpr ORT_STRING_VIEW provider_name_dml = ORT_TSTR("dml");
 
 ::std::vector<::std::basic_string<ORTCHAR_T>> GetParameterStrings() {
   // Map key is provider name(CPU, CUDA, etc). Value is the ONNX node tests' opsets to run.
@@ -615,9 +613,10 @@ ::std::vector<::std::basic_string<ORTCHAR_T>> GetParameterStrings() {
       ORT_TSTR("SSD"),                 // needs to run symbolic shape inference shape first
       ORT_TSTR("size")                 // INVALID_ARGUMENT: Cannot find binding of given name: x
   };
-  std::vector<std::basic_string<ORTCHAR_T>> paths;
+  std::vector<std::filesystem::path> paths;
 
   for (std::pair<ORT_STRING_VIEW, std::vector<ORT_STRING_VIEW>> kvp : provider_names) {
+    const ORT_STRING_VIEW provider_name = kvp.first;
     // Setup ONNX node tests. The test data is preloaded on our CI build machines.
 #if !defined(_WIN32)
     ORT_STRING_VIEW node_test_root_path = ORT_TSTR("/data/onnx");
@@ -625,7 +624,10 @@ ::std::vector<::std::basic_string<ORTCHAR_T>> GetParameterStrings() {
     ORT_STRING_VIEW node_test_root_path = ORT_TSTR("c:\\local\\data\\onnx");
 #endif
     for (auto p : kvp.second) {
-      paths.push_back(ConcatPathComponent(node_test_root_path, p));
+      // tensorrt ep isn't expected to pass all onnx node tests. exclude and run model tests only.
+      if (provider_name != provider_name_tensorrt) {
+        paths.push_back(ConcatPathComponent(node_test_root_path, p));
+      }
     }
 
     // Same as the above, except this one is for large models
@@ -644,7 +646,6 @@ ::std::vector<::std::basic_string<ORTCHAR_T>> GetParameterStrings() {
     }
 #endif
 
-    ORT_STRING_VIEW provider_name = kvp.first;
     std::unordered_set<std::basic_string<ORTCHAR_T>> all_disabled_tests(std::begin(immutable_broken_tests),
                                                                         std::end(immutable_broken_tests));
     if (provider_name == provider_name_cuda) {
@@ -699,45 +700,45 @@ ::std::vector<::std::basic_string<ORTCHAR_T>> GetParameterStrings() {
     all_disabled_tests.insert(ORT_TSTR("fp16_tiny_yolov2"));
 
     while (!paths.empty()) {
-      std::basic_string<ORTCHAR_T> node_data_root_path = paths.back();
+      std::filesystem::path node_data_root_path = paths.back();
       paths.pop_back();
-      std::basic_string<ORTCHAR_T> my_dir_name = GetLastComponent(node_data_root_path);
-      ORT_TRY {
-        LoopDir(node_data_root_path, [&](const ORTCHAR_T* filename, OrtFileType f_type) -> bool {
-          if (filename[0] == ORT_TSTR('.'))
-            return true;
-          if (f_type == OrtFileType::TYPE_DIR) {
-            std::basic_string<PATH_CHAR_TYPE> p = ConcatPathComponent(node_data_root_path, filename);
-            paths.push_back(p);
-            return true;
-          }
-          std::basic_string<PATH_CHAR_TYPE> filename_str = filename;
-          if (!HasExtensionOf(filename_str, ORT_TSTR("onnx")))
-            return true;
-          std::basic_string<PATH_CHAR_TYPE> test_case_name = my_dir_name;
-          if (test_case_name.compare(0, 5, ORT_TSTR("test_")) == 0)
-            test_case_name = test_case_name.substr(5);
-          if (all_disabled_tests.find(test_case_name) != all_disabled_tests.end())
-            return true;
+      if (!std::filesystem::exists(node_data_root_path) || !std::filesystem::is_directory(node_data_root_path)) {
+        continue;
+      }
+      for (auto const& dir_entry : std::filesystem::directory_iterator(node_data_root_path)) {
+        if (dir_entry.is_directory()) {
+          paths.push_back(dir_entry.path());
+          continue;
+        }
+        const std::filesystem::path& path = dir_entry.path();
+        if (!path.has_filename() || path.filename().native().compare(0, 1, ORT_TSTR(".")) == 0) {
+          // Ignore hidden files.
+          continue;
+        }
+        if (path.filename().extension().compare(ORT_TSTR(".onnx")) != 0) {
+          // Ignore the files that are not ONNX models
+          continue;
+        }
+        std::basic_string<PATH_CHAR_TYPE> test_case_name = path.parent_path().filename().native();
+        if (test_case_name.compare(0, 5, ORT_TSTR("test_")) == 0)
+          test_case_name = test_case_name.substr(5);
+        if (all_disabled_tests.find(test_case_name) != all_disabled_tests.end())
+          continue;
 
 #ifdef DISABLE_ML_OPS
-          auto starts_with = [](const std::basic_string<PATH_CHAR_TYPE>& find_in,
-                                const std::basic_string<PATH_CHAR_TYPE>& find_what) {
-            return find_in.compare(0, find_what.size(), find_what) == 0;
-          };
-          if (starts_with(test_case_name, ORT_TSTR("XGBoost_")) || starts_with(test_case_name, ORT_TSTR("coreml_")) ||
-              starts_with(test_case_name, ORT_TSTR("scikit_")) || starts_with(test_case_name, ORT_TSTR("libsvm_"))) {
-            return true;
-          }
+        auto starts_with = [](const std::basic_string<PATH_CHAR_TYPE>& find_in,
+                              const std::basic_string<PATH_CHAR_TYPE>& find_what) {
+          return find_in.compare(0, find_what.size(), find_what) == 0;
+        };
+        if (starts_with(test_case_name, ORT_TSTR("XGBoost_")) || starts_with(test_case_name, ORT_TSTR("coreml_")) ||
+            starts_with(test_case_name, ORT_TSTR("scikit_")) || starts_with(test_case_name, ORT_TSTR("libsvm_"))) {
+          continue;
+        }
 #endif
-          std::basic_ostringstream<PATH_CHAR_TYPE> oss;
-          oss << provider_name << ORT_TSTR("_") << ConcatPathComponent(node_data_root_path, filename_str);
-          v.emplace_back(oss.str());
-          return true;
-        });
+        std::basic_ostringstream<PATH_CHAR_TYPE> oss;
+        oss << provider_name << ORT_TSTR("_") << path.native();
+        v.emplace_back(oss.str());
       }
-      ORT_CATCH(const std::exception&) {
-      }  // ignore non-exist dir
     }
   }
   return v;
diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc
index ee18cf2cea6cb..d91a1de3faa6e 100644
--- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc
+++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc
@@ -75,6 +75,43 @@ TEST(BatchNormTest, PositiveTestCase) {
   input_data_map.insert({"mean", mean});
   input_data_map.insert({"var", var});
 
+  InputShapesMap input_shapes_map;
+  vector<int64_t> input_shape{1, 1, 7, 7};
+  input_shapes_map.insert({"X", input_shape});
+  input_shapes_map.insert({"scale", {1}});
+  input_shapes_map.insert({"B", {1}});
+  input_shapes_map.insert({"mean", {1}});
+  input_shapes_map.insert({"var", {1}});
+
+  auto expected_output = {1.01359f, 0.703983f, 0.641631f, 1.08571f, 0.939167f, 0.762469f, 0.682729f, 0.762401f, 0.787021f,
+                          1.06744f, 0.604378f, 0.957476f, 0.667302f, 0.901764f, 1.07566f, 1.01117f, 0.928324f, 0.897667f,
+                          0.705842f, 0.660885f, 0.977291f, 0.878918f, 0.818345f, 1.06608f, 0.839057f, 1.04796f, 0.621471f,
+                          0.781831f, 0.760527f, 0.835665f, 1.05825f, 0.611442f, 0.781873f, 1.08437f, 0.907454f, 0.926173f,
+                          1.03375f, 0.707961f, 0.968646f, 0.621757f, 0.973095f, 0.700301f, 0.916723f, 0.807602f, 0.692598f,
+                          0.621972f, 0.707334f, 0.63723f, 0.63062f};
+  float epsilon = 1e-05f;
+  TestBatchNorm(input_data_map, input_shapes_map, epsilon, expected_output, input_shape);
+}
+
+TEST(BatchNormTest, PositiveTestCase_5D) {
+  // This input was taken from the SpatialBN_1.pb, SpatialBN_1_input.pb and SpatialBN_1_output.pb files.
+  vector<float> X{0.329876f, -0.287158f, -0.411425f, 0.473621f, 0.18156f, -0.170596f, -0.329516f, -0.170733f, -0.121664f, 0.4372f,
+                  -0.485668f, 0.218049f, -0.360263f, 0.107016f, 0.45358f, 0.325056f, 0.15995f, 0.098852f, -0.283453f, -0.373051f,
+                  0.257542f, 0.0614853f, -0.0592363f, 0.434488f, -0.0179583f, 0.398374f, -0.451602f, -0.132009f, -0.174468f,
+                  -0.0247169f, 0.418897f, -0.47159f, -0.131925f, 0.470943f, 0.118357f, 0.155664f, 0.370062f, -0.279229f, 0.240311f,
+                  -0.451034f, 0.249178f, -0.294496f, 0.13683f, -0.0806475f, -0.309849f, -0.450604f, -0.28048f, -0.420197f, -0.433369f};
+  vector<float> scale{0.589433f};
+  vector<float> B{-0.384622f};
+  vector<float> mean{-2.45673f};
+  vector<float> var{1.37998f};
+
+  InputDataMap input_data_map;
+  input_data_map.insert({"X", X});
+  input_data_map.insert({"scale", scale});
+  input_data_map.insert({"B", B});
+  input_data_map.insert({"mean", mean});
+  input_data_map.insert({"var", var});
+
   InputShapesMap input_shapes_map;
   vector<int64_t> input_shape{1, 1, 7, 7, 1};
   input_shapes_map.insert({"X", input_shape});
@@ -868,18 +905,21 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) {
   test.AddInput<float>("var", channel_dims, {1.0f, 2.0f});
 
   test.AddOutput<float>("Y", input_output_dims, {0.0131f, 0.5210f, 1.7244f, 0.1387f, -0.2708f, -0.1191f, 1.2089f, -0.0922f, -0.9548f, -1.5203f, 0.9077f, -0.8298f, 0.5796f, -0.4501f, -2.0921f, 1.2358f});
-
   test.AddOutput<float>("running_mean", channel_dims, {-0.1754f, 0.303106f});
   test.AddOutput<float>("running_var", channel_dims, {0.696052f, 1.41316f});
+
   // mean and variance of X across channel dimension
   // With Opset9 we output saved_inv_std instead of saved_var to match CUDA EP
   test.AddOutput<float>("saved_mean", channel_dims, {-0.306f, 0.114562f});
   test.AddOutput<float>("saved_inv_std", channel_dims, {1.2288f, 0.861317f});
 
+  test.SetOutputTolerance(0.0001f);
+
   // exclude CUDA Execution Provider due to flakiness
   // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm()
   test.Run(OpTester::ExpectResult::kExpectSuccess, "",
-           {kCudaExecutionProvider, kRocmExecutionProvider,
+           // TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
             kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
 }
 
@@ -900,14 +940,15 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) {
   test.AddInput<float>("var", channel_dims, {1.0f, 2.0f});
 
   test.AddOutput<float>("Y", input_output_dims, {0.0131f, 0.5210f, 1.7244f, 0.1387f, -0.2708f, -0.1191f, 1.2089f, -0.0922f, -0.9548f, -1.5203f, 0.9077f, -0.8298f, 0.5796f, -0.4501f, -2.0921f, 1.2358f});
-
   test.AddOutput<float>("running_mean", channel_dims, {-0.1754f, 0.303106f});
   test.AddOutput<float>("running_var", channel_dims, {0.696052f, 1.41316f});
 
+  test.SetOutputTolerance(0.0001f);
+
   // exclude CUDA Execution Provider due to flakiness
   // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm()
   test.Run(OpTester::ExpectResult::kExpectSuccess, "",
-           {kCudaExecutionProvider, kRocmExecutionProvider,
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
             kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
 }
 
@@ -932,9 +973,11 @@ TEST(BatchNormTest, ForwardTrainingTestOpset15) {
   test.AddOutput<float>("running_mean", channel_dims, {-0.1754f, 0.303106f});
   test.AddOutput<float>("running_var", channel_dims, {0.696052f, 1.41316f});
 
+  test.SetOutputTolerance(0.0001f);
+
   // Same exclusions as the opset 14 test
   test.Run(OpTester::ExpectResult::kExpectSuccess, "",
-           {kCudaExecutionProvider, kRocmExecutionProvider,
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
             kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
 }
 #endif  // BATCHNORM_INCLUDE_TRAINING_SUPPORT
diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc
index dede278b7274f..0efa78af2795c 100644
--- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc
+++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc
@@ -59,6 +59,8 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes,
   std::unordered_set<std::string> excluded_providers(attributes.excluded_providers);
   // Disable TensorRT because weight as input is not supported
   excluded_providers.insert(kTensorrtExecutionProvider);
+  // Disable CUDA NHWC execution provider as it is currently flaky
+  excluded_providers.insert(kCudaNHWCExecutionProvider);
 
   // QNN SDK 2.10.0 has a bug that breaks support for dynamic bias inputs.
   excluded_providers.insert(kQnnExecutionProvider);
diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc
index 472f841aa8565..ec93dc249eeb2 100644
--- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc
+++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc
@@ -75,7 +75,8 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes,
                          const vector<int64_t>& expected_output_shape,
                          OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess,
                          const std::string& err_str = "",
-                         const std::unordered_set<std::string>& excluded_provider_types = {kTensorrtExecutionProvider, kQnnExecutionProvider}) {
+                         const std::unordered_set<std::string>& excluded_provider_types =
+                             {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}) {
   std::unordered_set<std::string> extra_exclude_openvino_for_initializer_filter = excluded_provider_types;
   extra_exclude_openvino_for_initializer_filter.insert(kOpenVINOExecutionProvider);
   TestConvTransposeOpInitializer(attributes, inputs, input_shapes, expected_output, expected_output_shape,
@@ -409,7 +410,8 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_2) {
   vector<int64_t> Y_shape = {1, 1, 1, 14};
   auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f};
   TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape,
-                      OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kQnnExecutionProvider});
+                      OpTester::ExpectResult::kExpectSuccess, "",
+                      {kOpenVINOExecutionProvider, kCudaNHWCExecutionProvider, kQnnExecutionProvider});
 }
 
 TEST(ConvTransposeTest, ConvTranspose_2D_OutputShapeWithBatchSize) {
@@ -434,7 +436,8 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShapeWithBatchSize) {
   auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f,
                         11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f};
   TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape,
-                      OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kQnnExecutionProvider});
+                      OpTester::ExpectResult::kExpectSuccess, "",
+                      {kOpenVINOExecutionProvider, kCudaNHWCExecutionProvider, kQnnExecutionProvider});
 }
 
 TEST(ConvTransposeTest, ConvTranspose_InvalidKernelShape) {
@@ -871,7 +874,8 @@ TEST(ConvTransposeTest, DimWithZero) {
 
   TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape,
                       OpTester::ExpectResult::kExpectSuccess, "",
-                      {kTensorrtExecutionProvider, kAclExecutionProvider, kQnnExecutionProvider});
+                      {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider,
+                       kAclExecutionProvider, kQnnExecutionProvider});
 }
 
 TEST(ConvTransposeTest, ConvTranspose_3D) {
@@ -1005,7 +1009,8 @@ TEST(ConvTransposeTest, ConvTranspose_3D) {
 
   TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape,
                       OpTester::ExpectResult::kExpectSuccess, "",
-                      {kTensorrtExecutionProvider, kCudaExecutionProvider, kQnnExecutionProvider});
+                      {kTensorrtExecutionProvider, kCudaExecutionProvider,
+                       kCudaNHWCExecutionProvider, kQnnExecutionProvider});
 }
 
 TEST(ConvTransposeTest, ConvTranspose_1D_AsymmetricPads) {
diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc
index 4b194ec18b31b..c8cf183291518 100644
--- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc
+++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc
@@ -57,7 +57,8 @@ TEST(PoolTest, MaxPool) {
 
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});  // TensorRT: result differs
+  // TensorRT: result differs
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
 }
 
 // Only CUDA kernel has float 16 support
@@ -115,7 +116,8 @@ TEST(PoolTest, MaxPool_F16) {
 
   test.AddInput<MLFloat16>("X", x_dims, f_X);
   test.AddOutput<MLFloat16>("Y", expected_dims, f_Y);
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});  // TensorRT: Assertion `!attrs.count("pads")' failed
+  // TensorRT: Assertion `!attrs.count("pads")' failed
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
 }
 #endif
 
@@ -167,7 +169,9 @@ static void MaxPool_8_WithIndexTest(bool has_index, int64_t storage_order = 0) {
     storage_order == 0 ? test.AddOutput<int64_t>("Indices", expected_dims, expected_indices_row)
                        : test.AddOutput<int64_t>("Indices", expected_dims, expected_indices_col);
   }
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDnnlExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider, kArmNNExecutionProvider, kOpenVINOExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kDnnlExecutionProvider, kTensorrtExecutionProvider,
+            kAclExecutionProvider, kArmNNExecutionProvider, kOpenVINOExecutionProvider});
 }
 
 TEST(PoolTest, MaxPool_8_With_Index) {
@@ -181,7 +185,7 @@ TEST(PoolTest, MaxPool_8_With_Index) {
   MaxPool_8_WithIndexTest(true, 1 /*storage_order*/);  // col major
 }
 
-TEST(PoolTest, MaxPool1D) {
+TEST(PoolTest, MaxPool1D_case1) {
   OpTester test("MaxPool");
 
   test.AddAttribute("auto_pad", "");
@@ -199,6 +203,44 @@ TEST(PoolTest, MaxPool1D) {
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
 }
 
+TEST(PoolTest, MaxPool1D_case2) {
+  OpTester test("MaxPool");
+  // no padding
+  test.AddAttribute("auto_pad", "VALID");
+  test.AddAttribute("strides", std::vector<int64_t>{1});
+  test.AddAttribute("pads", vector<int64_t>{0, 0});
+  test.AddAttribute("kernel_shape", vector<int64_t>{2});
+
+  std::vector<float> x_vals = {1, 2, 3, 4, 5};
+  std::vector<int64_t> x_dims = {1, 1, 5};
+  // The last dim is (5-2+1)/1 = 4
+  std::vector<int64_t> expected_dims = {1, 1, 4};
+  std::vector<float> expected_vals = {2, 3, 4, 5};
+
+  test.AddInput<float>("X", x_dims, x_vals);
+  test.AddOutput<float>("Y", expected_dims, expected_vals);
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+}
+
+TEST(PoolTest, MaxPool1D_case3) {
+  OpTester test("MaxPool");
+  test.AddAttribute("auto_pad", "");
+  test.AddAttribute("strides", std::vector<int64_t>{1});
+  // Pad one element
+  test.AddAttribute("pads", vector<int64_t>{0, 1});
+  test.AddAttribute("kernel_shape", vector<int64_t>{2});
+
+  std::vector<float> x_vals = {1, 2, 3, 4, 5};
+  std::vector<int64_t> x_dims = {1, 1, 5};
+  // Since we padded it, the last dim is larger compared to the case above
+  std::vector<int64_t> expected_dims = {1, 1, 5};
+  std::vector<float> expected_vals = {2, 3, 4, 5, 5};
+
+  test.AddInput<float>("X", x_dims, x_vals);
+  test.AddOutput<float>("Y", expected_dims, expected_vals);
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+}
+
 static void MaxPool1D_8_WithIndexTest(int64_t storage_order) {
   OpTester test("MaxPool", 8);
 
@@ -217,7 +259,8 @@ static void MaxPool1D_8_WithIndexTest(int64_t storage_order) {
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
   test.AddOutput<int64_t>("Indices", expected_dims, expected_indices);
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kTensorrtExecutionProvider, kAclExecutionProvider});
 }
 
 TEST(PoolTest, MaxPool1D_8_With_Index) {
@@ -243,7 +286,8 @@ static void MaxPool1D_12_WithIndexTest_int8(int64_t storage_order) {
   test.AddInput<int8_t>("X", x_dims, x_vals);
   test.AddOutput<int8_t>("Y", expected_dims, expected_vals);
   test.AddOutput<int64_t>("Indices", expected_dims, expected_indices);
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kTensorrtExecutionProvider, kAclExecutionProvider});
 }
 
 static void MaxPool1D_12_WithIndexTest_uint8(int64_t storage_order) {
@@ -264,7 +308,8 @@ static void MaxPool1D_12_WithIndexTest_uint8(int64_t storage_order) {
   test.AddInput<uint8_t>("X", x_dims, x_vals);
   test.AddOutput<uint8_t>("Y", expected_dims, expected_vals);
   test.AddOutput<int64_t>("Indices", expected_dims, expected_indices);
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kTensorrtExecutionProvider, kAclExecutionProvider});
 }
 
 TEST(PoolTest, MaxPool1D_12_With_Index_8bits) {
@@ -304,7 +349,7 @@ TEST(PoolTest, MaxPool2D_uint8) {
 #if defined(OPENVINO_CONFIG_GPU_FP32) || defined(OPENVINO_CONFIG_GPU_FP16)
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
 #else
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {});
 #endif
 }
 
@@ -416,7 +461,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_1d) {
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
   test.Run(OpTester::ExpectResult::kExpectSuccess, "",
-           {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
 }
 
 TEST(PoolTest, MaxPool_10_Dilation_2d) {
@@ -500,7 +545,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_2d) {
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
   test.Run(OpTester::ExpectResult::kExpectSuccess, "",
-           {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
 }
 
 TEST(PoolTest, MaxPool_10_Dilation_Ceil0_2d) {
@@ -528,7 +573,8 @@ TEST(PoolTest, MaxPool_10_Dilation_Ceil0_2d) {
 
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kTensorrtExecutionProvider, kAclExecutionProvider});
 }
 
 TEST(PoolTest, MaxPool_12_Dilation_Ceil0_2d_int8) {
@@ -556,7 +602,8 @@ TEST(PoolTest, MaxPool_12_Dilation_Ceil0_2d_int8) {
 
   test.AddInput<int8_t>("X", x_dims, x_vals);
   test.AddOutput<int8_t>("Y", expected_dims, expected_vals);
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kTensorrtExecutionProvider, kAclExecutionProvider});
 }
 
 TEST(PoolTest, MaxPool_10_Dilation_Ceil1_2d) {
@@ -585,7 +632,8 @@ TEST(PoolTest, MaxPool_10_Dilation_Ceil1_2d) {
 
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kTensorrtExecutionProvider, kAclExecutionProvider});
 }
 
 TEST(PoolTest, MaxPool_10_DilationPadding_3d) {
@@ -621,7 +669,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_3d) {
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
   test.Run(OpTester::ExpectResult::kExpectSuccess, "",
-           {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
 }
 
 TEST(PoolTest, GlobalMaxPool) {
@@ -697,7 +745,7 @@ TEST(PoolTest, GlobalMaxPool) {
 
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {});
 }
 
 TEST(PoolTest, GlobalMaxPool3D) {
@@ -878,6 +926,7 @@ TEST(PoolTest, AveragePool_IncludePadPixel) {
 
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
+  test.SetOutputTolerance(0.0001f);
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
 }
 
@@ -920,7 +969,8 @@ TEST(PoolTest, AveragePool_10_ceil1_2d) {
 
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kTensorrtExecutionProvider, kAclExecutionProvider});
 }
 
 TEST(PoolTest, AveragePool_19_dilation_2d) {
@@ -944,7 +994,9 @@ TEST(PoolTest, AveragePool_19_dilation_2d) {
 
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider, kOpenVINOExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider,
+            kTensorrtExecutionProvider, kAclExecutionProvider, kOpenVINOExecutionProvider});
 }
 
 TEST(PoolTest, GlobalAveragePool) {
@@ -1020,7 +1072,7 @@ TEST(PoolTest, GlobalAveragePool) {
 
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {});
 }
 
 TEST(PoolTest, GlobalAveragePool_Large_128) {
@@ -1033,7 +1085,7 @@ TEST(PoolTest, GlobalAveragePool_Large_128) {
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals,
                         /*sort_output=*/false, /*rel_error=*/1e-3f, /*abs_error=*/1e-2f);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {});
 }
 
 TEST(PoolTest, GlobalAveragePool_Large_256) {
@@ -1046,7 +1098,7 @@ TEST(PoolTest, GlobalAveragePool_Large_256) {
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals,
                         /*sort_output=*/false, /*rel_error=*/1e-3f, /*abs_error=*/1e-2f);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {});
 }
 
 TEST(PoolTest, LpPool) {
@@ -1353,7 +1405,7 @@ TEST(PoolTest, LpPool) {
 
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider});
 }
 
 // test data generated with lp_pool_test_generator.py
@@ -1385,7 +1437,8 @@ TEST(PoolTest, LpPool1d) {
 
       // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a94f434942252e6d98ac17705c06ce060
       // TensorRT does not support 1d pooling
-      test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+      test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+               {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider});
       y_count++;
     }
 }
@@ -1417,7 +1470,7 @@ TEST(PoolTest, LpPool2d) {
       test.AddAttribute("kernel_shape", kernel_sizes[kernel_size_count]);
 
       test.AddOutput<float>("Y", y_sizes[y_count], ys[y_count]);
-      test.Run();
+      test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider});
       y_count++;
     }
 }
@@ -1435,7 +1488,8 @@ TEST(PoolTest, LpPoolCeilMode) {
 
   // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a94f434942252e6d98ac17705c06ce060
   // TensorRT does not support 1d pooling
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider});
 }
 
 TEST(PoolTest, GlobalLpPool) {
@@ -1690,7 +1744,7 @@ TEST(PoolTest, GlobalLpPool) {
 
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider});
 }
 
 TEST(PoolTest, MaxPoolDimWithZeroForN) {
@@ -1707,7 +1761,8 @@ TEST(PoolTest, MaxPoolDimWithZeroForN) {
 
   test.AddInput<float>("X", x_dims, x_vals);
   test.AddOutput<float>("Y", expected_dims, expected_vals);
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kTensorrtExecutionProvider, kQnnExecutionProvider});
 }
 
 }  // namespace test
diff --git a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc
index 2f97f6e71e92b..58a616717316e 100644
--- a/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc
+++ b/onnxruntime/test/providers/cpu/object_detection/roialign_test.cc
@@ -4,6 +4,7 @@
 #include "gtest/gtest.h"
 #include "test/providers/provider_test_utils.h"
 #include "test/util/include/default_providers.h"
+#include "test/common/trt_op_test_utils.h"
 
 namespace onnxruntime {
 namespace test {
@@ -463,6 +464,7 @@ static void BasicTest() {
                                            0.3661f,
                                            0.2349f,
                                        });
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -689,6 +691,7 @@ TEST(RoiAlignTest, MaxModePositive) {
                                           });*/
   test.Run();
 }
+
 TEST(RoiAlignTest, AvgModeNegativeInvalidMode) {
   // TODO: Unskip when fixed #41968513
   if (DefaultDmlExecutionProvider().get() != nullptr) {
@@ -713,7 +716,8 @@ TEST(RoiAlignTest, AvgModeNegativeInvalidMode) {
   test.AddInput<int64_t>("batch_indices", {5}, {0, 0, 0, 0, 0});
   test.AddOutput<float>("Y", {5, 3, 3, 4}, {2.95833f, 3.20833f, 3.45833f, 3.70833f, 4.625f, 4.875f, 5.125f, 5.375f, 6.29167f, 6.54167f, 6.79167f, 7.04167f, 27.9583f, 28.2083f, 28.4583f, 28.7083f, 29.625f, 29.875f, 30.125f, 30.375f, 31.2917f, 31.5417f, 31.7917f, 32.0417f, 52.9583f, 53.2083f, 53.4583f, 53.7083f, 54.625f, 54.875f, 55.125f, 55.375f, 56.2917f, 56.5417f, 56.7917f, 57.0417f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 7.39583f, 7.39583f, 7.42708f, 7.64583f, 9.0625f, 9.0625f, 9.09375f, 9.3125f, 10.7292f, 10.7292f, 10.7604f, 10.9792f, 32.3958f, 32.3958f, 32.4271f, 32.6458f, 34.0625f, 34.0625f, 34.0938f, 34.3125f, 35.7292f, 35.7292f, 35.7604f, 35.9792f, 57.3958f, 57.3958f, 57.4271f, 57.6458f, 59.0625f, 59.0625f, 59.0938f, 59.3125f, 60.7292f, 60.7292f, 60.7604f, 60.9792f, 4.27083f, 4.52083f, 4.77083f, 5.02083f, 5.9375f, 6.1875f, 6.4375f, 6.6875f, 7.60417f, 7.85417f, 8.10417f, 8.35417f, 29.2708f, 29.5208f, 29.7708f, 30.0208f, 30.9375f, 31.1875f, 31.4375f, 31.6875f, 32.6042f, 32.8542f, 33.1042f, 33.3542f, 54.2708f, 54.5208f, 54.7708f, 55.0208f, 55.9375f, 56.1875f, 56.4375f, 56.6875f, 57.6042f, 57.8542f, 58.1042f, 58.3542f, 6.77083f, 6.77083f, 6.77083f, 6.80208f, 8.4375f, 8.4375f, 8.4375f, 8.46875f, 10.1042f, 10.1042f, 10.1042f, 10.1354f, 31.7708f, 31.7708f, 31.7708f, 31.8021f, 33.4375f, 33.4375f, 33.4375f, 33.4688f, 35.1042f, 35.1042f, 35.1042f, 35.1354f, 56.7708f, 56.7708f, 56.7708f, 56.8021f, 58.4375f, 58.4375f, 58.4375f, 58.4688f, 60.1042f, 60.1042f, 60.1042f, 60.1354f});
 
-  test.Run(OpTester::ExpectResult::kExpectFailure, "Invalid mode");
+  // TRT EP segmentation fault in A100
+  test.Run(OpTester::ExpectResult::kExpectFailure, "Invalid mode", ExcludeTrtOnA100());
 }
 
 TEST(RoiAlignTest, AvgModeNegativeSamplingRatio) {
@@ -738,7 +742,8 @@ TEST(RoiAlignTest, AvgModeNegativeSamplingRatio) {
   test.AddInput<int64_t>("batch_indices", {5}, {0, 0, 0, 0, 0});
   test.AddOutput<float>("Y", {5, 3, 3, 4}, {2.95833f, 3.20833f, 3.45833f, 3.70833f, 4.625f, 4.875f, 5.125f, 5.375f, 6.29167f, 6.54167f, 6.79167f, 7.04167f, 27.9583f, 28.2083f, 28.4583f, 28.7083f, 29.625f, 29.875f, 30.125f, 30.375f, 31.2917f, 31.5417f, 31.7917f, 32.0417f, 52.9583f, 53.2083f, 53.4583f, 53.7083f, 54.625f, 54.875f, 55.125f, 55.375f, 56.2917f, 56.5417f, 56.7917f, 57.0417f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 25.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 7.39583f, 7.39583f, 7.42708f, 7.64583f, 9.0625f, 9.0625f, 9.09375f, 9.3125f, 10.7292f, 10.7292f, 10.7604f, 10.9792f, 32.3958f, 32.3958f, 32.4271f, 32.6458f, 34.0625f, 34.0625f, 34.0938f, 34.3125f, 35.7292f, 35.7292f, 35.7604f, 35.9792f, 57.3958f, 57.3958f, 57.4271f, 57.6458f, 59.0625f, 59.0625f, 59.0938f, 59.3125f, 60.7292f, 60.7292f, 60.7604f, 60.9792f, 4.27083f, 4.52083f, 4.77083f, 5.02083f, 5.9375f, 6.1875f, 6.4375f, 6.6875f, 7.60417f, 7.85417f, 8.10417f, 8.35417f, 29.2708f, 29.5208f, 29.7708f, 30.0208f, 30.9375f, 31.1875f, 31.4375f, 31.6875f, 32.6042f, 32.8542f, 33.1042f, 33.3542f, 54.2708f, 54.5208f, 54.7708f, 55.0208f, 55.9375f, 56.1875f, 56.4375f, 56.6875f, 57.6042f, 57.8542f, 58.1042f, 58.3542f, 6.77083f, 6.77083f, 6.77083f, 6.80208f, 8.4375f, 8.4375f, 8.4375f, 8.46875f, 10.1042f, 10.1042f, 10.1042f, 10.1354f, 31.7708f, 31.7708f, 31.7708f, 31.8021f, 33.4375f, 33.4375f, 33.4375f, 33.4688f, 35.1042f, 35.1042f, 35.1042f, 35.1354f, 56.7708f, 56.7708f, 56.7708f, 56.8021f, 58.4375f, 58.4375f, 58.4375f, 58.4688f, 60.1042f, 60.1042f, 60.1042f, 60.1354f});
 
-  test.Run(OpTester::ExpectResult::kExpectFailure, "Sampling ratio should be >=0");
+  // TRT EP segmentation fault in A100
+  test.Run(OpTester::ExpectResult::kExpectFailure, "Sampling ratio should be >=0", ExcludeTrtOnA100());
 }
 
 TEST(RoiAlignTest, AvgModeNegativeInvalidNumRoiDims) {
diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc
index b0e0a0dd0d564..2902995df1e71 100644
--- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc
+++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc
@@ -3541,6 +3541,7 @@ TEST(ReductionOpTest, ReduceDimWithZero1) {
                {
                    kCoreMLExecutionProvider,
                    kCudaExecutionProvider,
+                   kCudaNHWCExecutionProvider,
                    kDnnlExecutionProvider,
                    kMIGraphXExecutionProvider,
                    kOpenVINOExecutionProvider,
@@ -3591,6 +3592,7 @@ TEST(ReductionOpTest, ReduceDimWithZero2) {
                {
                    kCoreMLExecutionProvider,
                    kCudaExecutionProvider,
+                   kCudaNHWCExecutionProvider,
                    kDnnlExecutionProvider,
                    kMIGraphXExecutionProvider,
                    kOpenVINOExecutionProvider,
@@ -5779,6 +5781,7 @@ void test_empty_set(const std::string& op, int opset, bool axes_as_input, float
       {
           kCoreMLExecutionProvider,
           kCudaExecutionProvider,
+          kCudaNHWCExecutionProvider,
           kDmlExecutionProvider,
           kDnnlExecutionProvider,
           kMIGraphXExecutionProvider,
diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_test_cases_generator.py b/onnxruntime/test/providers/cpu/reduction/reduction_test_cases_generator.py
index 727351cae84ac..568a4649f3977 100644
--- a/onnxruntime/test/providers/cpu/reduction/reduction_test_cases_generator.py
+++ b/onnxruntime/test/providers/cpu/reduction/reduction_test_cases_generator.py
@@ -59,7 +59,7 @@ def PrintResult(op, axes, keepdims, res):  # noqa: N802
 
     print(" // expected values")
     print("{", end="")
-    for i in range(0, res.size):
+    for i in range(res.size):
         print("%5.6ff," % res.item(i))
 
     print("})},")
@@ -128,7 +128,7 @@ def PrintReenableOptimizations():  # noqa: N802
     print("ReductionTestCases testcases = {")
     print("// input_data")
     print("{")
-    for i in range(0, input_data.size):
+    for i in range(input_data.size):
         print(
             "%5.6ff," % input_data.item(i),
         )
diff --git a/onnxruntime/test/providers/cpu/rnn/GRU.py b/onnxruntime/test/providers/cpu/rnn/GRU.py
index 144acaf14db61..f141710cf31ef 100644
--- a/onnxruntime/test/providers/cpu/rnn/GRU.py
+++ b/onnxruntime/test/providers/cpu/rnn/GRU.py
@@ -84,7 +84,7 @@ def run(self):
             hidden_size = f_output.shape[3]
 
             output = np.empty((0, 2, batch_size, hidden_size), np.float32)
-            for x in range(0, seq_length):
+            for x in range(seq_length):
                 output = np.append(output, f_output[x])
                 output = np.append(output, r_output_orig_input_order[x])
 
diff --git a/onnxruntime/test/providers/cpu/rnn/LSTM.py b/onnxruntime/test/providers/cpu/rnn/LSTM.py
index 116ec3671bf01..49e28a93385a4 100644
--- a/onnxruntime/test/providers/cpu/rnn/LSTM.py
+++ b/onnxruntime/test/providers/cpu/rnn/LSTM.py
@@ -124,7 +124,7 @@ def run(self):
             output = np.empty((0, 2, batch_size, hidden_size), np.float32)
             # Y_h = np.empty((0, 2, batch_size, hidden_size), np.float32)
             # Y_c = np.empty((0, 2, hidden_size, hidden_size), np.float32)
-            for x in range(0, seq_length):
+            for x in range(seq_length):
                 output = np.append(output, f_output[x])
                 output = np.append(output, r_output_orig_input_order[x])
 
diff --git a/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc b/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc
index 7e81fc80ddf85..e73a1b492cc05 100644
--- a/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc
+++ b/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc
@@ -143,6 +143,8 @@ static void RunLstmTest(const std::vector<float>& X_data,
     test.AddOptionalOutputEdge<float>();
   }
 
+  test.SetOutputTolerance(0.0001f);
+
   // TensorRT failed on LSTM tests
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
 }
diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc
index b9875b9553a55..38734ab9f668f 100644
--- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc
+++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc
@@ -120,15 +120,11 @@ TEST(RNNTest, RNN_bidirectional_bias_initial_zigged_batch) {
   test.AddOutput<float>("Y_h", Y_h_dims, Y_h_data);
 
   // TensorRT failed on RNN tests
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
 }
 
 // Doesn't work with CUDA 11.4 on Windows. Need investigation.
-#if defined(USE_CUDA) && defined(_WIN32)
-TEST(RNNTest, DISABLED_RNN_bidirectional_zigged_batch) {
-#else
 TEST(RNNTest, RNN_bidirectional_zigged_batch) {
-#endif
   OpTester test("RNN");
   int64_t num_directions = 2, input_size = 2, hidden_size = 3, seq_length = 5;
 
@@ -275,15 +271,11 @@ TEST(RNNTest, RNN_reverse_direction_zigged_batch) {
   std::vector<float> Y_h_data({0.87014002F, 0.09402763F, -0.54269236F, 0.64809889F, -0.19472955F, -0.24271242F});
   test.AddOutput<float>("Y_h", Y_h_dims, Y_h_data);
 
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
 }
 
 // Doesn't work with CUDA 11.4 on Windows. Need investigation.
-#if defined(USE_CUDA) && defined(_WIN32)
-TEST(RNNTest, DISABLED_RNN_forward_direction_zigged_batch) {
-#else
 TEST(RNNTest, RNN_forward_direction_zigged_batch) {
-#endif
   OpTester test("RNN");
   int64_t num_directions = 1, input_size = 2, hidden_size = 3, seq_length = 5;
 
@@ -357,12 +349,7 @@ TEST(RNNTest, RNN_forward_direction_zigged_batch) {
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
 }
 
-// Doesn't work with CUDA 11.4 on Windows. Need investigation.
-#if defined(USE_CUDA) && defined(_WIN32)
-TEST(RNNTest, DISABLED_RNN_bidirectional_0) {
-#else
 TEST(RNNTest, RNN_bidirectional_0) {
-#endif
   OpTester test("RNN");
   int64_t num_directions = 2, input_size = 2, hidden_size = 3, batch_size = 1, seq_length = 5;
 
@@ -424,12 +411,7 @@ TEST(RNNTest, RNN_bidirectional_0) {
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
 }
 
-// Doesn't work with CUDA 11.4 on Windows. Need investigation.
-#if defined(USE_CUDA) && defined(_WIN32)
-TEST(RNNTest, DISABLED_RNN_bidirectional_1) {
-#else
 TEST(RNNTest, RNN_bidirectional_1) {
-#endif
   OpTester test("RNN");
   int64_t num_directions = 2, input_size = 2, hidden_size = 2, batch_size = 1, seq_length = 1;
 
@@ -597,7 +579,7 @@ TEST(RNNTest, DISABLED_RNN_default_attributes_and_forward_direction) {
   }
 }
 
-TEST(RNNTest, DISABLED_RNN_reverse_direction) {
+TEST(RNNTest, RNN_reverse_direction) {
   int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 1, seq_length = 5;
 
   // In case of useDefault, attributes, inputs or outputs are not set.
@@ -762,7 +744,9 @@ TEST(RNNTest, RNN_invalid_sequence_lens) {
     test.AddOutput<float>("Y_h", Y_h_dims, Y_h_data);
 
     // the CUDA RNN version allows the invalid sequence lengths, so disable testing on CUDA and TensorRT
-    test.Run(OpTester::ExpectResult::kExpectFailure, error_msg, {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+    test.Run(OpTester::ExpectResult::kExpectFailure, error_msg,
+             {kCudaExecutionProvider, kCudaNHWCExecutionProvider,
+              kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
   };
 
   // should batch batch_size to be valid
@@ -860,7 +844,8 @@ TEST(RNNTest, RNN_bidirectional_with_sequence_lens) {
 
   test.AddOutput<float>("Y_h", Y_h_dims, Y_h_data);
 
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
 }
 
 TEST(RNNTest, RNN_with_invalid_activation_load_failure) {
diff --git a/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc b/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc
index 60e75811e4333..c2d64b8e5ee4a 100644
--- a/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc
+++ b/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc
@@ -442,6 +442,19 @@ TEST(SequenceOpsTest, SplitToSequence_PositiveAxisScalarSplit) {
   test.Run();
 }
 
+TEST(SequenceOpsTest, SplitToSequence_StringSplit) {
+  OpTester test("SplitToSequence", 11);
+  test.AddInput<std::string>("input", {3}, std::vector<std::string>({"Test string", "Another string", "A third and much longer string"}));
+  int64_t axis = 0;
+  test.AddAttribute("axis", axis);
+  SeqTensors<std::string> output;
+  output.AddTensor({1}, {"Test string"});
+  output.AddTensor({1}, {"Another string"});
+  output.AddTensor({1}, {"A third and much longer string"});
+  test.AddSeqOutput("S2", output);
+  test.Run();
+}
+
 TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0UnevenSplitFloat) {
   OpTester test("SplitToSequence", 11);
   test.AddInput<float>("input", {5, 2}, GetConsecutiveVector<float>(1.f, 10));
diff --git a/onnxruntime/test/providers/cpu/tensor/affine_grid_test.cc b/onnxruntime/test/providers/cpu/tensor/affine_grid_test.cc
index e37e784f28930..1ffe6c73d4fa4 100644
--- a/onnxruntime/test/providers/cpu/tensor/affine_grid_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/affine_grid_test.cc
@@ -13,6 +13,7 @@ TEST(AffineGridTest, 2d) {
   test.AddInput<int64_t>("size", {4}, {1, 1, 2, 3});
   test.AddOutput<float>("grid", {1, 2, 3, 2},
                         {-0.6667f, -0.5000f, 0.0000f, -0.5000f, 0.6667f, -0.5000f, -0.6667f, 0.5000f, 0.0000f, 0.5000f, 0.6667f, 0.5000f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -24,6 +25,7 @@ TEST(AffineGridTest, test_2d_0) {
   test.AddInput<float>("theta", {1, 2, 3}, {1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f});
   test.AddInput<int64_t>("size", {4}, {1, 1, 3, 2});
   test.AddOutput<float>("grid", {1, 3, 2, 2}, {-0.3228f, -0.9151f, 1.1544f, -0.7414f, -0.4386f, -0.5868f, 1.0386f, -0.4132f, -0.5544f, -0.2586f, 0.9228f, -0.0849f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -33,6 +35,7 @@ TEST(AffineGridTest, test_2d_1) {
   test.AddInput<float>("theta", {2, 2, 3}, {1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f, 1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f});
   test.AddInput<int64_t>("size", {4}, {2, 10, 2, 3});
   test.AddOutput<float>("grid", {2, 2, 3, 2}, {-0.5980f, -0.8620f, 0.3868f, -0.7462f, 1.3716f, -0.6304f, -0.7716f, -0.3696f, 0.2132f, -0.2538f, 1.1980f, -0.1380f, -0.5980f, -0.8620f, 0.3868f, -0.7462f, 1.3716f, -0.6304f, -0.7716f, -0.3696f, 0.2132f, -0.2538f, 1.1980f, -0.1380f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -42,6 +45,7 @@ TEST(AffineGridTest, test_2d_2) {
   test.AddInput<float>("theta", {1, 2, 3}, {1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f});
   test.AddInput<int64_t>("size", {4}, {1, 1, 3, 2});
   test.AddOutput<float>("grid", {1, 3, 2, 2}, {-0.6726f, -2.7663f, 0.8274f, -1.9003f, -1.2500f, -0.9330f, 0.2500f, -0.0670f, -1.8274f, 0.9003f, -0.3274f, 1.7663f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -51,6 +55,7 @@ TEST(AffineGridTest, test_2d_3) {
   test.AddInput<float>("theta", {2, 2, 3}, {1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f, 1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f});
   test.AddInput<int64_t>("size", {4}, {2, 10, 2, 3});
   test.AddOutput<float>("grid", {2, 2, 3, 2}, {-1.0670f, -2.4524f, -0.0670f, -1.8750f, 0.9330f, -1.2976f, -1.9330f, 0.2976f, -0.9330f, 0.8750f, 0.0670f, 1.4524f, -1.0670f, -2.4524f, -0.0670f, -1.8750f, 0.9330f, -1.2976f, -1.9330f, 0.2976f, -0.9330f, 0.8750f, 0.0670f, 1.4524f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -60,6 +65,7 @@ TEST(AffineGridTest, test_2d_4) {
   test.AddInput<float>("theta", {1, 2, 3}, {1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f});
   test.AddInput<int64_t>("size", {4}, {1, 1, 3, 2});
   test.AddOutput<float>("grid", {1, 3, 2, 2}, {-1.0036f, -1.1661f, 1.9509f, -0.8188f, -1.1772f, -0.6736f, 1.7772f, -0.3264f, -1.3509f, -0.1812f, 1.6036f, 0.1661f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -69,6 +75,7 @@ TEST(AffineGridTest, test_2d_5) {
   test.AddInput<float>("theta", {2, 2, 3}, {1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f, 1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f});
   test.AddInput<int64_t>("size", {4}, {2, 10, 2, 3});
   test.AddOutput<float>("grid", {2, 2, 3, 2}, {-1.0036f, -1.1661f, 0.4736f, -0.9924f, 1.9509f, -0.8188f, -1.3509f, -0.1812f, 0.1264f, -0.0076f, 1.6036f, 0.1661f, -1.0036f, -1.1661f, 0.4736f, -0.9924f, 1.9509f, -0.8188f, -1.3509f, -0.1812f, 0.1264f, -0.0076f, 1.6036f, 0.1661f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -78,6 +85,7 @@ TEST(AffineGridTest, test_2d_6) {
   test.AddInput<float>("theta", {1, 2, 3}, {1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f});
   test.AddInput<int64_t>("size", {4}, {1, 1, 3, 2});
   test.AddOutput<float>("grid", {1, 3, 2, 2}, {-1.1340f, -4.1160f, 1.8660f, -2.3840f, -2.0000f, -1.3660f, 1.0000f, 0.3660f, -2.8660f, 1.3840f, 0.1340f, 3.1160f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -87,6 +95,7 @@ TEST(AffineGridTest, test_2d_7) {
   test.AddInput<float>("theta", {2, 2, 3}, {1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f, 1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f});
   test.AddInput<int64_t>("size", {4}, {2, 10, 2, 3});
   test.AddOutput<float>("grid", {2, 2, 3, 2}, {-1.1340f, -4.1160f, 0.3660f, -3.2500f, 1.8660f, -2.3840f, -2.8660f, 1.3840f, -1.3660f, 2.2500f, 0.1340f, 3.1160f, -1.1340f, -4.1160f, 0.3660f, -3.2500f, 1.8660f, -2.3840f, -2.8660f, 1.3840f, -1.3660f, 2.2500f, 0.1340f, 3.1160f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -96,6 +105,7 @@ TEST(AffineGridTest, test_3d_0) {
   test.AddInput<float>("theta", {1, 3, 4}, {1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f});
   test.AddInput<int64_t>("size", {5}, {1, 1, 3, 2, 2});
   test.AddOutput<float>("grid", {1, 3, 2, 2, 3}, {-0.7468f, -1.3266f, 1.5323f, 0.6627f, -1.2078f, 1.3639f, -0.7468f, 0.6430f, 1.6191f, 0.6627f, 0.7618f, 1.4507f, -0.4048f, -1.5442f, 1.8408f, 1.0048f, -1.4254f, 1.6724f, -0.4048f, 0.4254f, 1.9276f, 1.0048f, 0.5442f, 1.7592f, -0.0627f, -1.7618f, 2.1493f, 1.3468f, -1.6430f, 1.9809f, -0.0627f, 0.2078f, 2.2361f, 1.3468f, 0.3266f, 2.0677f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -105,6 +115,7 @@ TEST(AffineGridTest, test_3d_1) {
   test.AddInput<float>("theta", {2, 3, 4}, {1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f, 1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f});
   test.AddInput<int64_t>("size", {5}, {2, 10, 2, 2, 3});
   test.AddOutput<float>("grid", {2, 2, 2, 3, 3}, {-0.8962f, -1.4008f, 1.6375f, 0.0435f, -1.3216f, 1.5252f, 0.9832f, -1.2424f, 1.4130f, -0.8962f, 0.5688f, 1.7243f, 0.0435f, 0.6480f, 1.6121f, 0.9832f, 0.7272f, 1.4998f, -0.3832f, -1.7272f, 2.1002f, 0.5565f, -1.6480f, 1.9879f, 1.4962f, -1.5688f, 1.8757f, -0.3832f, 0.2424f, 2.1870f, 0.5565f, 0.3216f, 2.0748f, 1.4962f, 0.4008f, 1.9625f, -0.8962f, -1.4008f, 1.6375f, 0.0435f, -1.3216f, 1.5252f, 0.9832f, -1.2424f, 1.4130f, -0.8962f, 0.5688f, 1.7243f, 0.0435f, 0.6480f, 1.6121f, 0.9832f, 0.7272f, 1.4998f, -0.3832f, -1.7272f, 2.1002f, 0.5565f, -1.6480f, 1.9879f, 1.4962f, -1.5688f, 1.8757f, -0.3832f, 0.2424f, 2.1870f, 0.5565f, 0.3216f, 2.0748f, 1.4962f, 0.4008f, 1.9625f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -114,6 +125,7 @@ TEST(AffineGridTest, test_3d_2) {
   test.AddInput<float>("theta", {1, 3, 4}, {0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f});
   test.AddInput<int64_t>("size", {5}, {1, 1, 3, 2, 2});
   test.AddOutput<float>("grid", {1, 3, 2, 2, 3}, {-0.5299f, 0.8995f, -4.3568f, -0.2701f, -0.3995f, -2.9818f, -0.5299f, 2.3995f, 0.4064f, -0.2701f, 1.1005f, 1.7814f, -0.6299f, -0.6005f, -2.7691f, -0.3701f, -1.8995f, -1.3941f, -0.6299f, 0.8995f, 1.9941f, -0.3701f, -0.3995f, 3.3691f, -0.7299f, -2.1005f, -1.1814f, -0.4701f, -3.3995f, 0.1936f, -0.7299f, -0.6005f, 3.5818f, -0.4701f, -1.8995f, 4.9568f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -123,6 +135,7 @@ TEST(AffineGridTest, test_3d_3) {
   test.AddInput<float>("theta", {2, 3, 4}, {0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f, 0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f});
   test.AddInput<int64_t>("size", {5}, {2, 10, 2, 2, 3});
   test.AddOutput<float>("grid", {2, 2, 2, 3, 3}, {-0.5982f, 0.7410f, -4.1890f, -0.4250f, -0.1250f, -3.2724f, -0.2518f, -0.9910f, -2.3557f, -0.5982f, 2.2410f, 0.5741f, -0.4250f, 1.3750f, 1.4908f, -0.2518f, 0.5090f, 2.4075f, -0.7482f, -1.5090f, -1.8075f, -0.5750f, -2.3750f, -0.8908f, -0.4018f, -3.2410f, 0.0259f, -0.7482f, -0.0090f, 2.9557f, -0.5750f, -0.8750f, 3.8724f, -0.4018f, -1.7410f, 4.7890f, -0.5982f, 0.7410f, -4.1890f, -0.4250f, -0.1250f, -3.2724f, -0.2518f, -0.9910f, -2.3557f, -0.5982f, 2.2410f, 0.5741f, -0.4250f, 1.3750f, 1.4908f, -0.2518f, 0.5090f, 2.4075f, -0.7482f, -1.5090f, -1.8075f, -0.5750f, -2.3750f, -0.8908f, -0.4018f, -3.2410f, 0.0259f, -0.7482f, -0.0090f, 2.9557f, -0.5750f, -0.8750f, 3.8724f, -0.4018f, -1.7410f, 4.7890f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -132,6 +145,7 @@ TEST(AffineGridTest, test_3d_4) {
   test.AddInput<float>("theta", {1, 3, 4}, {1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f});
   test.AddInput<int64_t>("size", {5}, {1, 1, 3, 2, 2});
   test.AddOutput<float>("grid", {1, 3, 2, 2, 3}, {-1.6226f, -2.2620f, 1.4189f, 1.1965f, -2.0245f, 1.0821f, -1.6226f, 1.6772f, 1.5925f, 1.1965f, 1.9147f, 1.2557f, -1.1095f, -2.5884f, 1.8816f, 1.7095f, -2.3508f, 1.5448f, -1.1095f, 1.3508f, 2.0552f, 1.7095f, 1.5884f, 1.7184f, -0.5965f, -2.9147f, 2.3443f, 2.2226f, -2.6772f, 2.0075f, -0.5965f, 1.0245f, 2.5179f, 2.2226f, 1.2620f, 2.1811f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -141,6 +155,7 @@ TEST(AffineGridTest, test_3d_5) {
   test.AddInput<float>("theta", {2, 3, 4}, {1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f, 1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f});
   test.AddInput<int64_t>("size", {5}, {2, 10, 2, 2, 3});
   test.AddOutput<float>("grid", {2, 2, 2, 3, 3}, {-1.6226f, -2.2620f, 1.4189f, -0.2130f, -2.1433f, 1.2505f, 1.1965f, -2.0245f, 1.0821f, -1.6226f, 1.6772f, 1.5925f, -0.2130f, 1.7960f, 1.4241f, 1.1965f, 1.9147f, 1.2557f, -0.5965f, -2.9147f, 2.3443f, 0.8130f, -2.7960f, 2.1759f, 2.2226f, -2.6772f, 2.0075f, -0.5965f, 1.0245f, 2.5179f, 0.8130f, 1.1433f, 2.3495f, 2.2226f, 1.2620f, 2.1811f, -1.6226f, -2.2620f, 1.4189f, -0.2130f, -2.1433f, 1.2505f, 1.1965f, -2.0245f, 1.0821f, -1.6226f, 1.6772f, 1.5925f, -0.2130f, 1.7960f, 1.4241f, 1.1965f, 1.9147f, 1.2557f, -0.5965f, -2.9147f, 2.3443f, 0.8130f, -2.7960f, 2.1759f, 2.2226f, -2.6772f, 2.0075f, -0.5965f, 1.0245f, 2.5179f, 0.8130f, 1.1433f, 2.3495f, 2.2226f, 1.2620f, 2.1811f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -150,6 +165,7 @@ TEST(AffineGridTest, test_3d_6) {
   test.AddInput<float>("theta", {1, 3, 4}, {0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f});
   test.AddInput<int64_t>("size", {5}, {1, 1, 3, 2, 2});
   test.AddOutput<float>("grid", {1, 3, 2, 2, 3}, {-0.6098f, 1.5490f, -8.2197f, -0.0902f, -1.0490f, -5.4697f, -0.6098f, 4.5490f, 1.3066f, -0.0902f, 1.9510f, 4.0566f, -0.7598f, -0.7010f, -5.8381f, -0.2402f, -3.2990f, -3.0881f, -0.7598f, 2.2990f, 3.6881f, -0.2402f, -0.2990f, 6.4381f, -0.9098f, -2.9510f, -3.4566f, -0.3902f, -5.5490f, -0.7066f, -0.9098f, 0.0490f, 6.0697f, -0.3902f, -2.5490f, 8.8197f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 
@@ -159,6 +175,7 @@ TEST(AffineGridTest, test_3d_7) {
   test.AddInput<float>("theta", {2, 3, 4}, {0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f, 0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f});
   test.AddInput<int64_t>("size", {5}, {2, 10, 2, 2, 3});
   test.AddOutput<float>("grid", {2, 2, 2, 3, 3}, {-0.6098f, 1.5490f, -8.2197f, -0.3500f, 0.2500f, -6.8447f, -0.0902f, -1.0490f, -5.4697f, -0.6098f, 4.5490f, 1.3066f, -0.3500f, 3.2500f, 2.6816f, -0.0902f, 1.9510f, 4.0566f, -0.9098f, -2.9510f, -3.4566f, -0.6500f, -4.2500f, -2.0816f, -0.3902f, -5.5490f, -0.7066f, -0.9098f, 0.0490f, 6.0697f, -0.6500f, -1.2500f, 7.4447f, -0.3902f, -2.5490f, 8.8197f, -0.6098f, 1.5490f, -8.2197f, -0.3500f, 0.2500f, -6.8447f, -0.0902f, -1.0490f, -5.4697f, -0.6098f, 4.5490f, 1.3066f, -0.3500f, 3.2500f, 2.6816f, -0.0902f, 1.9510f, 4.0566f, -0.9098f, -2.9510f, -3.4566f, -0.6500f, -4.2500f, -2.0816f, -0.3902f, -5.5490f, -0.7066f, -0.9098f, 0.0490f, 6.0697f, -0.6500f, -1.2500f, 7.4447f, -0.3902f, -2.5490f, 8.8197f});
+  test.SetOutputTolerance(0.0001f);
   test.Run();
 }
 }  // namespace test
diff --git a/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc
index 8a8bc5560c084..b4bd3fca7b712 100644
--- a/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc
@@ -383,7 +383,7 @@ TEST(GatherElementsOpTest, IndicesOutOfBounds) {
   // skip openvino which will not throw error message but will ensure no out-of-bound access
   // skip TensorRT because it doesn't support out of bounds indices
   test.Run(OpTester::ExpectResult::kExpectFailure, "",
-           {kCudaExecutionProvider, kRocmExecutionProvider, kOpenVINOExecutionProvider,
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kOpenVINOExecutionProvider,
             kTensorrtExecutionProvider, kDmlExecutionProvider});
 }
 
diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc
index 0f097622abff0..5c89d6ea7bd75 100644
--- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc
@@ -6,6 +6,33 @@
 
 namespace onnxruntime {
 namespace test {
+
+std::vector<std::unique_ptr<IExecutionProvider>> GetExecutionProviders(int opset_version) {
+  ORT_UNUSED_PARAMETER(opset_version);
+
+  std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
+
+  execution_providers.emplace_back(DefaultCpuExecutionProvider());
+#ifdef USE_CUDA
+  if (opset_version < 20) {
+    execution_providers.emplace_back(DefaultCudaExecutionProvider());
+#ifdef ENABLE_CUDA_NHWC_OPS
+    execution_providers.push_back(DefaultCudaNHWCExecutionProvider());
+#endif
+  }
+
+#endif
+  return execution_providers;
+}
+
+template <typename T>
+void RunTests(T& test, std::vector<std::unique_ptr<IExecutionProvider>>&& execution_providers) {
+  for (size_t idx = 0; idx < execution_providers.size(); ++idx) {
+    test.ConfigEp(std::move(execution_providers[idx])).RunWithConfig();
+  }
+  execution_providers.clear();
+}
+
 // DO NOT edit following tests. They are generated by:
 // onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py
 TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) {
@@ -25,8 +52,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) {
@@ -46,8 +72,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) {
@@ -67,8 +92,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) {
@@ -88,8 +112,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) {
@@ -109,8 +132,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) {
@@ -130,8 +152,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners)
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) {
@@ -151,8 +172,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) {
@@ -172,8 +192,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) {
@@ -193,8 +212,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) {
@@ -214,8 +232,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) {
@@ -235,8 +252,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) {
@@ -256,8 +272,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) {
@@ -277,8 +292,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) {
@@ -298,8 +312,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) {
@@ -319,8 +332,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) {
@@ -340,8 +352,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) {
@@ -361,8 +372,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) {
@@ -382,8 +392,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners)
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(16));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) {
@@ -403,8 +412,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) {
@@ -424,8 +432,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) {
@@ -445,8 +452,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) {
@@ -466,8 +472,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) {
@@ -487,8 +492,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) {
@@ -508,8 +512,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) {
@@ -529,8 +532,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) {
@@ -550,8 +552,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) {
@@ -571,8 +572,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) {
@@ -592,8 +592,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) {
@@ -613,8 +612,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners)
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) {
@@ -634,8 +632,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners)
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) {
@@ -655,8 +652,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) {
@@ -676,8 +672,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) {
@@ -697,8 +692,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) {
@@ -718,8 +712,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) {
@@ -739,8 +732,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) {
@@ -760,8 +752,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) {
@@ -781,8 +772,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) {
@@ -802,8 +792,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) {
@@ -823,8 +812,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) {
@@ -844,8 +832,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) {
@@ -865,8 +852,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) {
@@ -886,8 +872,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) {
@@ -907,8 +892,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) {
@@ -928,8 +912,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) {
@@ -949,8 +932,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) {
@@ -970,8 +952,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) {
@@ -991,8 +972,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) {
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
 
 TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) {
@@ -1012,8 +992,8 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners)
   test.AddAttribute("padding_mode", padding_mode);
   test.AddAttribute("align_corners", align_corners);
   test.AddOutput<float>("Y", Y_shape, Y_data);
-  test.ConfigEp(DefaultCpuExecutionProvider())
-      .RunWithConfig();
+  RunTests(test, GetExecutionProviders(20));
 }
+
 }  // namespace test
 }  // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py
index e4d58e79243ef..c60e55617774f 100644
--- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py
+++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py
@@ -76,6 +76,6 @@
                     print('test.AddAttribute("padding_mode", padding_mode);')
                     print('test.AddAttribute("align_corners", align_corners);')
                     print('test.AddOutput<float>("Y", Y_shape, Y_data);')
-                    print("test.Run();")
+                    print(f"RunTests(test, GetExecutionProviders({opset_version}));")
                     print("}")
                     print("\n")
diff --git a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc
index 2e583c5d2547b..bd97306142f18 100644
--- a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc
@@ -99,6 +99,48 @@ TEST(IsInfTest, test_isinf_negative_double20) {
   run_is_inf_test(20, 0, 1, input, output);
 }
 
+TEST(IsInfTest, test_isinf_mlfloat16) {
+  std::initializer_list<MLFloat16> input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16,
+                                            MLFloat16::NegativeInfinity, MLFloat16::Infinity};
+  std::initializer_list<bool> output = {false, false, true, false, true, true};
+  run_is_inf_test(20, 1, 1, input, output);
+}
+
+TEST(IsInfTest, test_isinf_positive_mlfloat16) {
+  std::initializer_list<MLFloat16> input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16,
+                                            MLFloat16::NegativeInfinity, MLFloat16::Infinity};
+  std::initializer_list<bool> output = {false, false, true, false, false, true};
+  run_is_inf_test(20, 1, 0, input, output);
+}
+
+TEST(IsInfTest, test_isinf_negative_mlfloat16) {
+  std::initializer_list<MLFloat16> input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16,
+                                            MLFloat16::NegativeInfinity, MLFloat16::Infinity};
+  std::initializer_list<bool> output = {false, false, false, false, true, false};
+  run_is_inf_test(20, 0, 1, input, output);
+}
+
+TEST(IsInfTest, test_isinf_bfloat16) {
+  std::initializer_list<BFloat16> input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16,
+                                           BFloat16::NegativeInfinity, BFloat16::Infinity};
+  std::initializer_list<bool> output = {false, false, true, false, true, true};
+  run_is_inf_test(20, 1, 1, input, output);
+}
+
+TEST(IsInfTest, test_isinf_positive_bfloat16) {
+  std::initializer_list<BFloat16> input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16,
+                                           BFloat16::NegativeInfinity, BFloat16::Infinity};
+  std::initializer_list<bool> output = {false, false, true, false, false, true};
+  run_is_inf_test(20, 1, 0, input, output);
+}
+
+TEST(IsInfTest, test_isinf_negative_bfloat16) {
+  std::initializer_list<BFloat16> input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16,
+                                           BFloat16::NegativeInfinity, BFloat16::Infinity};
+  std::initializer_list<bool> output = {false, false, false, false, true, false};
+  run_is_inf_test(20, 0, 1, input, output);
+}
+
 #if !defined(DISABLE_FLOAT8_TYPES)
 TEST(IsInfTest, test_Float8E4M3FN) {
   std::initializer_list<Float8E4M3FN> input = {
diff --git a/onnxruntime/test/providers/cpu/tensor/isnan_test.cc b/onnxruntime/test/providers/cpu/tensor/isnan_test.cc
index 0f1e5c07cdd9b..3cf99fde2cce7 100644
--- a/onnxruntime/test/providers/cpu/tensor/isnan_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/isnan_test.cc
@@ -38,9 +38,23 @@ TEST(IsNaNOpTest, IsNaNFloat16_9) {
   run_is_nan_test(9, dims, input, output);
 }
 
+TEST(IsNaNOpTest, IsNaNFloat16_13) {
+  std::vector<int64_t> dims{2, 2};
+  std::initializer_list<MLFloat16> input = {MLFloat16::One, MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN};
+  std::initializer_list<bool> output = {false, true, false, true};
+  run_is_nan_test(13, dims, input, output);
+}
+
 TEST(IsNaNOpTest, IsNaNFloat16_20) {
   std::vector<int64_t> dims{2, 2};
-  std::initializer_list<MLFloat16> input = {MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN};
+  std::initializer_list<MLFloat16> input = {MLFloat16::One, MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN};
+  std::initializer_list<bool> output = {false, true, false, true};
+  run_is_nan_test(20, dims, input, output);
+}
+
+TEST(IsNaNOpTest, IsNaNBFloat16_20) {
+  std::vector<int64_t> dims{2, 2};
+  std::initializer_list<BFloat16> input = {BFloat16::One, BFloat16::NaN, BFloat16(2.0f), BFloat16::NaN};
   std::initializer_list<bool> output = {false, true, false, true};
   run_is_nan_test(20, dims, input, output);
 }
diff --git a/onnxruntime/test/providers/cpu/tensor/mean_variance_normalization_test.cc b/onnxruntime/test/providers/cpu/tensor/mean_variance_normalization_test.cc
index b6720ae2a9a7d..8dcb15cbc6926 100644
--- a/onnxruntime/test/providers/cpu/tensor/mean_variance_normalization_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/mean_variance_normalization_test.cc
@@ -5,6 +5,7 @@
 
 #include "test/common/tensor_op_test_utils.h"
 #include "test/providers/provider_test_utils.h"
+#include "test/util/include/default_providers.h"
 
 namespace onnxruntime::test {
 
@@ -155,6 +156,10 @@ TEST(MeanVarianceNormalizationTest, AxesSubsets5D) {
     test.AddInput<float>("input", shape, X.data(), X.size());
     test.AddOutput<float>("output", shape, Y.data(), Y.size());
 
+    if (DefaultDmlExecutionProvider().get() != nullptr) {
+      test.SetOutputTolerance(0.001f);
+    }
+
     test.Run();
   };
 
diff --git a/onnxruntime/test/providers/cpu/tensor/onehot_op_test.cc b/onnxruntime/test/providers/cpu/tensor/onehot_op_test.cc
index a2ffbdcc0bdf1..55c247e4c2fea 100644
--- a/onnxruntime/test/providers/cpu/tensor/onehot_op_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/onehot_op_test.cc
@@ -3,6 +3,7 @@
 
 #include "gtest/gtest.h"
 #include "test/providers/provider_test_utils.h"
+#include "test/common/trt_op_test_utils.h"
 
 using namespace std;
 
@@ -36,7 +37,8 @@ TEST(OneHotOpTest, DefaultAxis_float_float_float /*indices, output, depth*/) {
                          0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
                          0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
                          0., 0., 0., 0., 0., 0., 1., 0., 0., 0.});
-  test.Run();
+  // TRT EP segmentation fault in A100
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(OneHotOpTest, DefaultAxis_int64_int32_float /*indices, output, depth*/) {
@@ -51,7 +53,7 @@ TEST(OneHotOpTest, DefaultAxis_int64_int32_float /*indices, output, depth*/) {
                            0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
                            0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
                            0, 0, 0, 0, 0, 0, 1, 0, 0, 0});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(OneHotOpTest, DefaultAxis_int64_float_int64 /*indices, output, depth*/) {
@@ -81,7 +83,7 @@ TEST(OneHotOpTest, DefaultAxis_int32_float_float /*indices, output, depth*/) {
                          0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
                          0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
                          0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(OneHotOpTest, DefaultAxis_int32_float_int32 /*indices, output, depth*/) {
@@ -231,7 +233,7 @@ TEST(OneHotOpTest, DefaultAxis_float_float_float_NonZeroOffValue /*indices, outp
                          2., 2., 3., 2., 2., 2., 2., 2., 2., 2.,
                          2., 2., 2., 2., 3., 2., 2., 2., 2., 2.,
                          2., 2., 2., 2., 2., 2., 3., 2., 2., 2.});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(OneHotOpTest, DefaultAxis_int64_int32_float_NonZeroOffValue /*indices, output, depth*/) {
@@ -246,7 +248,7 @@ TEST(OneHotOpTest, DefaultAxis_int64_int32_float_NonZeroOffValue /*indices, outp
                            2, 2, 3, 2, 2, 2, 2, 2, 2, 2,
                            2, 2, 2, 2, 3, 2, 2, 2, 2, 2,
                            2, 2, 2, 2, 2, 2, 3, 2, 2, 2});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(OneHotOpTest, DefaultAxis_int64_float_int64_NonZeroOffValue /*indices, output, depth*/) {
@@ -276,7 +278,7 @@ TEST(OneHotOpTest, DefaultAxis_int32_float_float_NonZeroOffValue /*indices, outp
                          2.0f, 2.0f, 3.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f,
                          2.0f, 2.0f, 2.0f, 2.0f, 3.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f,
                          2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 3.0f, 2.0f, 2.0f, 2.0f});
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(OneHotOpTest, DefaultAxis_int32_float_int32_NonZeroOffValue /*indices, output, depth*/) {
diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc
index 10f02349a24d5..496f2213e9d32 100644
--- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc
@@ -5,13 +5,16 @@
 #include "gtest/gtest.h"
 #include "test/providers/provider_test_utils.h"
 #include "test/util/include/default_providers.h"
+#include "test/common/trt_op_test_utils.h"
 
 namespace onnxruntime {
 namespace test {
+
 TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_tf_crop_and_resize) {
   // TODO: Unskip when fixed #41968513
   if (DefaultDmlExecutionProvider().get() != nullptr) {
-    GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.20000028610229492, which exceeds threshold";
+    GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] "
+                 << "is 0.20000028610229492, which exceeds threshold";
   }
 
   OpTester test("Resize", 13);
@@ -32,7 +35,8 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_tf_crop_and_resize) {
 
   test.AddInput<float>("X", {H, W}, X);
   test.AddInput<float>("roi", {4}, roi);
-  test.AddInput<float>("", {0}, scales);  // opset13 requires either 'sizes' or 'scales' must be provided, but not both of them
+  // opset13 requires either 'sizes' or 'scales' must be provided, but not both of them
+  test.AddInput<float>("", {0}, scales);
   test.AddInput<int64_t>("sizes", {2}, sizes);
 
   std::vector<float> Y = {7.600004f, 7.9f, 8.2f,
@@ -100,7 +104,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr
   // TensorRT: results mismatch
   // ROCm: results mismatch
   test.Run(OpTester::ExpectResult::kExpectSuccess, "",
-           {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
 }
 
 TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extrapolation_uint8) {
@@ -130,7 +134,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr
   test.AddOutput<uint8_t>("Y", {N, static_cast<int64_t>(H * scales[1]), static_cast<int64_t>(W * scales[2]), C}, Y);
   // CUDA: result mismatch due to not implementing NHWC support
   // ROCm: results mismatch
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider});
 }
 
 TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extrapolation_int8) {
@@ -188,7 +193,9 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e
   // CUDA: result mismatch due to not implementing NHWC support
   // ROCm: results mismatch
   // DML: results mismatch
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider});
+  test.Run(
+      OpTester::ExpectResult::kExpectSuccess, "",
+      {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider});
 }
 
 TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_int8) {
@@ -238,7 +245,10 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) {
   std::vector<float> Y = {2.66666651f, 4.3333331f};
 
   test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider});  // QNN: result diff
+  // QNN: result diff
+  // TRT: Segmentation fault in A100
+  std::unordered_set<std::string> excluded_providers({kQnnExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100(excluded_providers));
 }
 
 TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear) {
@@ -262,8 +272,9 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear) {
   test.AddOutput<float>("Y", {N, static_cast<int64_t>(H * scales[1]), static_cast<int64_t>(W * scales[2]), C}, Y);
   // CUDA: result mismatch due to not implementing NHWC support
   // ROCm: results mismatch
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
-           {kCudaExecutionProvider, kRocmExecutionProvider});
+  // TRT: Segmentation fault in A100
+  std::unordered_set<std::string> excluded_providers({kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100(excluded_providers));
 }
 
 TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) {
@@ -287,7 +298,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) {
   test.AddOutput<uint8_t>("Y", {N, static_cast<int64_t>(H * scales[1]), static_cast<int64_t>(W * scales[2]), C}, Y);
   // CUDA: result mismatch due to not implementing NHWC support
   // ROCm: results mismatch
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider});
 }
 
 TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) {
@@ -309,7 +321,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) {
   std::vector<int8_t> Y = {0, 0};
 
   test.AddOutput<int8_t>("Y", {N, static_cast<int64_t>(H * scales[1]), static_cast<int64_t>(W * scales[2]), C}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 // Since NNAPI(TFLite) only using the scale calculate using the input/output size
@@ -317,7 +329,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) {
 // The output size is [1,1,2,4].*[1,1,0.6,0.6]=[1,1,1,2]
 // NNAPI will recaluclate the scales as the output size divided by input size
 // scales = [1,1,1,2]./[1,1,2,4] = [1,1,0.5,0.5]
-// See, https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/reference/reference_ops.h
+// See:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/reference/reference_ops.h
 // So the result of the above example will be different than CPU EP
 // Add the following 2 tests to test with scales valid to NNAPI
 TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1) {
@@ -341,7 +353,7 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1) {
     std::vector<float> Y = {3.5f, 5.5f};
 
     test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-    test.Run();
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
   };
 
   run_test(false);
@@ -399,7 +411,7 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_align_corners) {
     std::vector<float> Y = {1.0f, 4.0f};
 
     test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-    test.Run();
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
   };
 
   run_test(false);
@@ -435,7 +447,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_uin
     test.AddOutput<uint8_t>("Y", {N, static_cast<int64_t>(H * scales[1]), static_cast<int64_t>(W * scales[2]), C}, Y);
     // CUDA: result mismatch due to not implementing NHWC support
     // ROCm: results mismatch
-    test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider});
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+             {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider});
   };
 
   run_test(false);
@@ -475,7 +488,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_int
 TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_2DBilinear_pytorch_half_pixel) {
   // TODO: Unskip when fixed #41968513
   if (DefaultDmlExecutionProvider().get() != nullptr) {
-    GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold";
+    GTEST_SKIP() << "Skipping because of the following error: "
+                 << " The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold";
   }
 
   OpTester test("Resize", 13);
@@ -533,7 +547,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe
   // CUDA: result mismatch due to not implementing NHWC support
   // ROCm: results mismatch
   // DML: results mismatch
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider});
 }
 
 TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8) {
@@ -566,8 +581,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDmlExecutionProvider});
 }
 
-TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_asymmetric) {
-  // To test NNAPI EP, we need the sclaes/sizes to be in initializers
+TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_asymmetric_scales) {
+  // To test CoreML/NNAPI EP, we need the scales/sizes to be in initializers
   auto run_test = [](bool scales_in_initializer) {
     OpTester test("Resize", 13);
     std::vector<float> roi{};
@@ -599,7 +614,7 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_asymmetric) {
         7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 11.0f, 11.0f, 11.0f};
 
     test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-    test.Run();
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
   };
 
   run_test(false);
@@ -644,7 +659,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_uint8) {
                             Y, false, .0f, 1.0f);
     // CUDA: result mismatch due to not implementing NHWC support
     // ROCm: results mismatch
-    test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider});
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+             {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider});
   };
 
   run_test(false);
@@ -715,13 +731,14 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_2DBilinear_align_corners) {
       4.0f, 4.5714290f, 5.142857f, 5.714286f, 6.285714f, 6.8571430f, 7.428571f, 8.0f};
 
   test.AddOutput<float>("Y", {static_cast<int64_t>(H * scales[0]), static_cast<int64_t>(W * scales[1])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_3DTrilinear_pytorch_half_pixel) {
   // TODO: Unskip when fixed #41968513
   if (DefaultDmlExecutionProvider().get() != nullptr) {
-    GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold";
+    GTEST_SKIP() << "Skipping because of the following error: "
+                 << "The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold";
   }
 
   OpTester test("Resize", 13);
@@ -808,7 +825,7 @@ TEST(ResizeOpTest, ResizeOpLinearScalesNoOpTest) {
                             7.0f, 11.0f};
 
     test.AddOutput<float>("Y", {N, C, H, W}, Y);
-    test.Run();
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
   };
 
   run_test(false);
@@ -834,7 +851,7 @@ TEST(ResizeOpTest, ResizeOpNearestDownSampleTest) {
   std::vector<float> Y = {1.0f, 3.0f};
 
   test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpNearestDownSampleTest_Opset12) {
@@ -856,7 +873,7 @@ TEST(ResizeOpTest, ResizeOpNearestDownSampleTest_Opset12) {
   std::vector<float> Y = {1.0f, 3.0f};
 
   test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpNearestDownSampleTest_WithSizes) {
@@ -909,7 +926,7 @@ TEST(ResizeOpTest, ResizeOpNearestDownSampleTest_tf_half_pixel) {
                           14.0f, 16.0f};
 
   test.AddOutput<float>("Y", {N, C, sizes[2], sizes[3]}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpNearestDownSampleTest_tf_crop_and_resize_with_extrapolation) {
@@ -989,7 +1006,7 @@ TEST(ResizeOpTest, ResizeOpNearestUpSampleTest) {
                           3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f};
 
   test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpNearestUpSampleTest_WithSizes_CeilMode) {
@@ -1082,13 +1099,14 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Floor_Align_Corners) {
                           13.0f, 13.0f, 13.0f, 14.0f, 14.0f, 15.0f, 15.0f, 16.0f};
 
   test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpNearest_OneToOneMappingBetweenInputAndOutputDataDims) {
   // TODO: Unskip when fixed #41968513
   if (DefaultDmlExecutionProvider().get() != nullptr) {
-    GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 3, which exceeds threshold";
+    GTEST_SKIP() << "Skipping because of the following error: "
+                 << "The difference between expected[i] and output[i] is 3, which exceeds threshold";
   }
 
   OpTester test("Resize", 12);  // tf_half_pixel_for_nn is deprecated since opset 13
@@ -1185,7 +1203,7 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Nearest2xOptimization_Scales) {
                             3.0f, 3.0f, 4.0f, 4.0f};
 
     test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-    test.Run();
+    test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
   };
 
   run_test(false);
@@ -1250,7 +1268,7 @@ TEST(ResizeOpTest, ResizeOpCubicDownSampleTest) {
                           11.9165f, 13.2266f, 14.5278f};
 
   test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpCubicDownSampleTest_exclude_outside) {
@@ -1280,7 +1298,7 @@ TEST(ResizeOpTest, ResizeOpCubicDownSampleTest_exclude_outside) {
                           11.949f, 13.2503f, 14.5942f};
 
   test.AddOutput<float>("Y", {static_cast<int64_t>(H * scales[0]), static_cast<int64_t>(W * scales[1])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpCubicDownSampleTest_coeff) {
@@ -1307,7 +1325,7 @@ TEST(ResizeOpTest, ResizeOpCubicDownSampleTest_coeff) {
                           11.8701f, 13.168f, 14.4912f};
 
   test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpCubicDownSampleTest_with_roi) {
@@ -1361,7 +1379,7 @@ TEST(ResizeOpTest, ResizeOpCubicDownSampleTest_asymmetric) {
                           11.375f, 12.6719f, 13.9688f};
 
   test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpCubicUpSampleTest) {
@@ -1393,7 +1411,7 @@ TEST(ResizeOpTest, ResizeOpCubicUpSampleTest) {
                           13.375f, 13.7813f, 14.375f, 14.875f, 15.375f, 15.9688f, 16.375f, 16.4688f};
 
   test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpCubicUpSampleTest_MultiChannel) {
@@ -1474,13 +1492,14 @@ TEST(ResizeOpTest, ResizeOpCubicUpSampleTest_tf_half_pixel_for_nn) {
                           13.332f, 13.8086f, 14.4375f, 14.8438f, 15.4727f, 15.9492f, 16.2461f, 16.1758f};
 
   test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_Ver10) {
   // TODO: Unskip when fixed #41968513
   if (DefaultDmlExecutionProvider().get() != nullptr) {
-    GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold";
+    GTEST_SKIP() << "Skipping because of the following error: "
+                 << "The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold";
   }
 
   OpTester test("Resize", 10);
@@ -1499,13 +1518,17 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_Ver10) {
   std::vector<float> Y = {1.0f, 2.66666651f};
 
   test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider});  // QNN: result diff
+  // QNN: result diff
+  // TRT: segmentation fault in A100
+  std::unordered_set<std::string> excluded_providers({kQnnExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100(excluded_providers));
 }
 
 TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_2DBilinear_Ver10) {
   // TODO: Unskip when fixed #41968513
   if (DefaultDmlExecutionProvider().get() != nullptr) {
-    GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold";
+    GTEST_SKIP() << "Skipping because of the following error: "
+                 << "The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold ";
   }
 
   OpTester test("Resize", 10);
@@ -1524,13 +1547,14 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_2DBilinear_Ver10) {
   std::vector<float> Y = {1.0f, 2.66666651f};
 
   test.AddOutput<float>("Y", {static_cast<int64_t>(H * scales[0]), static_cast<int64_t>(W * scales[1])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_Ver10) {
   // TODO: Unskip when fixed #41968513
   if (DefaultDmlExecutionProvider().get() != nullptr) {
-    GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.5, which exceeds threshold";
+    GTEST_SKIP() << "Skipping because of the following error: "
+                 << "The difference between expected[i] and output[i] is 0.5, which exceeds threshold";
   }
 
   OpTester test("Resize", 10);
@@ -1559,13 +1583,17 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_Ver10) {
       7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 11.0f, 11.0f, 11.0f};
 
   test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider});  // QNN: result diff
+  // QNN: result diff
+  // TRT: segmentation fault in A100
+  std::unordered_set<std::string> excluded_providers({kQnnExecutionProvider});
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100(excluded_providers));
 }
 
 TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_2DBilinear_Ver10) {
   // TODO: Unskip when fixed #41968513
   if (DefaultDmlExecutionProvider().get() != nullptr) {
-    GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.5, which exceeds threshold";
+    GTEST_SKIP() << "Skipping because of the following error: "
+                 << "The difference between expected[i] and output[i] is 0.5, which exceeds threshold";
   }
 
   OpTester test("Resize", 10);
@@ -1586,7 +1614,7 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_2DBilinear_Ver10) {
       4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 8.0f, 8.0f, 8.0f};
 
   test.AddOutput<float>("Y", {static_cast<int64_t>(H * scales[0]), static_cast<int64_t>(W * scales[1])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpLinearScalesNoOpTest_Ver10) {
@@ -1611,7 +1639,7 @@ TEST(ResizeOpTest, ResizeOpLinearScalesNoOpTest_Ver10) {
                           7.0f, 11.0f};
 
   test.AddOutput<float>("Y", {N, C, H, W}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpNearestDownSampleTest_Ver10) {
@@ -1631,7 +1659,7 @@ TEST(ResizeOpTest, ResizeOpNearestDownSampleTest_Ver10) {
   std::vector<float> Y = {1.0f, 3.0f};
 
   test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpNearestUpSampleTest_Ver10) {
@@ -1652,10 +1680,10 @@ TEST(ResizeOpTest, ResizeOpNearestUpSampleTest_Ver10) {
                           3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f};
 
   test.AddOutput<float>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
-TEST(UpsampleOpTest, ResizeOpNearestNoScaleTest_Ver10) {
+TEST(ResizeOpTest, ResizeOpNearestNoScaleTest_Ver10) {
   OpTester test("Resize", 10);
   std::vector<float> scales{1.0f, 1.0f, 1.0f, 1.0f};
 
@@ -1670,13 +1698,14 @@ TEST(UpsampleOpTest, ResizeOpNearestNoScaleTest_Ver10) {
   std::vector<float> Y = {1.0f, 2.0f, 3.0f, 4.0f};
 
   test.AddOutput<float>("Y", {N, C, H, W}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOp_MissingRoiAndMissingScalesOptionalInputs) {
   // TODO: Unskip when fixed #41968513
   if (DefaultDmlExecutionProvider().get() != nullptr) {
-    GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1876): The parameter is incorrect.";
+    GTEST_SKIP() << "Skipping because of the following error: "
+                 << "MLOperatorAuthorImpl.cpp(1876): The parameter is incorrect.";
   }
 
   OpTester test("Resize", 13);
@@ -1720,7 +1749,7 @@ void ResizeOpTypeCheck_Ver_10() {
                       3, 3, 3, 4, 4, 4};
 
   test.AddOutput<T>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpTypeCheck_Ver_10) {
@@ -1751,7 +1780,7 @@ void ResizeOpTypeCheck_Ver_11_13_18(int opset_version) {
                       3, 3, 3, 4, 4, 4};
 
   test.AddOutput<T>("Y", {N, C, static_cast<int64_t>(H * scales[2]), static_cast<int64_t>(W * scales[3])}, Y);
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(ResizeOpTest, ResizeOpTypeCheck_Ver11) {
@@ -1827,7 +1856,8 @@ template <typename T, typename T1 = int64_t>
 void TestAntialiasing(std::map<std::string, std::string> attributes,
                       std::vector<int64_t> input_shape,
                       std::vector<T> input_data,
-                      std::vector<T1> output_shape_or_scale, std::vector<T> output_data) {
+                      std::vector<T1> output_shape_or_scale, std::vector<T> output_data,
+                      gsl::span<std::string_view> excluded_ep = {}) {
   auto parse_attr = [](const std::string& str, auto typed_v) {
     using Tdata = decltype(typed_v);
     std::vector<Tdata> vect;
@@ -1891,13 +1921,24 @@ void TestAntialiasing(std::map<std::string, std::string> attributes,
   }
 
   test.AddOutput<T>("Y", output_shape, output_data);
-  // TensorRT 8.5 supports operators up to Opset 17. Temporarily exclude TensorRT EP due to accurarcy issue.
-  test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
+
+  std::unordered_set<std::string> excluded_eps;
+  std::transform(excluded_ep.begin(), excluded_ep.end(),
+                 std::inserter(excluded_eps, excluded_eps.end()), [](std::string_view ep) {
+                   return std::string(ep);
+                 });
+  // TensorRT 8.5 supports operators up to Opset 17. Temporarily exclude TensorRT EP due to accuracy issue.
+  excluded_eps.insert(kTensorrtExecutionProvider);
+  // Test is flaky on kCudaNHWCExecutionProvider
+  excluded_eps.insert(kCudaNHWCExecutionProvider);
+
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded_eps);
 }
 
 TEST(ResizeOpTest, Antialias_Bilinear_No_ExcludeOutside) {
   if (DefaultDmlExecutionProvider().get() != nullptr) {
-    GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases.";
+    GTEST_SKIP() << "Skipping because dml implementation of antialias "
+                 << "is slightly different and doesn't match in all cases.";
   }
   std::vector<float> X(16);
   std::iota(X.begin(), X.end(), 1.f);
@@ -1939,7 +1980,8 @@ TEST(ResizeOpTest, Antialias_Bilinear_dtype) {
     std::vector<int8_t> Y = {1, 3, 4,
                              6, 8, 9,
                              11, 13, 14};
-    TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y);
+    InlinedVector<std::string_view> excluded_eps = {kCudaExecutionProvider};
+    TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y, excluded_eps);
   }
   {
     std::vector<int32_t> X(16);
@@ -1982,17 +2024,21 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear) {
                           33.5f, 73.5f, 113.5f,
                           35.074074f, 75.07407f, 115.07407f,
                           36.590908f, 76.59091f, 116.59091f};
-  TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 5, 8, 3}, X, {1, 4, 5, 3}, Y);
+
+  // Nchw is not supported by CUDA Resize implementation
+  InlinedVector<std::string_view> excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider};
+  TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 5, 8, 3}, X, {1, 4, 5, 3}, Y, excluded_eps);
 }
 
 TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) {
+  InlinedVector<std::string_view> excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider};
   {
     std::vector<uint8_t> X(16);
     std::iota(X.begin(), X.end(), uint8_t(0));
     std::vector<uint8_t> Y = {1, 3, 4,
                               6, 8, 9,
                               11, 13, 14};
-    TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y);
+    TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y, excluded_eps);
   }
   {
     std::vector<int8_t> X(16);
@@ -2000,7 +2046,7 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) {
     std::vector<int8_t> Y = {1, 3, 4,
                              6, 8, 9,
                              11, 13, 14};
-    TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y);
+    TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y, excluded_eps);
   }
   {
     std::vector<int32_t> X(16);
@@ -2008,13 +2054,14 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) {
     std::vector<int32_t> Y = {1, 3, 4,
                               6, 8, 9,
                               11, 13, 14};
-    TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y);
+    TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y, excluded_eps);
   }
 }
 
 TEST(ResizeOpTest, Antialias_Trilinear_No_ExcludeOutside) {
   if (DefaultDmlExecutionProvider().get() != nullptr) {
-    GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases.";
+    GTEST_SKIP() << "Skipping because dml implementation of "
+                 << "antialias is slightly different and doesn't match in all cases.";
   }
   std::vector<float> X(16 * 4);
   std::iota(X.begin(), X.end(), 0.f);
@@ -2038,13 +2085,17 @@ TEST(ResizeOpTest, Antialias_Trilinear_ExcludeOutside) {
 
 TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) {
   if (DefaultDmlExecutionProvider().get() != nullptr) {
-    GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases.";
+    GTEST_SKIP() << "Skipping because dml implementation of antialias"
+                 << " is slightly different and doesn't match in all cases.";
   }
+
+  InlinedVector<std::string_view> excluded_eps = {kCudaExecutionProvider};
   std::vector<float> X(16 * 4 * 4);
   std::iota(X.begin(), X.end(), 0.f);
   {
     std::vector<float> Y = X;
-    TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 4}, Y);
+    TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 4}, Y,
+                     excluded_eps);
   }
   {
     std::vector<float> Y = {0.625f, 2.375f, 4.625f, 6.375f, 8.625f, 10.375f, 12.625f,
@@ -2066,7 +2117,8 @@ TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) {
                             224.625f, 226.375f, 228.625f, 230.375f, 232.625f, 234.375f, 236.625f,
                             238.375f, 240.625f, 242.375f, 244.625f, 246.375f, 248.625f, 250.375f,
                             252.625f, 254.375f};
-    TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 2}, Y);
+    TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 2}, Y,
+                     excluded_eps);
   }
   {
     std::vector<float> Y = {2.5f, 3.5f, 4.5f, 5.5f, 9.5f, 10.5f, 11.5f, 12.5f, 18.5f,
@@ -2084,7 +2136,8 @@ TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) {
                             217.5f, 218.5f, 219.5f, 220.5f, 226.5f, 227.5f, 228.5f, 229.5f, 233.5f,
                             234.5f, 235.5f, 236.5f, 242.5f, 243.5f, 244.5f, 245.5f, 249.5f, 250.5f,
                             251.5f, 252.5f};
-    TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 2, 4}, Y);
+    TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 2, 4}, Y,
+                     excluded_eps);
   }
 }
 
@@ -2124,12 +2177,15 @@ TEST(ResizeOpTest, Antialias_NHWCBicubic_ExcludeOutside) {
       19.576872f, 43.57687f, 21.126253f, 45.126255f, 22.606192f,
       46.606194f, 19.878183f, 43.87818f, 21.358122f, 45.35812f,
       22.907503f, 46.907505f, 24.387442f, 48.387444f};
-  TestAntialiasing({{"mode", "cubic"}, {"exclude_outside", "0"}}, {1, 4, 6, 2}, X, {1, 8, 4, 2}, Y);
+
+  InlinedVector<std::string_view> excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider};
+  TestAntialiasing({{"mode", "cubic"}, {"exclude_outside", "0"}}, {1, 4, 6, 2}, X, {1, 8, 4, 2}, Y, excluded_eps);
 }
 
 TEST(ResizeOpTest, Antialias_Linear_AlignCorners) {
   if (DefaultDmlExecutionProvider().get() != nullptr) {
-    GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases.";
+    GTEST_SKIP() << "Skipping because dml implementation of antialias"
+                 << "is slightly different and doesn't match in all cases.";
   }
   std::vector<float> X(256);
   std::iota(X.begin(), X.end(), 0.0f);
@@ -2145,9 +2201,40 @@ TEST(ResizeOpTest, Antialias_Linear_AlignCorners) {
       187.08333f, 195.91667f, 198.41667f, 205.91667f, 208.41667f,
       217.25f, 219.75f, 227.25f, 229.75f, 238.58333f,
       241.08333f, 248.58333f, 251.08333f};
+  InlinedVector<std::string_view> excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider};
   TestAntialiasing(
       {{"mode", "linear"}, {"exclude_outside", "0"}, {"coordinate_transformation_mode", "align_corners"}},
-      {4, 1, 4, 4, 4}, X, {4, 1, 3, 2, 2}, Y);
+      {4, 1, 4, 4, 4}, X, {4, 1, 3, 2, 2}, Y, excluded_eps);
+}
+
+TEST(ResizeOpTest, Antialias_Linear_AlignCorners_3D) {
+  if (DefaultDmlExecutionProvider().get() != nullptr) {
+    GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly "
+                 << "different and doesn't match in all cases.";
+  }
+  std::vector<float> X(256);
+  std::iota(X.begin(), X.end(), 0.0f);
+  std::vector<float> Y{
+      1.25f, 3.75f, 11.25f, 13.75f,
+      17.25f, 19.75f, 27.25f, 29.75f,
+      33.25f, 35.75f, 43.25f, 45.75f,
+      49.25f, 51.75f, 59.25f, 61.75f,
+      65.25f, 67.75f, 75.25f, 77.75f,
+      81.25f, 83.75f, 91.25f, 93.75f,
+      97.25f, 99.75f, 107.25f, 109.75f,
+      113.25f, 115.75f, 123.25f, 125.75f,
+      129.25f, 131.75f, 139.25f, 141.75f,
+      145.25f, 147.75f, 155.25f, 157.75f,
+      161.25f, 163.75f, 171.25f, 173.75f,
+      177.25f, 179.75f, 187.25f, 189.75f,
+      193.25f, 195.75f, 203.25f, 205.75f,
+      209.25f, 211.75f, 219.25f, 221.75f,
+      225.25f, 227.75f, 235.25f, 237.75f,
+      241.25f, 243.75f, 251.25f, 253.75f};
+
+  TestAntialiasing(
+      {{"mode", "linear"}, {"exclude_outside", "0"}, {"coordinate_transformation_mode", "align_corners"}},
+      {16, 4, 4}, X, {16, 2, 2}, Y);
 }
 
 TEST(ResizeOpTest, Antialias_Bicubic_ExcludeOutside) {
@@ -2166,19 +2253,23 @@ TEST(ResizeOpTest, Antialias_Bicubic_Dtype) {
     std::vector<uint8_t> X(36);
     std::iota(X.begin(), X.end(), uint8_t(0));
     std::vector<uint8_t> Y = {4, 6, 7, 16, 18, 19, 28, 30, 31};
-    TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y);
+    TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6},
+                     X, {1, 1, 3, 3}, Y);
   }
   {
     std::vector<int8_t> X(36);
     std::iota(X.begin(), X.end(), int8_t(0));
     std::vector<int8_t> Y = {4, 6, 7, 16, 18, 19, 28, 30, 31};
-    TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y);
+    InlinedVector<std::string_view> excluded_eps = {kCudaExecutionProvider};
+    TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6},
+                     X, {1, 1, 3, 3}, Y, excluded_eps);
   }
   {
     std::vector<int32_t> X(36);
     std::iota(X.begin(), X.end(), 0);
     std::vector<int32_t> Y = {4, 6, 7, 16, 18, 19, 28, 30, 31};
-    TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y);
+    TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6},
+                     X, {1, 1, 3, 3}, Y);
   }
 }
 
@@ -2189,8 +2280,10 @@ TEST(ResizeOpTest, Antialias_Axes_and_Scale) {
   std::vector<float> Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f,
                           27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f,
                           50.7f, 51.9f, 54.3f, 55.5f, 56.7f};
-  TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, {1, 1, 4, 4, 4}, X,
-                   std::vector<float>{3 / 4.0f, 3 / 4.0f, 3 / 4.0f}, Y);
+  TestAntialiasing(
+      {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}},
+      {1, 1, 4, 4, 4}, X,
+      std::vector<float>{3 / 4.0f, 3 / 4.0f, 3 / 4.0f}, Y);
 }
 
 TEST(ResizeOpTest, Antialias_Axes_and_Size) {
@@ -2199,8 +2292,10 @@ TEST(ResizeOpTest, Antialias_Axes_and_Size) {
   std::vector<float> Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f,
                           27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f,
                           50.7f, 51.9f, 54.3f, 55.5f, 56.7f};
-  TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, {1, 1, 4, 4, 4}, X,
-                   {3, 3, 3}, Y);
+  TestAntialiasing(
+      {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}},
+      {1, 1, 4, 4, 4}, X,
+      {3, 3, 3}, Y);
 }
 
 TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoLarger) {
@@ -2209,9 +2304,13 @@ TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoLarger) {
   std::vector<float> Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f,
                           27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f,
                           50.7f, 51.9f, 54.3f, 55.5f, 56.7f};
-  TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, {"policy", "not_larger"}},
-                   {1, 1, 4, 4, 4}, X,
-                   {3, 4, 5}, Y);
+  // clang-format off
+  TestAntialiasing(
+      {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"},
+       {"policy", "not_larger"}},
+      {1, 1, 4, 4, 4}, X,
+      {3, 4, 5}, Y);
+  // clang-format on
 }
 
 TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoSmaller) {
@@ -2220,9 +2319,13 @@ TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoSmaller) {
   std::vector<float> Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f,
                           27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f,
                           50.7f, 51.9f, 54.3f, 55.5f, 56.7f};
-  TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, {"policy", "not_smaller"}},
-                   {1, 1, 4, 4, 4}, X,
-                   {1, 2, 3}, Y);
+  // clang-format off
+  TestAntialiasing(
+      {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"},
+       {"policy", "not_smaller"}},
+      {1, 1, 4, 4, 4}, X,
+      {1, 2, 3}, Y);
+  // clang-format on
 }
 
 TEST(ResizeOpTest, Antialias_Use_Extrapolation) {
diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc
index 30e27bb15fa57..b1dfec7951338 100644
--- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc
@@ -268,7 +268,7 @@ static void scatter_invalid_index(const char* op_name, int op_version) {
   test.AddOutput<float>("y", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f});
   test.Run(OpTester::ExpectResult::kExpectFailure,
            "indices element out of data bounds, idx=4 must be within the inclusive range [-4,3]",
-           {kCudaExecutionProvider, kTensorrtExecutionProvider});
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider});
 }
 
 TEST(Scatter, InvalidIndex) {
@@ -291,9 +291,10 @@ static void scatter_bool_with_axis_tests(const char* op_name, int op_version) {
   test.AddOutput<bool>("y", {1, 5}, {false, true, false, false, false});
 #if defined(OPENVINO_CONFIG_GPU_FP32) || defined(OPENVINO_CONFIG_GPU_FP16)
   test.Run(OpTester::ExpectResult::kExpectSuccess, "",
-           {kOpenVINOExecutionProvider});  // OpenVINO: Disabled due to failure for GPU
+           {kCudaNHWCExecutionProvider, kOpenVINOExecutionProvider});  // OpenVINO: Disabled due to failure for GPU
 #else
-  test.Run();
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "",
+           {kCudaNHWCExecutionProvider});  // OpenVINO: Disabled due to failure for GPU
 #endif
 }
 
diff --git a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc
index 63b92cfc187bd..5222380d9ca56 100644
--- a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc
@@ -108,6 +108,53 @@ TEST(TensorOpTest, SpaceToDepthTest_2) {
   test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider});
 }
 
+TEST(TensorOpTest, SpaceToDepthTest_3) {
+  // Test swizzling with H_output > 1
+  OpTester test("SpaceToDepth");
+  constexpr int64_t blocksize = 2;
+  test.AddAttribute("blocksize", blocksize);
+  constexpr int64_t N = 1, C = 2, H = 4, W = 8;
+
+  const std::vector<float> X = {
+      0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f,
+      1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f,
+
+      2.0f, 2.1f, 2.2f, 2.3f, 2.4f, 2.5f, 2.6f, 2.7f,
+      3.0f, 3.1f, 3.2f, 3.3f, 3.4f, 3.5f, 3.6f, 3.7f,
+
+      4.0f, 4.1f, 4.2f, 4.3f, 4.4f, 4.5f, 4.6f, 4.7f,
+      5.0f, 5.1f, 5.2f, 5.3f, 5.4f, 5.5f, 5.6f, 5.7f,
+      6.0f, 6.1f, 6.2f, 6.3f, 6.4f, 6.5f, 6.6f, 6.7f,
+      7.0f, 7.1f, 7.2f, 7.3f, 7.4f, 7.5f, 7.6f, 7.7f};
+
+  test.AddInput<float>("input", {N, C, H, W}, X);
+
+  const std::vector<float> result = {
+      0.0f, 0.2f, 0.4f, 0.6f,
+      2.0f, 2.2f, 2.4f, 2.6f,
+      4.0f, 4.2f, 4.4f, 4.6f,
+      6.0f, 6.2f, 6.4f, 6.6f,
+
+      0.1f, 0.3f, 0.5f, 0.7f,
+      2.1f, 2.3f, 2.5f, 2.7f,
+      4.1f, 4.3f, 4.5f, 4.7f,
+      6.1f, 6.3f, 6.5f, 6.7f,
+
+      1.0f, 1.2f, 1.4f, 1.6f,
+      3.0f, 3.2f, 3.4f, 3.6f,
+      5.0f, 5.2f, 5.4f, 5.6f,
+      7.0f, 7.2f, 7.4f, 7.6f,
+
+      1.1f, 1.3f, 1.5f, 1.7f,
+      3.1f, 3.3f, 3.5f, 3.7f,
+      5.1f, 5.3f, 5.5f, 5.7f,
+      7.1f, 7.3f, 7.5f, 7.7f};
+
+  test.AddOutput<float>("output", {N, C * blocksize * blocksize, H / blocksize, W / blocksize}, result);
+
+  test.Run();
+}
+
 TEST(TensorOpTest, DepthToSpaceTest_1) {
   OpTester test("DepthToSpace", 7);  // create an opset 7 model
   constexpr int64_t blocksize = 2;
diff --git a/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc b/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc
index 72cb84d50f078..3ac8053aef95e 100644
--- a/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc
@@ -4,6 +4,7 @@
 #include "gtest/gtest.h"
 #include "test/providers/provider_test_utils.h"
 #include "test/util/include/default_providers.h"
+#include "test/common/trt_op_test_utils.h"
 
 namespace onnxruntime {
 namespace test {
@@ -692,7 +693,7 @@ TEST(UpsampleOpTest, NhwcUpsampleOp4D1CBilinearTest) {
   // TensorRT: results mismatch
   // ROCm: results mismatch
   test.Run(OpTester::ExpectResult::kExpectSuccess, "",
-           {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
 }
 
 TEST(UpsampleOpTest, NhwcUpsampleOp4DBilinearTest) {
@@ -766,7 +767,7 @@ TEST(UpsampleOpTest, NhwcUpsampleOp4DBilinearTest) {
   // TensorRT: results mismatch
   // ROCm: results mismatch
   test.Run(OpTester::ExpectResult::kExpectSuccess, "",
-           {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
 }
 
 TEST(UpsampleOpTest, UpsampleOp2DBilinearTest) {
@@ -886,7 +887,7 @@ TEST(UpsampleOpTest, NhwcUpsampleOp4DBilinearTest_int32) {
   // TensorRT: results mismatch
   // ROCm: results mismatch
   test.Run(OpTester::ExpectResult::kExpectSuccess, "",
-           {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
+           {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
 }
 
 TEST(UpsampleOpTest, UpsampleOpNearestTest_1D) {
@@ -939,7 +940,9 @@ TEST(UpsampleOpTest, UpsampleOpNearest2XTest_opset9) {
       7, 7, 9, 9};
 
   test.AddOutput<int32_t>("Y", {N, C, (int64_t)(H * scales[2]), (int64_t)(W * scales[3])}, Y);
-  test.Run();
+
+  // TRT: segmentation fault in A100
+  test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
 }
 
 TEST(UpsampleOpTest, NhwcUpsampleOpNearest2XTest_opset9) {
diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_test.cc
index 13d4546d669e3..b6a760f7041ad 100644
--- a/onnxruntime/test/providers/cuda/nhwc/conv_test.cc
+++ b/onnxruntime/test/providers/cuda/nhwc/conv_test.cc
@@ -9,8 +9,8 @@ namespace test {
 
 template <typename T>
 struct ConvOp {
-  const std::vector<int64_t> input_dims;
-  const std::vector<int64_t> kernel_shape;
+  std::vector<int64_t> input_dims;
+  std::vector<int64_t> kernel_shape;
   int64_t channels;
   int64_t group = 1;
   bool bias = false;
@@ -52,20 +52,31 @@ struct ConvOp {
 };
 
 TYPED_TEST(CudaNhwcTypedTest, ConvNhwcBias) {
-  auto op = ConvOp<TypeParam>{.input_dims = {1, 16, 64, 64}, .kernel_shape = {3, 3}, .channels = 16, .bias = true};
+  auto op = ConvOp<TypeParam>{};
+  op.input_dims = {1, 16, 64, 64};
+  op.kernel_shape = {3, 3};
+  op.channels = 16;
+  op.bias = true;
 
   MAKE_PROVIDERS_EPS_TYPE(TypeParam)
 }
 
 TYPED_TEST(CudaNhwcTypedTest, ConvNhwcGroupNoBias) {
-  auto op = ConvOp<TypeParam>{.input_dims = {1, 16, 64, 64}, .kernel_shape = {3, 3}, .channels = 16, .group = 4};
+  auto op = ConvOp<TypeParam>{};
+  op.input_dims = {1, 16, 64, 64};
+  op.kernel_shape = {3, 3};
+  op.channels = 16;
+  op.group = 4;
 
   MAKE_PROVIDERS_EPS_TYPE(TypeParam)
 }
 
 TYPED_TEST(CudaNhwcTypedTest, ConvNhwcPadding) {
-  auto op =
-      ConvOp<TypeParam>{.input_dims = {2, 4, 64, 64}, .kernel_shape = {3, 3}, .channels = 4, .padding = {4, 4, 4, 4}};
+  auto op = ConvOp<TypeParam>{};
+  op.input_dims = {2, 4, 64, 64};
+  op.kernel_shape = {3, 3};
+  op.channels = 4;
+  op.padding = {4, 4, 4, 4};
 
   MAKE_PROVIDERS_EPS_TYPE(TypeParam)
 }
diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc
index 6514feadf0ff7..786b2cb4cedc4 100644
--- a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc
+++ b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc
@@ -9,8 +9,8 @@ namespace test {
 
 template <typename T>
 struct ConvTransposeOp {
-  const std::vector<int64_t> input_dims;
-  const std::vector<int64_t> kernel_shape;
+  std::vector<int64_t> input_dims;
+  std::vector<int64_t> kernel_shape;
   int64_t channels;
   int64_t group = 1;
   bool bias = false;
@@ -60,15 +60,21 @@ struct ConvTransposeOp {
 };
 
 TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcGroupNoBias) {
-  auto op =
-      ConvTransposeOp<TypeParam>{.input_dims = {8, 8, 32, 32}, .kernel_shape = {3, 3}, .channels = 16, .group = 4};
+  auto op = ConvTransposeOp<TypeParam>{};
+  op.input_dims = {8, 8, 32, 32};
+  op.kernel_shape = {3, 3};
+  op.channels = 16;
+  op.group = 4;
 
   MAKE_PROVIDERS_EPS_TYPE(TypeParam)
 }
 
 TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) {
-  auto op =
-      ConvTransposeOp<TypeParam>{.input_dims = {1, 8, 80, 80}, .kernel_shape = {5, 5}, .channels = 16, .bias = true};
+  auto op = ConvTransposeOp<TypeParam>{};
+  op.input_dims = {1, 8, 80, 80};
+  op.kernel_shape = {5, 5};
+  op.channels = 16;
+  op.bias = true;
 
   if (HasCudaEnvironment(800)) {
     MAKE_PROVIDERS_EPS(1e-2)
@@ -78,21 +84,23 @@ TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) {
 }
 
 TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcPad) {
-  auto op = ConvTransposeOp<TypeParam>{.input_dims = {1, 16, 8, 8},
-                                       .kernel_shape = {3, 3},
-                                       .channels = 32,
-                                       .padding = {2, 2, 2, 2},
-                                       .output_padding = {}};
+  auto op = ConvTransposeOp<TypeParam>{};
+  op.input_dims = {1, 16, 8, 8};
+  op.kernel_shape = {3, 3};
+  op.channels = 32;
+  op.padding = {2, 2, 2, 2};
+  op.output_padding = {};
 
   MAKE_PROVIDERS_EPS_TYPE(TypeParam)
 }
 
 TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcOutPad) {
-  auto op = ConvTransposeOp<TypeParam>{.input_dims = {1, 32, 8, 8},
-                                       .kernel_shape = {3, 3},
-                                       .channels = 32,
-                                       .strides = {2, 2},
-                                       .output_padding = {1, 1, 1, 1}};
+  auto op = ConvTransposeOp<TypeParam>{};
+  op.input_dims = {1, 32, 8, 8};
+  op.kernel_shape = {3, 3};
+  op.channels = 32;
+  op.strides = {2, 2};
+  op.output_padding = {1, 1, 1, 1};
 
   MAKE_PROVIDERS_EPS_TYPE(TypeParam)
 }
diff --git a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h
index 2c942bb790096..82b6a286409cd 100644
--- a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h
+++ b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h
@@ -16,11 +16,13 @@
 
 #define MAKE_PROVIDERS_EPS(eps)                                           \
   std::vector<std::shared_ptr<IExecutionProvider>> execution_providers;   \
-  OrtCUDAProviderOptionsV2 nhwc = {.prefer_nhwc = true};                  \
+  OrtCUDAProviderOptionsV2 nhwc{};                                        \
+  nhwc.prefer_nhwc = true;                                                \
   execution_providers.push_back(CudaExecutionProviderWithOptions(&nhwc)); \
                                                                           \
   double error_tolerance = eps;                                           \
-  OrtCUDAProviderOptionsV2 nchw = {.prefer_nhwc = false};                 \
+  OrtCUDAProviderOptionsV2 nchw{};                                        \
+  nchw.prefer_nhwc = false;                                               \
   auto source_ep = CudaExecutionProviderWithOptions(&nchw);               \
   auto test = op.get_test();                                              \
   test->CompareEPs(std::move(source_ep), execution_providers, error_tolerance);
diff --git a/onnxruntime/test/providers/cuda/nhwc/norm_test.cc b/onnxruntime/test/providers/cuda/nhwc/norm_test.cc
index 52da8ba557c2d..40f69e3bd5b4f 100644
--- a/onnxruntime/test/providers/cuda/nhwc/norm_test.cc
+++ b/onnxruntime/test/providers/cuda/nhwc/norm_test.cc
@@ -9,7 +9,7 @@ namespace test {
 
 template <typename T>
 struct BatchNormOp {
-  const std::vector<int64_t> input_dims;
+  std::vector<int64_t> input_dims;
 
   std::unique_ptr<CompareOpTester> get_test() {
     // create rand inputs
@@ -40,9 +40,8 @@ struct BatchNormOp {
 };
 
 TYPED_TEST(CudaNhwcTypedTest, BatchNormNhwc) {
-  auto op = BatchNormOp<TypeParam>{
-      .input_dims = {4, 16, 64, 64},
-  };
+  auto op = BatchNormOp<TypeParam>{};
+  op.input_dims = {4, 16, 64, 64};
 
   MAKE_PROVIDERS()
 }
diff --git a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc
index e0d59901da80c..426170b9588f1 100644
--- a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc
+++ b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc
@@ -9,9 +9,9 @@ namespace test {
 
 template <typename T>
 struct PoolOp {
-  const std::string pooling_type;
-  const std::vector<int64_t> input_dims;
-  const std::vector<int64_t> kernel_shape;
+  std::string pooling_type;
+  std::vector<int64_t> input_dims;
+  std::vector<int64_t> kernel_shape;
   int64_t channels;
   int64_t group = 1;
   std::vector<int64_t> strides = {1, 1};
@@ -41,22 +41,21 @@ struct PoolOp {
 };
 
 TYPED_TEST(CudaNhwcTypedTest, AveragePoolNhwc) {
-  auto op = PoolOp<TypeParam>{
-      .pooling_type = "AveragePool",
-      .input_dims = {1, 16, 64, 64},
-      .kernel_shape = {3, 3},
-      .channels = 16,
-  };
+  auto op = PoolOp<TypeParam>{};
+  op.pooling_type = "AveragePool";
+  op.input_dims = {1, 16, 64, 64};
+  op.kernel_shape = {3, 3};
+  op.channels = 16;
+
   MAKE_PROVIDERS()
 }
 
 TYPED_TEST(CudaNhwcTypedTest, MaxPoolNhwc) {
-  auto op = PoolOp<TypeParam>{
-      .pooling_type = "MaxPool",
-      .input_dims = {1, 16, 64, 64},
-      .kernel_shape = {3, 3},
-      .channels = 16,
-  };
+  auto op = PoolOp<TypeParam>{};
+  op.pooling_type = "MaxPool";
+  op.input_dims = {1, 16, 64, 64};
+  op.kernel_shape = {3, 3};
+  op.channels = 16;
   MAKE_PROVIDERS()
 }
 
@@ -72,21 +71,24 @@ TYPED_TEST(CudaNhwcTypedTest, GlobalMaxPoolNhwc) {
   test->AddOutput<TypeParam>("Y", output_dims, output_data);
 
   std::vector<std::shared_ptr<IExecutionProvider>> execution_providers;
-  OrtCUDAProviderOptionsV2 nhwc = {.prefer_nhwc = true};
+  OrtCUDAProviderOptionsV2 nhwc{};
+  nhwc.prefer_nhwc = true;
   execution_providers.push_back(CudaExecutionProviderWithOptions(&nhwc));
 
   double error_tolerance = 1e-3;
-  OrtCUDAProviderOptionsV2 nchw = {.prefer_nhwc = false};
+  OrtCUDAProviderOptionsV2 nchw{};
+  nchw.prefer_nhwc = false;
   auto source_ep = CudaExecutionProviderWithOptions(&nchw);
   test->CompareEPs(std::move(source_ep), execution_providers, error_tolerance);
 }
 
 TYPED_TEST(CudaNhwcTypedTest, AveragePoolNhwcPad) {
-  auto op = PoolOp<TypeParam>{.pooling_type = "AveragePool",
-                              .input_dims = {1, 16, 64, 64},
-                              .kernel_shape = {3, 3},
-                              .channels = 16,
-                              .padding = {2, 2, 2, 2}};
+  auto op = PoolOp<TypeParam>{};
+  op.pooling_type = "AveragePool";
+  op.input_dims = {1, 16, 64, 64};
+  op.kernel_shape = {3, 3};
+  op.channels = 16;
+  op.padding = {2, 2, 2, 2};
 
   MAKE_PROVIDERS()
 }
diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h
new file mode 100644
index 0000000000000..bbe370675fc48
--- /dev/null
+++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h
@@ -0,0 +1,188 @@
+/**
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the MIT License.
+ *
+ * Module Name:
+ *    blkq4_fp16_gemm_sm80.h
+ *
+ * Abstract:
+ *   Bridge between gtest code and gemm kernel implementation.
+ *   Gemm kernel requires CUTLASS header files, which causes strange
+ *   compilation errors with RE2 header files, which are required
+ *   by gtest.
+ */
+
+#pragma once
+
+#include <random>
+
+#include "core/util/matrix_layout.h"
+#include "core/common/common.h"
+#include "core/mickey/blk_q4/f16_prepack_sm80.h"
+#include "test/cuda_host/blkq4_fp16_quant_sm80.h"
+
+namespace onnxruntime {
+namespace cuda {
+namespace test {
+
+Status sm80_supported();
+
+/**
+ * @brief Generate a set of quantized weights, scales and offsets
+ *        and dequantized weights for testing quantization and
+ *        dequantization. All outputs are column major layout.
+ *
+ * @tparam ElementT The type of the dequantized weights.
+ * @tparam block_size The block size of the quantization.
+ * @tparam col_blocking Whether to use column blocking (all elements of
+ *                      a block comes from a single column) or row blocking
+ * @tparam has_offsets Whether to generate offsets.
+ *
+ * @param[in]  rows The number of rows of the weight matrix.
+ * @param[in]  columns The number of columns of the weight matrix.
+ * @param[out] dequants The dequantized weights, column major layout.
+ * @param[out] q_weights The quantized weights, column major layout.
+ * @param[out] q_scales The scales, column major layout.
+ * @param[out] q_zp The zero points, column major layout.
+ */
+template <typename ElementT, int block_size, bool col_blocking, bool has_offsets>
+inline void blkq4_weights_gen(
+    int rows, int columns,
+    std::vector<ElementT>& dequants,
+    std::vector<uint8_t>& q_weights,
+    std::vector<ElementT>& q_scales,
+    std::vector<uint8_t>& q_zp) {
+  using Base = onnxruntime::cuda::BlockwiseQuantization<
+      ElementT,
+      block_size,
+      4,
+      col_blocking>;
+
+  using QuantBlocking = typename Base::QuantBlocking;
+  using ElementW = typename Base::ElementW;
+  using LayoutWPack = typename Base::LayoutWPack;
+  using ElementQOffset = typename Base::ElementQOffset;
+
+  static_assert(std::is_same<ElementW, uint8_t>::value);
+  static_assert(std::is_same<ElementQOffset, uint8_t>::value);
+  static_assert(std::is_same<LayoutWPack, ColumnMajorLayout>::value);
+
+  unsigned int seed = 28571;  // Replace with desired seed value
+  std::seed_seq seq{seed};
+  std::mt19937 gen(seq);
+  std::uniform_int_distribution<uint32_t> dis(0, 8192);
+
+  const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns);
+  const auto meta_shape = Base::get_quant_meta_shape(rows, columns);
+  const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]);
+
+  //
+  // For testing quantization and dequantization, it is not straight
+  // forward to avoid flaky tests due to rounding errors. The way we
+  // try to achieve this is to:
+  // 1. Generate a set of quantized weights, scales and offsets
+  // 2. Dequantize the weights
+  // 3. Quantize the dequantized weights
+  // 4. Compare the dequantied-and-then-quantized weights with
+  //    the original quantized weights
+  //
+  // Random filling of the initial values are key to get this right.
+  // For weights, we must ensure each block gets a full range of
+  // values, i.e. must contain 0 and 15. And for scales, they must
+  // all be positive.
+  //
+
+  q_weights.resize(q_weight_shape.product());
+  MatrixRef<ElementW, ColumnMajorLayout, true> tensor_q_weight(
+      q_weights, make_Position(rows / 2, columns));
+  int v = 7;
+  for (int c = 0; c < tensor_q_weight.shape()[1]; c++) {
+    for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) {
+      uint8_t v0 = static_cast<uint8_t>(v);
+      v = (v + 5) % 16;
+      if (v == 11 || v == 7 || v == 3) {
+        // making the cycle 13 instead of 16, avoiding same values in a row
+        v = (v + 5) % 16;
+      }
+      uint8_t v1 = 0;
+      if (r + 1 < rows) {
+        v1 = static_cast<uint8_t>(v);
+        v = (v + 5) % 16;
+        if (v == 11 || v == 7 || v == 3) {
+          // making the cycle 13 instead of 16, avoiding same values in a row
+          v = (v + 5) % 16;
+        }
+      }
+
+      tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0);
+    }
+  }
+
+  q_scales.resize(meta_shape.product());
+  for (size_t i = 0; i < q_scales.size(); i++) {
+    uint32_t v = dis(gen);
+    uint32_t m = (v % 63) + 1;
+    uint32_t e = (v >> 6) % 4;
+    q_scales[i] = ElementT(m / static_cast<float>(1 << (2 + e)));
+  }
+  MatrixRef<ElementT, ColumnMajorLayout, true> tensor_scale(
+      q_scales, meta_shape);
+
+  MatrixRef<ElementQOffset, ColumnMajorLayout, true> tensor_offset;
+  if constexpr (has_offsets) {
+    q_zp.resize(zp_shape.product());
+    tensor_offset = MatrixRef<ElementQOffset, ColumnMajorLayout, true>(
+        q_zp, zp_shape);
+    for (int c = 0; c < zp_shape[1]; c++) {
+      for (int r = 0; r < zp_shape[0]; ++r) {
+        uint8_t v0 = dis(gen) % 16;
+        uint8_t v1 = 8;
+        if (r * 2 + 1 < meta_shape[0]) {
+          v1 = dis(gen) % 16;
+        }
+        tensor_offset.at(r, c) = static_cast<uint8_t>(v0 | (v1 << 4));
+      }
+    }
+  }
+
+  dequants.resize(rows * columns);
+  MatrixRef<ElementT, ColumnMajorLayout> tensor_dequant(dequants, make_Position(rows, columns));
+
+  // Dequantize weights and save into matrix B
+  for (int col = 0; col < tensor_dequant.shape()[1]; ++col) {
+    for (int row = 0; row < tensor_dequant.shape()[0]; ++row) {
+      auto weight_cord = make_Position(row / 2, col);
+      auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn);
+      uint8_t offset = 8;
+      if constexpr (has_offsets) {
+        if (scale_cord[0] % 2 == 0) {
+          offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) & 0x0f;
+        } else {
+          offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) >> 4;
+        }
+      }
+      int w = 0;
+      if (row % 2 == 0) {
+        w = int(tensor_q_weight.at(weight_cord) & 0x0f);
+      } else {
+        w = int(tensor_q_weight.at(weight_cord) >> 4);
+      }
+      float scale = float(tensor_scale.at(scale_cord));
+      float dequant = scale * float(w - offset);
+      tensor_dequant.at(row, col) = ElementT(dequant);
+      // Prints for help debugging in case of test failure
+      // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant);
+    }
+  }
+}
+
+template <
+    int block_size,
+    bool column_wise_blocking,
+    bool small_m,
+    bool has_offsets>
+void run_blkq4_gemm(int m, int n, int k);
+
+}  // namespace test
+}  // namespace cuda
+}  // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc
new file mode 100644
index 0000000000000..e687ae73e66f2
--- /dev/null
+++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc
@@ -0,0 +1,330 @@
+/**
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the MIT License.
+ *
+ * Module Name:
+ *    blkq4_fp16_gemm_sm80_test.cc
+ *
+ * Abstract:
+ *   Test code for block-wise quantized 4b GEMM kernels.
+ *   This part requires gtest header files, which do not play
+ *   well with CUTLASS headers.
+ */
+
+#include <random>
+
+#include "core/framework/float16.h"
+#include "core/mlas/inc/mlas_q4.h"
+
+#include "blkq4_fp16_gemm_sm80.h"
+
+#include "gtest/gtest.h"
+
+namespace onnxruntime {
+namespace test {
+
+template <bool col_blocking, bool has_offset = true>
+void testPrepack(int rows, int columns) {
+  using ElementT = MLFloat16;
+  constexpr int block_size = 32;
+  using Base = onnxruntime::cuda::BlockwiseQuantization<
+      ElementT,
+      block_size,
+      4,
+      col_blocking>;
+
+  using QuantBlocking = typename Base::QuantBlocking;
+  using ElementW = typename Base::ElementW;
+  using LayoutWPack = typename Base::LayoutWPack;
+  using ElementQOffset = typename Base::ElementQOffset;
+  using LayoutQmeta = typename Base::LayoutQmeta;
+
+  const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns);
+  const auto meta_shape = Base::get_quant_meta_shape(rows, columns);
+  const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]);
+
+  std::vector<ElementW> q_weights;
+  std::vector<ElementT> q_scales;
+  std::vector<ElementQOffset> q_zp;
+  std::vector<ElementT> dequants;
+  onnxruntime::cuda::test::blkq4_weights_gen<ElementT, block_size, col_blocking, has_offset>(
+      rows, columns, dequants, q_weights, q_scales, q_zp);
+
+  // for quantization tool, the input is row major, all outputs are column major
+  MatrixRef<ElementW, ColumnMajorLayout, true> tensor_q_weight(
+      q_weights, make_Position(rows / 2, columns));
+  MatrixRef<ElementT, ColumnMajorLayout, true> tensor_scale(
+      q_scales, meta_shape);
+  MatrixRef<ElementQOffset, ColumnMajorLayout, true> tensor_offset;
+  if constexpr (has_offset) {
+    tensor_offset = MatrixRef<ElementQOffset, ColumnMajorLayout, true>(q_zp, zp_shape);
+  }
+
+  // for quantization tool, the input is row major, test weight gen output is column major
+  std::vector<ElementT> dequants_transposed(dequants.size());
+  MatrixRef<ElementT, ColumnMajorLayout> tensor_dequant(dequants, make_Position(rows, columns));
+  MatrixRef<ElementT, RowMajorLayout> tensor_dequant_transposed(dequants_transposed, make_Position(rows, columns));
+  for (int col = 0; col < tensor_dequant.shape()[1]; ++col) {
+    for (int row = 0; row < tensor_dequant.shape()[0]; ++row) {
+      tensor_dequant_transposed.at(row, col) = tensor_dequant.at(row, col);
+    }
+  }
+
+  int q_rows, q_cols;
+  MlasBlockwiseQuantizedShape<ElementT, 4>(
+      block_size, col_blocking, rows, columns, q_rows, q_cols);
+  // to be exact, q_rows are padded to multiple of block_size, deal with it when we care about strange shapes
+  EXPECT_EQ(q_rows, q_weight_shape[0]);
+  EXPECT_EQ(q_cols, q_weight_shape[1]);
+
+  //
+  // Quantization tool outputs:
+  //
+  std::vector<ElementW> o_elements(q_rows * q_cols);
+  MatrixRef<ElementW, ColumnMajorLayout, true> tensor_o_elements(o_elements, q_weight_shape);
+
+  std::vector<ElementT> o_scales(meta_shape.product());
+  MatrixRef<ElementT, ColumnMajorLayout, true> tensor_o_scales(o_scales, meta_shape);
+
+  std::vector<uint8_t> o_zp(zp_shape.product());
+  MatrixRef<uint8_t, ColumnMajorLayout, true> tensor_o_zp(o_zp, zp_shape);
+
+  MlasQuantizeBlockwise<MLFloat16, 4>(o_elements.data(), o_scales.data(), has_offset ? o_zp.data() : nullptr,
+                                      dequants_transposed.data(), block_size,
+                                      col_blocking, rows, columns, columns, nullptr);
+  for (int col = 0; col < tensor_q_weight.shape()[1]; ++col) {
+    for (int row = 0; row < tensor_q_weight.shape()[0]; ++row) {
+      EXPECT_EQ(tensor_o_elements.at(row, col), tensor_q_weight.at(row, col))
+          << "quantized value mismatch at [" << row << "," << col << "]"
+          << " shape[" << rows << "," << columns << "]"
+          << (col_blocking ? "Column-wise-block" : "Row-wise-block")
+          << std::endl;
+    }
+  }
+
+  for (int col = 0; col < meta_shape[1]; ++col) {
+    for (int row = 0; row < meta_shape[0]; row += 2) {
+      if (has_offset) {
+        uint8_t pair01 = tensor_o_zp.at(row / 2, col);
+        uint8_t expected_pair01 = tensor_offset.at(row / 2, col);
+        EXPECT_EQ(expected_pair01 & 0xf, pair01 & 0xf)
+            << "quantized offset mismatch at [" << row << "," << col << "]"
+            << " shape[" << rows << "," << columns << "]"
+            << (col_blocking ? "Column-wise-block" : "Row-wise-block")
+            << std::endl;
+        if (row + 1 < meta_shape[0]) {
+          EXPECT_EQ(expected_pair01 >> 4, pair01 >> 4)
+              << "quantized offset mismatch at [" << row + 1 << "," << col << "]"
+              << " shape[" << rows << "," << columns << "]"
+              << (col_blocking ? "Column-wise-block" : "Row-wise-block")
+              << std::endl;
+        }
+      }
+
+      EXPECT_EQ(tensor_scale.at(row + 0, col), tensor_o_scales.at(row + 0, col))
+          << "quantized scale mismatch at [" << row << "," << col << "]"
+          << " shape[" << rows << "," << columns << "]"
+          << (col_blocking ? "Column-wise-block" : "Row-wise-block")
+          << std::endl;
+      if (row + 1 < meta_shape[0]) {
+        EXPECT_EQ(tensor_scale.at(row + 1, col), tensor_o_scales.at(row + 1, col))
+            << "quantized scale mismatch at [" << row + 1 << "," << col << "]"
+            << " shape[" << rows << "," << columns << "]"
+            << (col_blocking ? "Column-wise-block" : "Row-wise-block")
+            << std::endl;
+      }
+    }
+  }
+
+  //
+  // Now we just setup quantized weights tensor_q_weight, quantization scale tensor_scale
+  // and quantization offset tensor_offset. The above tests just make sure our setup is
+  // consistent with quantization tool output.
+  //
+  // Next we test the prepack code
+  //
+
+  std::vector<ElementW> packed_w_ref(q_weight_shape.product());
+  MatrixRef<ElementW, LayoutWPack, true> tensor_packed_w_ref(
+      packed_w_ref, make_Position(rows, columns / 2));
+  onnxruntime::test::sm80_prepack_weights_ref(rows, columns, tensor_q_weight, tensor_packed_w_ref);
+
+  std::vector<ElementW> packed_w(q_weight_shape.product());
+  MatrixRef<ElementW, LayoutWPack, true> tensor_packed_w(
+      packed_w, make_Position(rows, columns / 2));
+  Base::prepack_weights(rows, columns, o_elements, packed_w);
+
+  for (int col = 0; col < tensor_packed_w.shape()[1]; ++col) {
+    for (int row = 0; row < tensor_packed_w.shape()[0]; ++row) {
+      EXPECT_EQ(tensor_packed_w_ref.at(row, col), tensor_packed_w.at(row, col))
+          << "prepacked weights mismatch at [" << row << "," << col << "]"
+          << " shape[" << rows << "," << columns << "]"
+          << (col_blocking ? "Column-wise-block" : "Row-wise-block")
+          << std::endl;
+    }
+  }
+
+  std::vector<ElementT> packed_scales_ref(meta_shape.product());
+  MatrixRef<ElementT, LayoutQmeta, true> tensor_packed_s_ref =
+      make_MatrixRef<ElementT, LayoutQmeta, true>(packed_scales_ref, meta_shape);
+  if constexpr (Base::ShouldRearrangeMeta) {
+    onnxruntime::test::sm80_prepack_quant_scales_ref<ElementT, LayoutQmeta, QuantBlocking>(
+        rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref);
+  } else {
+    for (int col = 0; col < tensor_packed_s_ref.shape()[1]; ++col) {
+      for (int row = 0; row < tensor_packed_s_ref.shape()[0]; ++row) {
+        tensor_packed_s_ref.at(row, col) = tensor_scale.at(row, col);
+      }
+    }
+  }
+
+  std::vector<ElementT> packed_scales(meta_shape.product());
+  MatrixRef<ElementT, LayoutQmeta, true> tensor_packed_s(
+      packed_scales, meta_shape);
+  Base::prepack_quant_scales(rows, columns, o_scales, packed_scales);
+
+  for (int col = 0; col < tensor_packed_s.shape()[1]; ++col) {
+    for (int row = 0; row < tensor_packed_s.shape()[0]; ++row) {
+      EXPECT_EQ(tensor_packed_s_ref.at(row, col), tensor_packed_s.at(row, col))
+          << "prepacked scales mismatch at [" << row << "," << col << "]"
+          << " shape[" << rows << "," << columns << "]"
+          << (col_blocking ? "Column-wise-block" : "Row-wise-block")
+          << std::endl;
+    }
+  }
+
+  if (has_offset) {
+    std::vector<ElementQOffset> packed_zp_ref(meta_shape.product());
+    MatrixRef<ElementQOffset, LayoutQmeta, true> tensor_packed_zp_ref =
+        make_MatrixRef<ElementQOffset, LayoutQmeta, true>(packed_zp_ref, meta_shape);
+    if constexpr (Base::ShouldRearrangeMeta) {
+      onnxruntime::test::sm80_prepack_quant_offsets_ref<LayoutQmeta, QuantBlocking>(
+          rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref);
+    } else {
+      for (int col = 0; col < meta_shape[1]; ++col) {
+        for (int row = 0; row < meta_shape[0]; row += 2) {
+          uint8_t pair01 = tensor_offset.at(row / 2, col);
+          tensor_packed_zp_ref.at(row, col) = pair01 & 0xf;
+          if (row + 1 < meta_shape[0]) {
+            tensor_packed_zp_ref.at(row + 1, col) = pair01 >> 4;
+          }
+        }
+      }
+    }
+
+    std::vector<ElementQOffset> packed_zp(meta_shape.product());
+    MatrixRef<ElementQOffset, LayoutQmeta, true> tensor_packed_zp(
+        packed_zp, meta_shape);
+    Base::prepack_quant_offsets(rows, columns, o_zp, packed_zp);
+
+    for (int col = 0; col < tensor_packed_zp.shape()[1]; ++col) {
+      for (int row = 0; row < tensor_packed_zp.shape()[0]; ++row) {
+        EXPECT_EQ(tensor_packed_zp_ref.at(row, col), tensor_packed_zp.at(row, col))
+            << "prepacked offsets mismatch at [" << row << "," << col << "]"
+            << " shape[" << rows << "," << columns << "]"
+            << (col_blocking ? "Column-wise-block" : "Row-wise-block")
+            << std::endl;
+      }
+    }
+  }
+}
+
+// TODO: code runs on CPU, but this is for sm80 only, maybe enable only when test on sm80
+TEST(BlkQ4_GEMM, PrepackSm80Test) {
+  Status status = onnxruntime::cuda::test::sm80_supported();
+  if (!status.IsOK()) {
+    // skip the test if sm80 is not supported
+    return;
+  }
+
+  testPrepack<false>(32, 32);
+  testPrepack<false, false>(32, 32);
+  testPrepack<true>(32, 32);
+  testPrepack<true, false>(32, 32);
+  testPrepack<false>(32, 64);
+  testPrepack<false>(32, 128);
+  testPrepack<false>(32, 256);
+  testPrepack<false>(64, 32);
+  testPrepack<false>(128, 32);
+  testPrepack<false>(256, 32);
+  testPrepack<false>(256, 256);
+  testPrepack<false, false>(32, 128);
+  testPrepack<false, false>(128, 32);
+  testPrepack<false, false>(256, 256);
+  testPrepack<true>(32, 64);
+  testPrepack<true>(32, 128);
+  testPrepack<true>(32, 256);
+  testPrepack<true>(64, 32);
+  testPrepack<true>(128, 32);
+  testPrepack<true>(256, 32);
+  testPrepack<true>(256, 256);
+  testPrepack<true, false>(32, 128);
+  testPrepack<true, false>(128, 32);
+  testPrepack<true, false>(256, 256);
+}
+
+TEST(BlkQ4_GEMM, Sm80RowBlockingTest) {
+  Status status = onnxruntime::cuda::test::sm80_supported();
+  if (!status.IsOK()) {
+    // skip the test if sm80 is not supported
+    return;
+  }
+
+  onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 32, 64);
+  onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 32, 64);
+
+  onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 96, 64);
+  onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 96, 64);
+
+  onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 96, 192);
+  onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 96, 192);
+
+  onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(256, 672, 576);
+  onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(256, 672, 576);
+
+  onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(512, 2048 + 32, 960);
+  onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(512, 2048 + 32, 960);
+
+  onnxruntime::cuda::test::run_blkq4_gemm<16, false, false, false>(256, 672, 576);
+  onnxruntime::cuda::test::run_blkq4_gemm<16, false, false, true>(256, 672, 576);
+
+  onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, false>(256, 1024, 576);
+  onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, true>(256, 1024, 576);
+}
+
+TEST(BlkQ4_GEMM, Sm80ColBlockingTest) {
+  Status status = onnxruntime::cuda::test::sm80_supported();
+  if (!status.IsOK()) {
+    // skip the test if sm80 is not supported
+    return;
+  }
+  onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, false>(64, 672, 576);
+  onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, true>(64, 672, 576);
+
+  onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, false>(256, 1024, 576);
+  onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, true>(256, 1024, 576);
+}
+
+TEST(BlkQ4_GEMM, Sm80SmallMTest) {
+  Status status = onnxruntime::cuda::test::sm80_supported();
+  if (!status.IsOK()) {
+    // skip the test if sm80 is not supported
+    return;
+  }
+
+  // // small m
+  onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, false>(16, 704, 576);
+  onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, true>(16, 704, 576);
+
+  onnxruntime::cuda::test::run_blkq4_gemm<64, false, true, false>(16, 1024, 576);
+  onnxruntime::cuda::test::run_blkq4_gemm<64, false, true, true>(16, 1024, 576);
+
+  onnxruntime::cuda::test::run_blkq4_gemm<16, true, true, false>(16, 672, 576);
+  onnxruntime::cuda::test::run_blkq4_gemm<16, true, true, true>(16, 672, 576);
+
+  onnxruntime::cuda::test::run_blkq4_gemm<64, true, true, false>(16, 1024, 576);
+  onnxruntime::cuda::test::run_blkq4_gemm<64, true, true, true>(16, 1024, 576);
+}
+
+}  // namespace test
+}  // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu
new file mode 100644
index 0000000000000..69c929d446ce4
--- /dev/null
+++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu
@@ -0,0 +1,344 @@
+/**
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the MIT License.
+ *
+ * Module Name:
+ *    blkq4_fp16_gemm_sm80_testcu.cu
+ *
+ * Abstract:
+ *   Test code for invoking block-wise quantized 4b GEMM kernels.
+ *   This part requires CUTLASS header files, which do not play
+ *   well with gtest headers.
+ */
+
+#include <random>
+#include <thrust/host_vector.h>
+#include <thrust/device_vector.h>
+
+#include "core/mickey/blk_q4/f16_gemm_sm80.h"
+
+#include "cutlass/util/host_tensor.h"
+#include "cutlass/util/reference/device/gemm.h"
+#include "cutlass/util/reference/host/tensor_compare.h"
+#include "cutlass/util/reference/host/tensor_copy.h"
+#include "cutlass/util/reference/host/tensor_fill.h"
+#include "cutlass/util/tensor_view_io.h"
+
+#include "core/common/common.h"
+
+#include "blkq4_fp16_gemm_sm80.h"
+
+namespace onnxruntime {
+namespace cuda{
+namespace test{
+
+Status sm80_supported(){
+  cudaDeviceProp props;
+
+  cudaError_t error = cudaGetDeviceProperties(&props, 0);
+  if (error != cudaSuccess) {
+    std::ostringstream ss;
+    ss << "Unable to obtain GPU device properties: " << cudaGetErrorString(error);
+    return Status(common::ONNXRUNTIME, common::ENGINE_ERROR, ss.str());
+  }
+
+  if (!((props.major * 10 + props.minor) >= 80)) {
+    std::ostringstream ss;
+    ss << "Device compute capability mismatch, desired 8.0, actual " << props.major << "." << props.minor;
+    return Status(common::ONNXRUNTIME, common::ENGINE_ERROR, ss.str());
+  }
+  return Status::OK();
+}
+
+/**
+ * @brief Reference implementation of GEMM
+ *        Copied directly from cutlass util/reference/device/gemm.h
+ *        for the strange reason that compiler insists on asking
+ *        for explicit stream argument in kernel launch.
+*/
+template <
+  typename ElementA,
+  typename LayoutA,
+  typename ElementB,
+  typename LayoutB,
+  typename ElementC,
+  typename LayoutC,
+  typename ScalarType,
+  typename AccumulatorType
+>
+void compute_gemm_ref(
+  cutlass::gemm::GemmCoord problem_size,
+  ScalarType alpha,
+  cutlass::TensorRef<ElementA, LayoutA> tensor_a,
+  cutlass::TensorRef<ElementB, LayoutB> tensor_b,
+  ScalarType beta,
+  cutlass::TensorRef<ElementC, LayoutC> tensor_c,
+  cutlass::TensorRef<ElementC, LayoutC> tensor_d,
+  AccumulatorType initial_accum = AccumulatorType(0)) {
+
+  // Blocking structure potentially improves performance of reference implementation
+  // with a minor increase in complexity.
+  //
+  // Note, this reference implementation is NOT expected to approach peak performance.
+  using OutputTile = cutlass::MatrixShape<4, 4>;
+
+  dim3 block(16, 8);
+
+  dim3 grid(
+    (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
+    (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn)
+  );
+
+  // Launch a GEMM kernel
+  cutlass::reference::device::kernel::Gemm<
+    cutlass::TensorRef<ElementA, LayoutA>,
+    cutlass::TensorRef<ElementB, LayoutB>,
+    cutlass::TensorRef<ElementC, LayoutC>,
+    ScalarType,
+    AccumulatorType,
+    OutputTile,
+    cutlass::multiply_add<AccumulatorType>,
+    cutlass::NumericConverter<ElementC, ScalarType>
+  ><<<grid, block, 0, 0>>>(
+    problem_size,
+    alpha,
+    tensor_a,
+    tensor_b,
+    beta,
+    tensor_c,
+    tensor_d,
+    initial_accum
+  );
+}
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+//
+// Converting cutlass tensor to MatrixRef
+//
+
+template <
+  typename Element,
+  typename LayoutCutlass,
+  typename Layout = std::conditional_t<std::is_same<LayoutCutlass, cutlass::layout::ColumnMajor>::value, ColumnMajorLayout, RowMajorLayout>
+  >
+__forceinline__
+MatrixRef<Element, Layout, true> make_MatrixRef(cutlass::HostTensor<Element, LayoutCutlass> const& tensor) {
+  static_assert(std::is_same<LayoutCutlass, cutlass::layout::ColumnMajor>::value
+                || std::is_same<LayoutCutlass, cutlass::layout::RowMajor>::value);
+  auto shape = make_Position(tensor.extent().row(), tensor.extent().column());
+  auto* ptr = const_cast<typename std::remove_const<Element>::type *>(tensor.host_data());
+  return MatrixRef<Element, Layout, true>(ptr, tensor.capacity(), shape);
+}
+
+template <
+  typename Element,
+  typename LayoutCutlass,
+  typename Layout = std::conditional_t<std::is_same<LayoutCutlass, cutlass::layout::ColumnMajor>::value, ColumnMajorLayout, RowMajorLayout>
+  >
+__forceinline__
+MatrixRef<Element const, Layout, true> make_ConstMatrixRef(cutlass::HostTensor<Element, LayoutCutlass> const& tensor) {
+  static_assert(std::is_same<LayoutCutlass, cutlass::layout::ColumnMajor>::value
+                || std::is_same<LayoutCutlass, cutlass::layout::RowMajor>::value);
+  auto shape = make_Position(tensor.extent().row(), tensor.extent().column());
+  return MatrixRef<Element const, Layout, true>(tensor.host_data(), tensor.capacity(), shape);
+}
+
+//
+// Invoking the kernel
+//
+
+template<
+    int block_size,
+    bool column_wise_blocking,
+    bool small_m,
+    bool has_offsets>
+void run_blkq4_gemm(int m, int n, int k) {
+  unsigned int seed = 28571;  // Replace with desired seed value
+  std::seed_seq seq{seed};
+  std::mt19937 gen(seq);
+  std::uniform_int_distribution<> dis(0, 8192);
+
+  using ElementDequant = cutlass::half_t;
+  using QuantBlocking =
+    typename std::conditional<column_wise_blocking,
+                     cutlass::MatrixShape<block_size, 1>,
+                     cutlass::MatrixShape<1, block_size>>::type;
+
+  using GemmRunner = BlkQ4F16GemmImpl<ElementDequant, QuantBlocking, small_m, has_offsets>;
+
+  using ElementAccumulator = typename GemmRunner::ElementAccumulator;
+  using ElementComputeEpilogue = typename GemmRunner::ElementComputeEpilogue;
+  using ElementInputA = typename GemmRunner::ElementInputA;
+  using ElementOutput = typename GemmRunner::ElementOutput;
+  using ElementW = typename GemmRunner::ElementW;
+  using ElementWPack = typename GemmRunner::ElementWPack;
+  using ElementQScale = typename GemmRunner::ElementQScale;
+  using ElementQOffset = typename GemmRunner::ElementQOffset;
+
+  using LayoutInputA = typename GemmRunner::LayoutInputA;
+  using LayoutOutput = typename GemmRunner::LayoutOutput;
+  using LayoutInputWPack = typename GemmRunner::LayoutInputWPack;
+  using LayoutInputQScale = typename GemmRunner::LayoutInputQScale;
+
+  const cutlass::gemm::GemmCoord problem_size = {m, n, k};
+  const auto q_weight_shape = cutlass::make_Coord(problem_size.k()/2, problem_size.n());
+  const auto meta_shape = cutlass::make_Coord(problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn);
+
+  //
+  // Generate quantized and dequantizeed input matrix B [K, N]
+  //
+  static_assert(std::is_same<LayoutInputWPack, cutlass::layout::ColumnMajor>::value);
+  std::vector<ElementW> q_weights;
+  std::vector<ElementQScale> q_scales;
+  std::vector<ElementQOffset> q_zp;
+  std::vector<ElementDequant> dequants;
+  onnxruntime::cuda::test::blkq4_weights_gen<ElementDequant, block_size, column_wise_blocking, has_offsets>(
+      problem_size.k(), problem_size.n(), dequants, q_weights, q_scales, q_zp);
+
+  using PrepackT = onnxruntime::cuda::BlockwiseQuantization<
+      ElementDequant,
+      block_size,
+      4,
+      column_wise_blocking>;
+
+  std::vector<ElementW> packed_w(q_weight_shape.product());
+  PrepackT::prepack_weights(problem_size.k(), problem_size.n(), q_weights, packed_w);
+  std::vector<ElementQScale> packed_scales(meta_shape.product());
+  PrepackT::prepack_quant_scales(problem_size.k(), problem_size.n(), q_scales, packed_scales);
+  std::vector<ElementQOffset> packed_zp;
+  if constexpr (has_offsets) {
+    packed_zp.resize(meta_shape.product());
+    PrepackT::prepack_quant_offsets(problem_size.k(), problem_size.n(), q_zp, packed_zp);
+  }
+
+  cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
+      problem_size.mk());  // <- Create matrix A with dimensions M x K
+  cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
+      problem_size.mn());  // <- Create matrix C with dimensions M x N
+  cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
+      problem_size.mn());  // <- Create matrix D with dimensions M x N used to store output from
+                           // CUTLASS kernel
+
+  // Fill input and output matrices on host using CUTLASS helper functions
+  cutlass::reference::host::TensorFillRandomUniform(
+      tensor_a.host_view(),
+      1,
+      ElementInputA(4),
+      ElementInputA(-4),
+      2);  // <- Fill matrix A on host with uniform-distribution random data
+  cutlass::reference::host::TensorFillRandomUniform(
+      tensor_c.host_view(),
+      1,
+      ElementOutput(4),
+      ElementOutput(-4),
+      0);  // <- Fill matrix C on host with uniform-distribution random data
+  cutlass::reference::host::TensorFill(
+      tensor_d.host_view());  // <- fill matrix D on host with zeros
+
+  //
+  // Copy data from host to GPU...
+  //
+  thrust::device_vector<ElementW> d_packed_w(packed_w);
+  cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_W(
+    reinterpret_cast<ElementWPack const *>(d_packed_w.data().get()),
+    LayoutInputWPack::packed({problem_size.k()/2, problem_size.n()/2}));
+
+  thrust::device_vector<ElementQScale> d_packed_scales(packed_scales);
+  cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_scales(
+    d_packed_scales.data().get(), LayoutInputQScale::packed(meta_shape));
+
+  thrust::device_vector<ElementQOffset> d_packed_zp(packed_zp);
+  cutlass::TensorRef<ElementQOffset const, LayoutInputQScale> ref_zp(
+    d_packed_zp.data().get(), LayoutInputQScale::packed(meta_shape));
+
+  tensor_a.sync_device();
+  tensor_c.sync_device();
+  tensor_d.sync_device();
+
+  // run GEMM
+  cutlass::Status status;
+  if constexpr (has_offsets){
+    status = GemmRunner::run(
+      nullptr, problem_size, tensor_a.device_ref(), ref_W,
+      ref_scales, ref_zp,
+      tensor_c.device_ref(), tensor_d.device_ref());
+  } else {
+    status = GemmRunner::run(
+      nullptr, problem_size, tensor_a.device_ref(), ref_W,
+      ref_scales,
+      tensor_c.device_ref(), tensor_d.device_ref());
+  }
+  ORT_ENFORCE(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status));
+
+  // Running reference kernel
+  using ElementInputB = ElementInputA;
+  using LayoutInputB = cutlass::layout::ColumnMajor;
+  thrust::device_vector<ElementInputB> d_dequants(dequants);
+  cutlass::TensorRef<ElementInputB, LayoutInputB> ref_B(
+    d_dequants.data().get(), LayoutInputB::packed(problem_size.kn()));
+  cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
+      problem_size.mn());  // <- Create matrix D with dimensions M x N used to store output from
+                           // reference kernel
+
+  cutlass::reference::host::TensorFill(
+      tensor_ref_d.host_view());  // <- fill matrix D for reference on host with zeros
+  tensor_ref_d.sync_device();
+
+  // Initialize alpha and beta for dot product computation
+  ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
+  ElementComputeEpilogue beta = ElementComputeEpilogue(0);
+
+  compute_gemm_ref<ElementInputA, LayoutInputA,
+               ElementInputB, LayoutInputB,
+               ElementOutput, LayoutOutput,
+               ElementComputeEpilogue, ElementAccumulator>(
+      problem_size,
+      alpha,
+      tensor_a.device_ref(),
+      ref_B,
+      beta,
+      tensor_c.device_ref(),
+      tensor_ref_d.device_ref());
+
+  // Wait for kernels to finish
+  cudaDeviceSynchronize();
+
+  // Copy output data from CUTLASS and reference kernel to host for comparison
+  tensor_d.sync_host();
+  tensor_ref_d.sync_host();
+
+  // Check if output from CUTLASS kernel and reference kernel are equal or not
+  bool passed = cutlass::reference::host::TensorEquals(
+    tensor_d.host_view(),
+    tensor_ref_d.host_view());
+  ORT_ENFORCE(passed, "Gemm kernel result wrong!");
+}
+
+template void run_blkq4_gemm<16, true, false, true>(int m, int n, int k);
+template void run_blkq4_gemm<16, true, false, false>(int m, int n, int k);
+template void run_blkq4_gemm<32, true, false, true>(int m, int n, int k);
+template void run_blkq4_gemm<32, true, false, false>(int m, int n, int k);
+template void run_blkq4_gemm<64, true, false, true>(int m, int n, int k);
+template void run_blkq4_gemm<64, true, false, false>(int m, int n, int k);
+template void run_blkq4_gemm<16, false, false, true>(int m, int n, int k);
+template void run_blkq4_gemm<16, false, false, false>(int m, int n, int k);
+template void run_blkq4_gemm<32, false, false, true>(int m, int n, int k);
+template void run_blkq4_gemm<32, false, false, false>(int m, int n, int k);
+template void run_blkq4_gemm<64, false, false, true>(int m, int n, int k);
+template void run_blkq4_gemm<64, false, false, false>(int m, int n, int k);
+template void run_blkq4_gemm<16, true, true, true>(int m, int n, int k);
+template void run_blkq4_gemm<16, true, true, false>(int m, int n, int k);
+template void run_blkq4_gemm<32, true, true, true>(int m, int n, int k);
+template void run_blkq4_gemm<32, true, true, false>(int m, int n, int k);
+template void run_blkq4_gemm<64, true, true, true>(int m, int n, int k);
+template void run_blkq4_gemm<64, true, true, false>(int m, int n, int k);
+template void run_blkq4_gemm<16, false, true, true>(int m, int n, int k);
+template void run_blkq4_gemm<16, false, true, false>(int m, int n, int k);
+template void run_blkq4_gemm<32, false, true, true>(int m, int n, int k);
+template void run_blkq4_gemm<32, false, true, false>(int m, int n, int k);
+template void run_blkq4_gemm<64, false, true, true>(int m, int n, int k);
+template void run_blkq4_gemm<64, false, true, false>(int m, int n, int k);
+
+}  // namespace test
+}  // namespace cuda
+}  // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc
deleted file mode 100644
index aba2b0b2cb4a4..0000000000000
--- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc
+++ /dev/null
@@ -1,507 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include <random>
-
-#include "core/framework/float16.h"
-#include "core/mickey/blk_q4/prepack_sm80.h"
-#include "core/mlas/inc/mlas_q4.h"
-
-#include "gtest/gtest.h"
-
-namespace onnxruntime {
-namespace test {
-
-void prepack_weights_ref(
-    int rows,
-    int columns,
-    const MatrixRef<uint8_t const, ColumnMajorLayout, true>& tensor_weight,
-    const MatrixRef<uint8_t, ColumnMajorLayout, true>& tensor_weight_prepacked) {
-  EXPECT_TRUE(tensor_weight.shape()[0] == rows / 2 && tensor_weight.shape()[1] == columns);
-  EXPECT_TRUE(tensor_weight_prepacked.shape()[0] == rows && tensor_weight_prepacked.shape()[1] == columns / 2);
-
-  auto t0_base = make_Position(0, 0);
-  auto t1_base = make_Position(4, 0);
-  auto t2_base = make_Position(0, 8);
-  auto t3_base = make_Position(4, 8);
-  for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) {
-    for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) {
-      // Packing from a 8x16 tile to a 16x8 tile
-      auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16);
-      auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8);
-      for (int col = 0; col < 8; ++col) {
-        for (int row = 0; row < 4; ++row) {
-          auto cord = make_Position(row, col);
-          auto packed_cord = packed_tile_base + make_Position(row * 4, col);  // packed tile is 16x8
-          uint8_t buf[4];
-          buf[0] = tensor_weight.at(dtile_base + t0_base + cord);
-          buf[1] = tensor_weight.at(dtile_base + t1_base + cord);
-          buf[2] = tensor_weight.at(dtile_base + t2_base + cord);
-          buf[3] = tensor_weight.at(dtile_base + t3_base + cord);
-
-          // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights
-          // are in different b16 register at the same positions. This makes it easier to convert to
-          // fp16x2 format in a b32 register
-
-          tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4);
-          tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4);
-          tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0);
-          tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0);
-        }
-      }
-    }
-  }
-}
-
-template <
-    typename ScaleElementT,
-    typename Layout,
-    typename QuantBlocking>
-void prepack_quant_scales_ref(
-    int rows,
-    int columns,
-    const MatrixRef<ScaleElementT const, Layout, true>& tensor_scale,
-    const MatrixRef<ScaleElementT, Layout, true>& tensor_scale_prepacked) {
-  EXPECT_TRUE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn));
-  EXPECT_TRUE(tensor_scale_prepacked.shape() == tensor_scale.shape());
-
-  // Only prepacking scale and offset tensors for a often used special case:
-  //    16b gemm (2 elements per 32b register, operand tile shape 8x8)
-  //    2 B operand tiles per mma instruction stacked on k dimension
-  //    (1,n) quantization blocking
-  if constexpr (sizeof(ScaleElementT) == 2 && QuantBlocking::kRow == 1) {
-    // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread
-    // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use
-    // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension,
-    // as shown below (T stands for thread):
-    // T0, T4, T8, T12
-    // T1, T5, T9, T13
-    // T2, T6, T10, T14
-    // T3, T7, T11, T15
-    // T0, T4, T8, T12
-    // T1, T5, T9, T13
-    // T2, T6, T10, T14
-    // T3, T7, T11, T15
-    //
-    // We need to deliver quantization scale and offset elements to the corresponding threads,
-    // so we can perform dequantization efficiently. With a column major layout, each thread
-    // needs two separate loads for a mma instruction, due to the tile fragment layout shown
-    // above. To reduce the number of loads, we rearrange each column as below, so we can use
-    // a single load to load fragments for two tiles:
-    // T0        T0
-    // T1        T0
-    // T2        T1
-    // T3   =>   T1
-    // T0        T2
-    // T1        T2
-    // T2        T3
-    // T3        T3
-
-    for (int col = 0; col < tensor_scale.shape()[1]; ++col) {
-      for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) {
-        for (int thread_id = 0; thread_id < 4; thread_id++) {
-          const int dst_idx = row_blk + thread_id * 4;
-          const int src_idx = row_blk + thread_id * 2;
-          tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col);
-          tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col);
-          tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col);
-          tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col);
-        }
-      }
-    }
-  } else {
-    // In all other cases, we don't prepack scale or offset
-    FAIL() << "Scale prepack only supported for 16b gemm with (1,n) quantization blocking";
-  }
-}
-
-template <typename Layout, typename QuantBlocking>
-void prepack_quant_offsets_ref(
-    size_t rows,
-    size_t columns,
-    MatrixRef<uint8_t const, Layout, true> tensor_offset,
-    MatrixRef<uint8_t, Layout, true> tensor_offset_prepacked) {
-  // EXPECT_TRUE(tensor_offset.shape()[0] == (rows / QuantBlocking::kRow) && tensor_offset.shape()[1] == (columns / QuantBlocking::kColumn));
-  EXPECT_TRUE(tensor_offset_prepacked.shape() == tensor_offset.shape());
-
-  // Only prepacking scale and offset tensors for a often used special case:
-  //    16b gemm (2 elements per 32b register, operand tile shape 8x8)
-  //    2 B operand tiles per mma instruction stacked on k dimension
-  //    (1,n) quantization blocking
-  if constexpr (QuantBlocking::kRow != 1) {
-    FAIL() << "Offsets prepack only supported for 16b gemm with (1,n) quantization blocking";
-  }
-  // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread
-  // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use
-  // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension,
-  // as shown below (T stands for thread):
-  // T0, T4, T8, T12
-  // T1, T5, T9, T13
-  // T2, T6, T10, T14
-  // T3, T7, T11, T15
-  // T0, T4, T8, T12
-  // T1, T5, T9, T13
-  // T2, T6, T10, T14
-  // T3, T7, T11, T15
-  //
-  // We need to deliver quantization scale and offset elements to the corresponding threads,
-  // so we can perform dequantization efficiently. With a column major layout, each thread
-  // needs two separate loads for a mma instruction, due to the tile fragment layout shown
-  // above. To reduce the number of loads, we rearrange each column as below, so we can use
-  // a single load to load fragments for two tiles:
-  // T0        T0
-  // T1        T0
-  // T2        T1
-  // T3   =>   T1
-  // T0        T2
-  // T1        T2
-  // T2        T3
-  // T3        T3
-  if (tensor_offset_prepacked.good()) {
-    for (int col = 0; col < tensor_offset.shape()[1]; ++col) {
-      for (int row_blk = 0; row_blk < tensor_offset.shape()[0]; row_blk += 16) {
-        for (int thread_id = 0; thread_id < 4; thread_id++) {
-          const int dst_idx = row_blk + thread_id * 4;
-          const int src_idx = row_blk + thread_id * 2;
-          // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own
-          // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to
-          // convert to fp16x2 format in a b32 register
-          tensor_offset_prepacked.at(dst_idx + 0, col) = tensor_offset.at(src_idx + 0, col);
-          tensor_offset_prepacked.at(dst_idx + 1, col) = tensor_offset.at(src_idx + 8, col);
-          tensor_offset_prepacked.at(dst_idx + 2, col) = tensor_offset.at(src_idx + 1, col);
-          tensor_offset_prepacked.at(dst_idx + 3, col) = tensor_offset.at(src_idx + 9, col);
-        }
-      }
-    }
-  }
-}
-
-template <bool ColumnMajorQuantBlocking>
-void testPrepack(int rows, int columns, bool has_offset = true) {
-  using ElementT = MLFloat16;
-  constexpr int block_size = 32;
-  using Base = onnxruntime::cuda::BlockwiseQuantization<
-      ElementT,
-      block_size,
-      4,
-      ColumnMajorQuantBlocking>;
-
-  using QuantBlocking = typename Base::QuantBlocking;
-  using ElementW = typename Base::ElementW;
-  using LayoutWPack = typename Base::LayoutWPack;
-  using ElementQOffset = typename Base::ElementQOffset;
-  using LayoutQmeta = typename Base::LayoutQmeta;
-
-  unsigned int seed = 28571;  // Replace with desired seed value
-  std::seed_seq seq{seed};
-  std::mt19937 gen(seq);
-  std::uniform_int_distribution<> dis(0, 8192);
-
-  const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns);
-  const auto meta_shape = Base::get_quant_meta_shape(rows, columns);
-
-  //
-  // For testing quantization and dequantization, it is not straight
-  // forward to avoid flaky tests due to rounding errors. The way we
-  // try to achieve this is to:
-  // 1. Generate a set of quantized weights, scales and offsets
-  // 2. Dequantize the weights
-  // 3. Quantize the dequantized weights
-  // 4. Compare the dequantied-and-then-quantized weights with
-  //    the original quantized weights
-  //
-  // Random filling of the initial values are key to get this right.
-  // For weights, we must ensure each block gets a full range of
-  // values, i.e. must contain 0 and 15. And for scales, they must
-  // all be positive.
-  //
-
-  std::vector<ElementW> q_weights(q_weight_shape.product());
-  MatrixRef<ElementW, LayoutWPack, true> tensor_q_weight(
-      q_weights, make_Position(rows / 2, columns));
-  int v = 7;
-  for (int c = 0; c < tensor_q_weight.shape()[1]; c++) {
-    for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) {
-      uint8_t v0 = static_cast<uint8_t>(v);
-      v = (v + 5) % 16;
-      if (v == 11 || v == 7 || v == 3) {
-        // making the cycle 13 instead of 16, avoiding same values in a row
-        v = (v + 5) % 16;
-      }
-      uint8_t v1 = 0;
-      if (r + 1 < rows) {
-        v1 = static_cast<uint8_t>(v);
-        v = (v + 5) % 16;
-        if (v == 11 || v == 7 || v == 3) {
-          // making the cycle 13 instead of 16, avoiding same values in a row
-          v = (v + 5) % 16;
-        }
-      }
-
-      tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0);
-    }
-  }
-
-  std::vector<ElementT> q_scales(meta_shape.product());
-  for (size_t i = 0; i < q_scales.size(); i++) {
-    q_scales[i] = ElementT(((dis(gen) % 127) + 1) / 32.0f);
-  }
-  MatrixRef<ElementT, LayoutQmeta, true> tensor_scale(
-      q_scales, meta_shape);
-
-  std::vector<ElementQOffset> q_zp(meta_shape.product());
-  for (size_t i = 0; i < q_zp.size(); i++) {
-    q_zp[i] = dis(gen) % 16;
-  }
-  MatrixRef<ElementQOffset, LayoutQmeta, true> tensor_offset(
-      q_zp, meta_shape);
-
-#if 0  // debug
-  // Fill tensor_q_weight with the patterned data, easier to debug with print
-  int loop_val = 0;
-  int offset = 3;
-  for (int col_tile = 0; col_tile < tensor_q_weight.extent().column()/8; ++col_tile) {
-    for (int row_tile = 0; row_tile < tensor_q_weight.extent().row()/4; ++row_tile) {
-      for (int col = 0; col < 8; ++col) {
-        for (int row = 0; row < 4; ++row) {
-          auto weight_cord = cutlass::make_Coord(row_tile * 4 + row, col_tile * 8 + col);
-          auto val = (loop_val + offset) % 256;
-          tensor_q_weight.at(weight_cord) = ElementW(val);
-          loop_val++;
-          if (loop_val == 256) {
-            loop_val = 0;
-            offset += 11;
-          }
-        }
-      }
-    }
-  }
-  for (int col = 0; col < tensor_scale.extent().column(); ++col){
-    int c =  col * QuantBlocking::kColumn;
-    for (int row = 0; row < tensor_scale.extent().row(); ++row){
-      int r = row * QuantBlocking::kRow;
-      auto weight_cord = cutlass::make_Coord(r/2, c);
-      int w = 0;
-      if (r % 2 == 0) {
-        w = int(tensor_q_weight.at(weight_cord) & 0x0f);
-      } else {
-        w = int(tensor_q_weight.at(weight_cord) >> 4);
-      }
-      tensor_scale.at({row, col}) = w;
-      tensor_offset.at({row, col}) = ElementQOffset(w);
-    }
-  }
-
-  int fill_val = -512;
-  int factor = 1;
-  for (int col = 0; col < tensor_scale.extent().column(); ++col){
-    for (int row = 0; row < tensor_scale.extent().row(); ++row){
-      tensor_scale.at({row, col}) = ElementQScale((float)fill_val * float(factor));
-      fill_val++;
-      if (fill_val == 512) {
-        fill_val = -512;
-        factor += 1;
-      }
-    }
-  }
-
-#endif  // debug
-
-  std::vector<ElementT> dequants(rows * columns);
-  MatrixRef<ElementT, RowMajorLayout> tensor_dequant(dequants, make_Position(rows, columns));
-
-  // Dequantize weights and save into matrix B for reference
-  for (int col = 0; col < tensor_dequant.shape()[1]; ++col) {
-    for (int row = 0; row < tensor_dequant.shape()[0]; ++row) {
-      auto weight_cord = make_Position(row / 2, col);
-      auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn);
-      const uint8_t offset = has_offset ? tensor_offset.at(scale_cord) : 8;
-      int w = 0;
-      if (row % 2 == 0) {
-        w = int(tensor_q_weight.at(weight_cord) & 0x0f);
-      } else {
-        w = int(tensor_q_weight.at(weight_cord) >> 4);
-      }
-      float scale = float(tensor_scale.at(scale_cord));
-      float dequant = scale * float(w - offset);
-      tensor_dequant.at(row, col) = ElementT(dequant);
-      // Prints for help debugging in case of test failure
-      // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant);
-    }
-  }
-
-  int q_rows, q_cols;
-  MlasBlockwiseQuantizedShape<ElementT, 4>(
-      block_size, ColumnMajorQuantBlocking, rows, columns, q_rows, q_cols);
-  // to be exact, q_rows are padded to multiple of block_size, deal with it when we care about strange shapes
-  EXPECT_EQ(q_rows, q_weight_shape[0]);
-  EXPECT_EQ(q_cols, q_weight_shape[1]);
-
-  //
-  // Quantization tool outputs:
-  //
-  std::vector<ElementW> o_elements(q_rows * q_cols);
-  MatrixRef<ElementW, ColumnMajorLayout, true> tensor_o_elements(o_elements, q_weight_shape);
-
-  std::vector<ElementT> o_scales(meta_shape.product());
-  MatrixRef<ElementT, ColumnMajorLayout, true> tensor_o_scales(o_scales, meta_shape);
-
-  std::vector<uint8_t> o_zp(((meta_shape[0] + 1) / 2) * meta_shape[1], true);
-  MatrixRef<uint8_t, ColumnMajorLayout, true> tensor_o_zp(
-      o_zp, make_Position((meta_shape[0] + 1) / 2, meta_shape[1]));
-
-  MlasQuantizeBlockwise<MLFloat16, 4>(o_elements.data(), o_scales.data(), has_offset ? o_zp.data() : nullptr,
-                                      tensor_dequant.data().data(), block_size,
-                                      ColumnMajorQuantBlocking, rows, columns, columns, nullptr);
-  for (int col = 0; col < tensor_q_weight.shape()[1]; ++col) {
-    for (int row = 0; row < tensor_q_weight.shape()[0]; ++row) {
-      EXPECT_EQ(tensor_o_elements.at(row, col), tensor_q_weight.at(row, col))
-          << "quantized value mismatch at [" << row << "," << col << "]"
-          << " shape[" << rows << "," << columns << "]"
-          << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block")
-          << std::endl;
-    }
-  }
-
-  for (int col = 0; col < meta_shape[1]; ++col) {
-    for (int row = 0; row < meta_shape[0]; row += 2) {
-      if (has_offset) {
-        uint8_t pair01 = tensor_o_zp.at(row / 2, col);
-        EXPECT_EQ(tensor_offset.at(row + 0, col), pair01 & 0xf)
-            << "quantized offset mismatch at [" << row << "," << col << "]"
-            << " shape[" << rows << "," << columns << "]"
-            << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block")
-            << std::endl;
-        if (row + 1 < meta_shape[0]) {
-          EXPECT_EQ(tensor_offset.at(row + 1, col), pair01 >> 4)
-              << "quantized offset mismatch at [" << row + 1 << "," << col << "]"
-              << " shape[" << rows << "," << columns << "]"
-              << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block")
-              << std::endl;
-        }
-      }
-
-      EXPECT_EQ(tensor_scale.at(row + 0, col), tensor_o_scales.at(row + 0, col))
-          << "quantized scale mismatch at [" << row << "," << col << "]"
-          << " shape[" << rows << "," << columns << "]"
-          << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block")
-          << std::endl;
-      if (row + 1 < meta_shape[0]) {
-        EXPECT_EQ(tensor_scale.at(row + 1, col), tensor_o_scales.at(row + 1, col))
-            << "quantized scale mismatch at [" << row + 1 << "," << col << "]"
-            << " shape[" << rows << "," << columns << "]"
-            << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block")
-            << std::endl;
-      }
-    }
-  }
-
-  //
-  // Now we just setup fp16 weights tensor_dequant, quantized weights tensor_q_weight,
-  // quantization scale tensor_scale and quantization offset tensor_offset. The above
-  // testing just make sure our test setup is consistent with quantization tool output.
-  //
-  // Next we test the prepack code
-  //
-
-  std::vector<ElementW> packed_w_ref(q_weight_shape.product());
-  MatrixRef<ElementW, LayoutWPack, true> tensor_packed_w_ref(
-      packed_w_ref, make_Position(rows, columns / 2));
-  prepack_weights_ref(rows, columns, tensor_q_weight, tensor_packed_w_ref);
-
-  std::vector<ElementW> packed_w(q_weight_shape.product());
-  MatrixRef<ElementW, LayoutWPack, true> tensor_packed_w(
-      packed_w, make_Position(rows, columns / 2));
-  Base::prepack_weights(rows, columns, o_elements, packed_w);
-
-  for (int col = 0; col < tensor_packed_w.shape()[1]; ++col) {
-    for (int row = 0; row < tensor_packed_w.shape()[0]; ++row) {
-      EXPECT_EQ(tensor_packed_w_ref.at(row, col), tensor_packed_w.at(row, col))
-          << "prepacked weights mismatch at [" << row << "," << col << "]"
-          << " shape[" << rows << "," << columns << "]"
-          << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block")
-          << std::endl;
-    }
-  }
-
-  std::vector<ElementT> packed_scales_ref(meta_shape.product());
-  MatrixRef<ElementT, LayoutQmeta, true> tensor_packed_s_ref =
-      Base::ShouldRearrangeMeta ? make_MatrixRef<ElementT, LayoutQmeta, true>(packed_scales_ref, meta_shape)
-                                : tensor_scale;
-  if (Base::ShouldRearrangeMeta) {
-    prepack_quant_scales_ref<ElementT, LayoutQmeta, QuantBlocking>(
-        rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref);
-  }
-
-  std::vector<ElementT> packed_scales(meta_shape.product());
-  MatrixRef<ElementT, LayoutQmeta, true> tensor_packed_s(
-      packed_scales, meta_shape);
-  Base::prepack_quant_scales(rows, columns, o_scales, packed_scales);
-
-  for (int col = 0; col < tensor_packed_s.shape()[1]; ++col) {
-    for (int row = 0; row < tensor_packed_s.shape()[0]; ++row) {
-      EXPECT_EQ(tensor_packed_s_ref.at(row, col), tensor_packed_s.at(row, col))
-          << "prepacked scales mismatch at [" << row << "," << col << "]"
-          << " shape[" << rows << "," << columns << "]"
-          << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block")
-          << std::endl;
-    }
-  }
-
-  if (has_offset) {
-    std::vector<ElementQOffset> packed_zp_ref(meta_shape.product());
-    MatrixRef<ElementQOffset, LayoutQmeta, true> tensor_packed_zp_ref =
-        Base::ShouldRearrangeMeta ? make_MatrixRef<ElementQOffset, LayoutQmeta, true>(packed_zp_ref, meta_shape)
-                                  : tensor_offset;
-    if (Base::ShouldRearrangeMeta) {
-      prepack_quant_offsets_ref<LayoutQmeta, QuantBlocking>(
-          rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref);
-    }
-
-    std::vector<ElementQOffset> packed_zp(meta_shape.product());
-    MatrixRef<ElementQOffset, LayoutQmeta, true> tensor_packed_zp(
-        packed_zp, meta_shape);
-    Base::prepack_quant_offsets(rows, columns, o_zp, packed_zp);
-
-    for (int col = 0; col < tensor_packed_zp.shape()[1]; ++col) {
-      for (int row = 0; row < tensor_packed_zp.shape()[0]; ++row) {
-        EXPECT_EQ(tensor_packed_zp_ref.at(row, col), tensor_packed_zp.at(row, col))
-            << "prepacked offsets mismatch at [" << row << "," << col << "]"
-            << " shape[" << rows << "," << columns << "]"
-            << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block")
-            << std::endl;
-      }
-    }
-  }
-}
-
-// TODO: code runs on CPU, but this is for sm80 only, maybe enable only when test on sm80
-TEST(BlkQ4_GEMM, PrepackSm80Test) {
-  testPrepack<false>(32, 32);
-  testPrepack<false>(32, 32, false);
-  testPrepack<true>(32, 32);
-  testPrepack<true>(32, 32, false);
-  testPrepack<false>(32, 64);
-  testPrepack<false>(32, 128);
-  testPrepack<false>(32, 256);
-  testPrepack<false>(64, 32);
-  testPrepack<false>(128, 32);
-  testPrepack<false>(256, 32);
-  testPrepack<false>(256, 256);
-  testPrepack<false>(32, 128, false);
-  testPrepack<false>(128, 32, false);
-  testPrepack<false>(256, 256, false);
-  testPrepack<true>(32, 64);
-  testPrepack<true>(32, 128);
-  testPrepack<true>(32, 256);
-  testPrepack<true>(64, 32);
-  testPrepack<true>(128, 32);
-  testPrepack<true>(256, 32);
-  testPrepack<true>(256, 256);
-  testPrepack<true>(32, 128, false);
-  testPrepack<true>(128, 32, false);
-  testPrepack<true>(256, 256, false);
-}
-
-}  // namespace test
-}  // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc
index a70e439cdf755..8dfaaedcbb378 100644
--- a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc
+++ b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc
@@ -22,16 +22,18 @@ TEST(TestDeferredRelease, WithArena) {
   CUDAExecutionProvider ep(info);
   AllocatorPtr gpu_alloctor = ep.CreatePreferredAllocators()[0];
 
+  RunOptions run_opts;
+  run_opts.run_tag = "log1";
   // Allocator for call cudaMallocHost and cudaFreeHost
   // For details, see CUDAPinnedAllocator in cuda_allocator.cc.
   AllocatorPtr cpu_pinned_alloc = ep.CreatePreferredAllocators()[1];
   // let the CudaStream instance "own" the default stream, so we can avoid the
   // work to initialize cublas/cudnn/... It is ok since it is just a customized unit test.
-  CudaStream stream(nullptr, gpu_alloctor->Info().device, cpu_pinned_alloc, false, true, nullptr, nullptr);
+  CudaStream stream(nullptr, gpu_alloctor->Info().device, cpu_pinned_alloc, false, true, nullptr, nullptr, info);
   // 10 MB
   const size_t n_bytes = 10 * 1000000;
   const int64_t n_allocs = 64;
-  ORT_THROW_IF_ERROR(ep.OnRunStart());
+  ORT_THROW_IF_ERROR(ep.OnRunStart(run_opts));
   for (size_t i = 0; i < n_allocs; ++i) {
     // Allocate 10MB CUDA pinned memory.
     auto pinned_buffer = IAllocator::MakeUniquePtr<void>(cpu_pinned_alloc, n_bytes);
@@ -44,7 +46,7 @@ TEST(TestDeferredRelease, WithArena) {
   cpu_pinned_alloc->GetStats(&stats);
   ASSERT_EQ(stats.num_allocs, n_allocs);
   ORT_THROW_IF_ERROR(stream.CleanUpOnRunEnd());
-  ORT_THROW_IF_ERROR(ep.OnRunEnd(true));
+  ORT_THROW_IF_ERROR(ep.OnRunEnd(true, run_opts));
 }
 
 TEST(TestDeferredRelease, WithoutArena) {
@@ -52,6 +54,9 @@ TEST(TestDeferredRelease, WithoutArena) {
   CUDAExecutionProviderInfo info;
   CUDAExecutionProvider ep(info);
 
+  RunOptions run_opts;
+  run_opts.run_tag = "log1";
+
   OrtDevice pinned_device{OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, DEFAULT_CPU_ALLOCATOR_DEVICE_ID};
   // Create allocator without BFCArena
   AllocatorCreationInfo pinned_memory_info(
@@ -66,11 +71,11 @@ TEST(TestDeferredRelease, WithoutArena) {
   // For details, see CUDAPinnedAllocator in cuda_allocator.cc.
   // let the CudaStream instance "own" the default stream, so we can avoid the
   // work to initialize cublas/cudnn/... It is ok since it is just a customized unit test.
-  CudaStream stream(nullptr, gpu_alloctor->Info().device, cuda_pinned_alloc, false, true, nullptr, nullptr);
+  CudaStream stream(nullptr, gpu_alloctor->Info().device, cuda_pinned_alloc, false, true, nullptr, nullptr, info);
   // 10 MB
   const size_t n_bytes = 10 * 1000000;
   const int64_t n_allocs = 64;
-  ORT_THROW_IF_ERROR(ep.OnRunStart());
+  ORT_THROW_IF_ERROR(ep.OnRunStart(run_opts));
   for (size_t i = 0; i < n_allocs; ++i) {
     // Allocate 10MB CUDA pinned memory.
     auto pinned_buffer = IAllocator::MakeUniquePtr<void>(cuda_pinned_alloc, n_bytes);
@@ -79,7 +84,7 @@ TEST(TestDeferredRelease, WithoutArena) {
   }
 
   ORT_THROW_IF_ERROR(stream.CleanUpOnRunEnd());
-  ORT_THROW_IF_ERROR(ep.OnRunEnd(true));
+  ORT_THROW_IF_ERROR(ep.OnRunEnd(true, run_opts));
 }
 
 }  // namespace test
diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc
index 0167f7a7718b1..2e073def5d643 100644
--- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc
+++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc
@@ -220,6 +220,7 @@ InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer&
   auto compile_capabilities = utils::CreateSupportedPartitions(graph_viewer, supported_compiled_nodes, stop_ops_,
                                                                generate_metadef_name, ep_name_,
                                                                onnxruntime::utils::kInternalTestingExecutionProvider,
+                                                               /*QDQ NodeUnit map*/ nullptr,
                                                                debug_output_);
 
   if (!static_capabilities.empty()) {
diff --git a/onnxruntime/test/providers/partitioning_utils_test.cc b/onnxruntime/test/providers/partitioning_utils_test.cc
new file mode 100644
index 0000000000000..5db69489afaef
--- /dev/null
+++ b/onnxruntime/test/providers/partitioning_utils_test.cc
@@ -0,0 +1,174 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "core/common/common.h"
+#include "core/graph/graph_viewer.h"
+#include "core/graph/model.h"
+#include "core/framework/node_unit.h"
+#include "core/framework/compute_capability.h"
+#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
+#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
+#include "core/providers/partitioning_utils.h"
+
+#include "test/optimizer/graph_transform_test_builder.h"
+#include "test/optimizer/qdq_test_utils.h"
+#include "test/util/include/asserts.h"
+#include "test/util/include/test_utils.h"
+#include "test/util/include/test/test_environment.h"
+
+namespace onnxruntime {
+namespace test {
+
+// Test handling of a DQ node that is connected to an initializer at the start of the graph, but not used
+// in a QDQ node group until after an unsupported node in the graph. If we do not process QDQ node units
+// correctly this DQ will incorrectly be in the first partition, with the rest of the QDQ node group in
+// the second partition.
+TEST(PartitioningUtilsTest, TestQDQHandling) {
+  constexpr const ORTCHAR_T* model_uri = ORT_TSTR("testdata/ort_github_issue_19590.onnx");
+  auto& logger = DefaultLoggingManager().DefaultLogger();
+
+  std::shared_ptr<Model> p_model;
+  ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, logger));
+  Graph& graph = p_model->MainGraph();
+  GraphViewer graph_viewer = GraphViewer(graph);
+
+  // we want everything but the Cast in the test model to be supported
+  const auto is_node_supported = [&](const Node& node) -> bool {
+    return node.OpType() != "Cast";
+  };
+
+  const auto on_group_closed = [&](const std::vector<const Node*>& /*group*/) -> bool {
+    return true;
+  };
+
+  const auto gen_metadef_name = [&]() {
+    static int metadef_id = 0;
+    return "TestMetaDef_" + std::to_string(metadef_id++);
+  };
+
+  std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
+  std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
+  std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
+
+  auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed,
+                                                 gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map,
+                                                 true);
+
+  // we should have 2 supported partitions, split by the Cast node.
+  // the first should have the Mul and NOT the DQ for the initializer if everything worked correctly.
+  ASSERT_EQ(result.size(), size_t(2)) << "Expected 2 partitions";
+  ASSERT_EQ(result[0]->sub_graph->nodes.size(), size_t(1)) << "First partition should only have the Mul and not a DQ";
+  ASSERT_EQ(result[1]->sub_graph->nodes.size(), size_t(5));  // everything else except the unsupported Cast
+}
+
+/// Check that CreateSupportedPartitions processes all nodes without error.
+static void CheckAllNodesProcessed(const std::function<void(ModelTestBuilder&)>& build_model) {
+  auto& logger = DefaultLoggingManager().DefaultLogger();
+  const std::unordered_map<std::string, int> domain_to_version = {{"", 15}};
+
+  Model model("PartitioningUtils_TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
+              domain_to_version, {}, logger);
+
+  Graph& graph = model.MainGraph();
+  ModelTestBuilder helper(graph);
+  build_model(helper);
+  helper.SetGraphOutputs();
+  ASSERT_STATUS_OK(model.MainGraph().Resolve());
+
+  GraphViewer graph_viewer = GraphViewer(graph);
+
+  std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
+  std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
+  std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
+
+  const auto is_node_supported = [&](const Node& /*node*/) -> bool {
+    return true;
+  };
+
+  const auto on_group_closed = [&](const std::vector<const Node*>& /*group*/) -> bool {
+    return true;
+  };
+
+  const auto gen_metadef_name = [&]() {
+    static int metadef_id = 0;
+    return "TestMetaDef_" + std::to_string(metadef_id++);
+  };
+
+  auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed,
+                                                 gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map,
+                                                 true);
+
+  // the 'real' test is that CreateSupportedPartitions doesn't throw due to a mismatch with expected vs processed nodes
+  // as all ops are supported there should only ever be 1 partition
+  ASSERT_EQ(result.size(), size_t(1)) << "Expected 1 partition";
+}
+
+TEST(PartitioningUtilsTest, TestHandlingQDQNodeUnitWithNoQNodes) {
+  // build graph with QDQ node unit for logical operator (Equal) that has no Q node and a downstream node (Cast).
+  auto build_model = [](ModelTestBuilder& builder) {
+    constexpr uint8_t zero_point = 0;
+    constexpr float qdq_scale = 0.0038f;
+    const std::vector<int64_t> input_shape = {1, 3, 8, 8};
+
+    auto* input0 = builder.MakeInput<float>(input_shape, -1.0f, 1.0f);
+    auto* input1 = builder.MakeInput<float>(input_shape, -1.0f, 1.0f);
+    auto* output = builder.MakeOutput();
+
+    // input -> Q -> DQ -> Op
+    auto* qdq0_output = AddQDQNodePair<uint8_t>(builder, input0, qdq_scale, zero_point);
+    auto* qdq1_output = AddQDQNodePair<uint8_t>(builder, input1, qdq_scale, zero_point);
+
+    // Equal ->
+    auto* equal_output = builder.MakeIntermediate();
+    builder.AddNode("Equal", {qdq0_output, qdq1_output}, {equal_output});
+
+    // -> Cast -> output
+    Node& cast_node = builder.AddNode("Cast", {equal_output}, {output});
+    cast_node.AddAttribute("to",
+                           static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT));
+  };
+
+  CheckAllNodesProcessed(build_model);
+}
+
+// TopK produces 2 outputs, one of which is used in a QDQ node group (Q of values output)
+// and the other (indices output) is not. A downstream node consuming the indices output has an edge from the target
+// node and not a Q node.
+// To process this correctly, the QDQ NodeUnit must return output edges for both the Q node/s of the values output,
+// and the downstream node (Cast in this case) of the indices output.
+TEST(PartitioningUtilsTest, TestQDQNodeGroupWithOutputFromTargetNode) {
+  const auto build_model = [](ModelTestBuilder& builder) {
+    constexpr uint8_t zero_point = 0;
+    constexpr float qdq_scale = 0.0038f;
+    const std::vector<int64_t> input_shape = {1, 3, 8, 8};
+
+    auto* input0 = builder.MakeInput<float>(input_shape, -1.0f, 1.0f);
+
+    // input -> Q -> DQ ->
+    auto* qdq0_output = AddQDQNodePair<uint8_t>(builder, input0, qdq_scale, zero_point);
+
+    // K input
+    NodeArg* k_input = builder.MakeInput<int64_t>({1}, {10});
+
+    // TopK op
+    NodeArg* values_output = builder.MakeIntermediate();
+    NodeArg* indices_output = builder.MakeIntermediate();
+    builder.AddNode("TopK", {qdq0_output, k_input}, {values_output, indices_output});
+
+    // values -> Q -> DQ -> graph output
+    AddQDQNodePairWithOutputAsGraphOutput<uint8_t>(builder, values_output, qdq_scale, zero_point);
+
+    // indices -> Cast -> graph output
+    auto* i_output = builder.MakeOutput();
+    Node& cast_node = builder.AddNode("Cast", {indices_output}, {i_output});
+    const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32;
+    cast_node.AddAttribute("to", static_cast<int64_t>(dst_type));
+  };
+
+  CheckAllNodesProcessed(build_model);
+}
+}  // namespace test
+}  // namespace onnxruntime
diff --git a/onnxruntime/test/providers/qnn/clip_op_test.cc b/onnxruntime/test/providers/qnn/clip_op_test.cc
index 15ba3b5de2fa1..e899f870f9e78 100644
--- a/onnxruntime/test/providers/qnn/clip_op_test.cc
+++ b/onnxruntime/test/providers/qnn/clip_op_test.cc
@@ -182,6 +182,44 @@ TEST_F(QnnHTPBackendTests, Clip_U8_Rank5) {
                   ExpectedEPNodeAssignment::All);
 }
 
+// Test FP16 Clip with min (FP16)
+TEST_F(QnnHTPBackendTests, Clip_FP16) {
+  ProviderOptions provider_options;
+
+#if defined(_WIN32)
+  provider_options["backend_path"] = "QnnHtp.dll";
+#else
+  provider_options["backend_path"] = "libQnnHtp.so";
+#endif
+
+  auto f32_input = TestInputDef<float>({1, 3, 2, 2}, false,
+                                       {-10.0f, -8.0f, -3.5f, 2.2f,
+                                        1.3f, 1.5f, 3.2f, 5.8f,
+                                        5.8f, 9.7f, 8.5f, 8.9f});
+  std::vector<MLFloat16> f16_data;
+  std::for_each(f32_input.GetRawData().begin(), f32_input.GetRawData().end(),
+                [&f16_data](const float data) {
+                  f16_data.push_back(static_cast<MLFloat16>(data));
+                });
+  auto f16_input = TestInputDef<MLFloat16>({1, 3, 2, 2}, false, f16_data);
+
+  const float min_f32 = 1.2f;
+  const MLFloat16 min_f16 = static_cast<MLFloat16>(min_f32);
+  auto f32_min_input = TestInputDef<float>({}, true, {min_f32});
+  auto f16_min_input = TestInputDef<MLFloat16>({}, true, {min_f16});
+
+  auto f32_model_builder = BuildOpTestCase<float, float>("Clip", {f32_input}, {f32_min_input}, {});
+  auto f16_model_builder = BuildOpTestCase<MLFloat16, MLFloat16>("Clip", {f16_input}, {f16_min_input}, {});
+  int opset = 13;
+  ExpectedEPNodeAssignment expected_ep_assignment = ExpectedEPNodeAssignment::All;
+
+  TestFp16ModelAccuracy(f32_model_builder,
+                        f16_model_builder,
+                        provider_options,
+                        opset,
+                        expected_ep_assignment);
+}
+
 #endif  // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
 }  // namespace test
 }  // namespace onnxruntime
diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc
index 4e1aef2c40b2b..4f294f899c170 100644
--- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc
+++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc
@@ -7,6 +7,7 @@
 
 #include "core/session/onnxruntime_cxx_api.h"
 #include "core/session/onnxruntime_session_options_config_keys.h"
+#include "core/session/onnxruntime_run_options_config_keys.h"
 #include "core/providers/cpu/cpu_provider_factory.h"  // For OrtSessionOptionsAppendExecutionProvider_CPU
 #include "core/session/inference_session.h"
 
@@ -332,19 +333,23 @@ static void CreateModelInMemory(std::unique_ptr<ModelAndBuilder>& result,
 static void RunSessionAndVerify(InferenceSession& session, const RunOptions& run_options, const NameMLValMap& feeds,
                                 const std::vector<std::string>& output_names,
                                 const std::vector<std::vector<int64_t>>& output_shapes,
-                                const std::vector<std::vector<float>>& expected_values) {
-  std::vector<OrtValue> fetches;
-  auto status = session.Run(run_options, feeds, output_names, &fetches);
-  ASSERT_TRUE(status.IsOK());
-
-  for (size_t i = 0; i < fetches.size(); i++) {
-    auto& tensor = fetches[i].Get<Tensor>();
-    TensorShape expected_shape(output_shapes[i]);
-    ASSERT_EQ(expected_shape, tensor.Shape());
-
-    gsl::span<const float> actual = tensor.DataAsSpan<float>();
-    gsl::span<const float> expected(expected_values[i].data(), expected_values[i].size());
-    ASSERT_EQ(expected, actual);
+                                const std::vector<std::vector<float>>& expected_values,
+                                int loop_count = 10) {
+  // Let it run for a while
+  for (int it = 0; it < loop_count; ++it) {
+    std::vector<OrtValue> fetches;
+    auto status = session.Run(run_options, feeds, output_names, &fetches);
+    ASSERT_TRUE(status.IsOK());
+
+    for (size_t i = 0; i < fetches.size(); i++) {
+      auto& tensor = fetches[i].Get<Tensor>();
+      TensorShape expected_shape(output_shapes[i]);
+      ASSERT_EQ(expected_shape, tensor.Shape());
+
+      gsl::span<const float> actual = tensor.DataAsSpan<float>();
+      gsl::span<const float> expected(expected_values[i].data(), expected_values[i].size());
+      ASSERT_EQ(expected, actual);
+    }
   }
 }
 
@@ -404,11 +409,11 @@ TEST_F(QnnCPUBackendTests, MultithreadSessionRun) {
 
   std::vector<std::thread> threads;
   constexpr int num_threads = 5;
-
+  constexpr int loop_count = 10;
   for (int i = 0; i < num_threads; i++) {
     threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts,
                                   model->builder.feeds_, model->builder.output_names_,
-                                  output_shapes, output_values));
+                                  output_shapes, output_values, loop_count));
   }
 
   for (auto& th : threads) {
@@ -484,11 +489,191 @@ TEST_F(QnnHTPBackendTests, MultithreadSessionRun) {
 
   std::vector<std::thread> threads;
   constexpr int num_threads = 5;
+  constexpr int loop_count = 10;
 
   for (int i = 0; i < num_threads; i++) {
     threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts,
                                   model->builder.feeds_, model->builder.output_names_,
-                                  output_shapes, output_values));
+                                  output_shapes, output_values, loop_count));
+  }
+
+  for (auto& th : threads) {
+    th.join();
+  }
+}
+
+// Tests running a single session in multiple threads on the HTP backend with run option to set power config
+TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgSessionRunOption) {
+  std::unique_ptr<ModelAndBuilder> model;
+  std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  std::vector<int64_t> shape = {1, 3, 2};
+  std::vector<std::vector<int64_t>> output_shapes = {shape};
+  std::vector<std::vector<float>> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}};
+
+  CreateModelInMemory(model,
+                      QDQBuildAdd3Tensors<uint8_t>(TestInputDef<float>(shape, false, input_data),
+                                                   TestInputDef<float>(shape, false, input_data),
+                                                   TestInputDef<float>(shape, false, input_data)),
+                      "add3.qdq");
+
+  SessionOptions session_opts;
+  session_opts.session_logid = "logger0";
+
+  InferenceSession session_obj{session_opts, GetEnvironment()};
+  onnxruntime::ProviderOptions options;
+
+#if defined(_WIN32)
+  options["backend_path"] = "QnnHtp.dll";
+#else
+  options["backend_path"] = "libQnnHtp.so";
+#endif
+
+  auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts);
+  EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK());
+
+  auto status = session_obj.Load(model->model_data.data(), static_cast<int>(model->model_data.size()));
+  ASSERT_TRUE(status.IsOK());
+  status = session_obj.Initialize();
+  ASSERT_TRUE(status.IsOK());
+
+  std::vector<std::thread> threads;
+  constexpr int num_threads = 5;
+  constexpr int loop_count = 10;
+
+  std::vector<std::string> perf_modes{
+      "burst", "balanced", "default", "high_performance", "high_power_saver",
+      "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver"};
+
+  size_t post_i = perf_modes.size() - 1;
+  ASSERT_TRUE(post_i > num_threads);
+  for (int i = 0; i < num_threads; ++i, --post_i) {
+    RunOptions run_opts;
+    run_opts.run_tag = session_opts.session_logid;
+    auto rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfMode, perf_modes[i].c_str());
+    ASSERT_TRUE(rt.IsOK());
+    rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, perf_modes[post_i].c_str());
+    ASSERT_TRUE(rt.IsOK());
+
+    threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts,
+                                  model->builder.feeds_, model->builder.output_names_,
+                                  output_shapes, output_values, loop_count));
+  }
+
+  for (auto& th : threads) {
+    th.join();
+  }
+}
+
+// Tests running a single session in multiple threads on the HTP backend with EP option to set default power config
+TEST_F(QnnHTPBackendTests, MultithreadDefaultHtpPowerCfgFromEpOption) {
+  std::unique_ptr<ModelAndBuilder> model;
+  std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  std::vector<int64_t> shape = {1, 3, 2};
+  std::vector<std::vector<int64_t>> output_shapes = {shape};
+  std::vector<std::vector<float>> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}};
+
+  CreateModelInMemory(model,
+                      QDQBuildAdd3Tensors<uint8_t>(TestInputDef<float>(shape, false, input_data),
+                                                   TestInputDef<float>(shape, false, input_data),
+                                                   TestInputDef<float>(shape, false, input_data)),
+                      "add3.qdq");
+
+  SessionOptions session_opts;
+  session_opts.session_logid = "logger0";
+
+  RunOptions run_opts;
+  run_opts.run_tag = session_opts.session_logid;
+
+  InferenceSession session_obj{session_opts, GetEnvironment()};
+  onnxruntime::ProviderOptions options;
+
+#if defined(_WIN32)
+  options["backend_path"] = "QnnHtp.dll";
+#else
+  options["backend_path"] = "libQnnHtp.so";
+#endif
+  options["htp_performance_mode"] = "burst";
+
+  auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts);
+  EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK());
+
+  auto status = session_obj.Load(model->model_data.data(), static_cast<int>(model->model_data.size()));
+  ASSERT_TRUE(status.IsOK());
+  status = session_obj.Initialize();
+  ASSERT_TRUE(status.IsOK());
+
+  std::vector<std::thread> threads;
+  constexpr int num_threads = 5;
+  constexpr int loop_count = 10;
+
+  for (int i = 0; i < num_threads; i++) {
+    threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts,
+                                  model->builder.feeds_, model->builder.output_names_,
+                                  output_shapes, output_values, loop_count));
+  }
+
+  for (auto& th : threads) {
+    th.join();
+  }
+}
+
+// Tests running a single session in multiple threads on the HTP backend with
+// EP option to set default power config + run option to set power config for each run
+TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgDefaultAndRunOption) {
+  std::unique_ptr<ModelAndBuilder> model;
+  std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  std::vector<int64_t> shape = {1, 3, 2};
+  std::vector<std::vector<int64_t>> output_shapes = {shape};
+  std::vector<std::vector<float>> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}};
+
+  CreateModelInMemory(model,
+                      QDQBuildAdd3Tensors<uint8_t>(TestInputDef<float>(shape, false, input_data),
+                                                   TestInputDef<float>(shape, false, input_data),
+                                                   TestInputDef<float>(shape, false, input_data)),
+                      "add3.qdq");
+
+  SessionOptions session_opts;
+  session_opts.session_logid = "logger0";
+
+  InferenceSession session_obj{session_opts, GetEnvironment()};
+  onnxruntime::ProviderOptions options;
+
+#if defined(_WIN32)
+  options["backend_path"] = "QnnHtp.dll";
+#else
+  options["backend_path"] = "libQnnHtp.so";
+#endif
+  options["htp_performance_mode"] = "burst";
+
+  auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts);
+  EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK());
+
+  auto status = session_obj.Load(model->model_data.data(), static_cast<int>(model->model_data.size()));
+  ASSERT_TRUE(status.IsOK());
+  status = session_obj.Initialize();
+  ASSERT_TRUE(status.IsOK());
+
+  std::vector<std::thread> threads;
+  constexpr int num_threads = 5;
+  constexpr int loop_count = 10;
+
+  std::vector<std::string> perf_modes{
+      "burst", "balanced", "default", "high_performance", "high_power_saver",
+      "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver"};
+
+  size_t post_i = perf_modes.size() - 1;
+  ASSERT_TRUE(post_i > num_threads);
+  for (int i = 0; i < num_threads; ++i, --post_i) {
+    RunOptions run_opts;
+    run_opts.run_tag = session_opts.session_logid;
+    auto rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfMode, perf_modes[i].c_str());
+    ASSERT_TRUE(rt.IsOK());
+    rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, perf_modes[post_i].c_str());
+    ASSERT_TRUE(rt.IsOK());
+
+    threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts,
+                                  model->builder.feeds_, model->builder.output_names_,
+                                  output_shapes, output_values, loop_count));
   }
 
   for (auto& th : threads) {
@@ -630,6 +815,25 @@ TEST_F(QnnHTPBackendTests, DISABLED_CastAddHTPAccuracyTest) {
                   ExpectedEPNodeAssignment::All);
 }
 
+// Test float32 model with FP16 precision
+TEST_F(QnnHTPBackendTests, Float32ModelWithFP16PrecisionTest) {
+  ProviderOptions provider_options;
+#if defined(_WIN32)
+  provider_options["backend_path"] = "QnnHtp.dll";
+#else
+  provider_options["backend_path"] = "libQnnHtp.so";
+#endif
+  provider_options["enable_htp_fp16_precision"] = "1";
+
+  auto input_defs = {TestInputDef<float>({1, 2, 2, 2}, false, -10.0f, 10.0f),
+                     TestInputDef<float>({1, 2, 2, 2}, false, -10.0f, 10.0f)};
+  RunQnnModelTest(BuildOpTestCase<float>("Add", input_defs, {}, {}, kOnnxDomain),
+                  provider_options,
+                  13,
+                  ExpectedEPNodeAssignment::All,
+                  0.008f);
+}
+
 #endif  // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
 #endif  // !defined(ORT_MINIMAL_BUILD)
 
diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
index b1f3b52e77553..9eb75d297ef78 100644
--- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
+++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
@@ -123,6 +123,8 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) {
   for (auto& node : ctx_graph.Nodes()) {
     if (node.OpType() == "EPContext") {
       ++ep_context_node_count;
+      // validate the fix for the partition issue relate to QDQ model
+      ASSERT_EQ(node.InputDefs().size(), 1);
     } else {
       ++non_ep_context_node_count;
     }
@@ -463,7 +465,6 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_InvalidGraph) {
 
   InferenceSessionWrapper session_object{so, GetEnvironment()};
 
-  std::string provider_type = kCpuExecutionProvider;
   ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options)));
   ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast<int>(qnn_ctx_model_data.size())));
   // Verify the return status with code INVALID_GRAPH
@@ -486,7 +487,6 @@ std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) {
   auto* graph_output = helper.MakeOutput<float>(shape);
   Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain);
   ep_context_node.AddAttribute("embed_mode", static_cast<int64_t>(0));
-  // The .. in the path will cause INVALID_GRAPH
   ep_context_node.AddAttribute("ep_cache_context", external_bin_path);
   ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0");
   ep_context_node.AddAttribute("source", "QNN");
@@ -651,6 +651,87 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) {
   ASSERT_EQ(std::remove(context_binary_file.c_str()), 0);
 }
 
+// Context binary only contains a single QNN graph, generated context cache model (detached mode) only has 1 EPContext node
+// Create another Onnx model which also reference to the bin file,
+// but the node name is not same with the QNN graph name inside the bin file.
+// This is to support backward compitable for the models generated before the PR that
+// make context generation support multi-partition
+TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphNameInCtx) {
+  ProviderOptions provider_options;
+#if defined(_WIN32)
+  provider_options["backend_path"] = "QnnHtp.dll";
+#else
+  provider_options["backend_path"] = "libQnnHtp.so";
+#endif
+  const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx";
+  std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin";
+  std::remove(context_binary_file.c_str());
+  std::remove(context_bin.string().c_str());
+
+  std::unordered_map<std::string, std::string> session_option_pairs;
+  session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1");
+  session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file);
+  session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0");
+
+  const TestInputDef<float> input_def({1, 2, 3}, false, -10.0f, 10.0f);
+  const std::string op_type = "Atan";
+
+  // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs.
+  // 1st run will generate the Onnx skeleton file + Qnn context cache binary file
+  TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
+                       BuildQDQOpTestCase<uint8_t>(op_type, {input_def}, {}, {}),
+                       provider_options,
+                       14,
+                       ExpectedEPNodeAssignment::All,
+                       QDQTolerance(),
+                       logging::Severity::kERROR,
+                       "",  // context model file path, not required for this inference
+                       session_option_pairs);
+
+  // Check the Onnx skeleton file is generated
+  EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str()));
+  // Check the Qnn context cache binary file is generated
+  EXPECT_TRUE(std::filesystem::exists(context_bin));
+
+  const std::unordered_map<std::string, int> domain_to_version = {{"", 11}, {kMSDomain, 1}};
+  auto& logging_manager = DefaultLoggingManager();
+  onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(),
+                           IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
+                           logging_manager.DefaultLogger());
+  Graph& graph = model.MainGraph();
+  ModelTestBuilder helper(graph);
+  std::vector<int64_t> shape = {1, 2, 3};
+  NodeArg* graph_input = MakeTestInput(helper, TestInputDef<float>(shape, false, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f}));
+  auto* graph_output = helper.MakeOutput<float>(shape);
+  Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain);
+  ep_context_node.AddAttribute("embed_mode", static_cast<int64_t>(0));
+  ep_context_node.AddAttribute("ep_cache_context", context_bin.string());
+  ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0");
+  ep_context_node.AddAttribute("source", "QNNExecutionProvider");
+  helper.SetGraphOutputs();
+  ASSERT_STATUS_OK(graph.Resolve());
+  std::string model_data;
+  model.ToProto().SerializeToString(&model_data);
+
+  // loads and run from Onnx skeleton file + Qnn context cache binary file
+
+  SessionOptions so;
+  so.session_logid = "qnn_ctx_model_logger";
+  RunOptions run_options;
+  run_options.run_tag = so.session_logid;
+
+  InferenceSessionWrapper session_object{so, GetEnvironment()};
+
+  ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options)));
+  ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast<int>(model_data.size())));
+  // Verify the return status with code INVALID_GRAPH
+  ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::OK);
+
+  // Clean up
+  ASSERT_EQ(std::remove(context_binary_file.c_str()), 0);
+  ASSERT_EQ(std::remove(context_bin.string().c_str()), 0);
+}
+
 #endif  // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
 
 }  // namespace test
diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h
index f4febd99ddae7..c0cfe3f0342fd 100644
--- a/onnxruntime/test/providers/qnn/qnn_test_utils.h
+++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h
@@ -467,6 +467,187 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe
   }
 }
 
+/**
+ * Tests the accuracy of a FP16 model on QNN EP by runnning 3 inferences:
+ *
+ * 1. float32 model on CPU EP (baseline)
+ * 2. FP16 model on CPU EP
+ * 3. FP16 model on QNN EP
+ *
+ * This function checks that running the FP16 model on QNN EP (#3) is at least as accurate (+- small tolerance)
+ * as running the FP16 model on CPU EP (#2). We primarily measure accuracy by comparing to the baseline (#1).
+ *
+ * \param f32_model_fn Function that builds the float model (baseline for comparison).
+ * \param f16_model_fn Function that builds the FP16 model (run by CPU EP and QNN EP).
+ * \param qnn_options QNN EP provider options.
+ * \param opset_version The opset version.
+ * \param expected_ep_assignment Describes "which nodes" should be assigned to the EP.
+ * \param tolerance The percent tolerance (as fraction) QNN EP results are allowed to differ from the FP16 model on CPU EP.
+ *                  This tolerance is a percentage of the output range.
+ * \param log_severity The logger's severity setting.
+ */
+inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn,
+                                  const GetTestModelFn& f16_model_fn,
+                                  ProviderOptions qnn_options,
+                                  int opset_version,
+                                  ExpectedEPNodeAssignment expected_ep_assignment,
+                                  float tolerance = 0.004,
+                                  logging::Severity log_severity = logging::Severity::kERROR,
+                                  const std::string& qnn_ctx_model_path = "",
+                                  const std::unordered_map<std::string, std::string>& session_option_pairs = {}) {
+  // Add kMSDomain to cover contrib op like Gelu
+  const std::unordered_map<std::string, int> domain_to_version = {{"", opset_version}, {kMSDomain, 1}};
+
+  auto& logging_manager = DefaultLoggingManager();
+  logging_manager.SetDefaultLoggerSeverity(log_severity);
+
+  // Create float model and serialize it to a string.
+  onnxruntime::Model f32_model("f32_model", false, ModelMetaData(), PathString(),
+                               IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
+                               logging_manager.DefaultLogger());
+  ModelTestBuilder f32_helper(f32_model.MainGraph());
+  std::string f32_model_data;
+  f32_model_fn(f32_helper);
+  f32_helper.SetGraphOutputs();
+  ASSERT_STATUS_OK(f32_model.MainGraph().Resolve());
+  f32_model.ToProto().SerializeToString(&f32_model_data);
+
+  // Run f32 model on CPU EP and collect outputs.
+  std::vector<OrtValue> cpu_f32_outputs;
+  InferenceModel(f32_model_data, "f32_model_logger", {}, ExpectedEPNodeAssignment::All,
+                 f32_helper.feeds_, cpu_f32_outputs);
+  ASSERT_FALSE(cpu_f32_outputs.empty());
+
+  const size_t num_outputs = cpu_f32_outputs.size();
+
+  // Compute output range(s) and quantization params.
+  std::vector<gsl::span<const float>> output_vals;
+  std::vector<int32_t> output_types;
+  output_vals.resize(num_outputs);
+  output_types.resize(num_outputs);
+
+  for (size_t i = 0; i < num_outputs; i++) {
+    auto& tensor = cpu_f32_outputs[i].Get<Tensor>();
+    int32_t elem_type = tensor.GetElementType();
+
+    if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
+      output_vals[i] = tensor.DataAsSpan<float>();
+    }
+
+    output_types[i] = elem_type;
+  }
+
+  // Create FP16 model and serialize it to a string.
+  onnxruntime::Model f16_model("fp16_model", false, ModelMetaData(), PathString(),
+                               IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
+                               logging_manager.DefaultLogger());
+  ModelTestBuilder f16_helper(f16_model.MainGraph());
+  std::string f16_model_data;
+  f16_model_fn(f16_helper);
+  f16_helper.SetGraphOutputs();
+  ASSERT_STATUS_OK(f16_model.MainGraph().Resolve());
+  f16_model.ToProto().SerializeToString(&f16_model_data);
+
+  bool is_qnn_ep = true;
+  TryEnableQNNSaver(qnn_options);
+  std::vector<OrtValue> qnn_f16_outputs;
+  if (!qnn_ctx_model_path.empty()) {
+    onnx::ModelProto model_proto;
+    onnxruntime::Model qnn_ctx_model;
+    // Load the QNN context cache model from path specified
+    ASSERT_STATUS_OK(qnn_ctx_model.Load(ToPathString(qnn_ctx_model_path), model_proto));
+    std::string qnn_ctx_model_data;
+    model_proto.SerializeToString(&qnn_ctx_model_data);
+    // Run QNN context cache model on QNN EP and collect outputs.
+    InferenceModel(qnn_ctx_model_data, "qnn_ctx_model_logger", qnn_options,
+                   expected_ep_assignment, f16_helper.feeds_, qnn_f16_outputs, is_qnn_ep, session_option_pairs);
+  } else {
+    // Run QDQ model on QNN EP and collect outputs.
+    // Only need to apply the extra session options to this QDQ model inference on QNN EP
+    InferenceModel(f16_model_data, "fp16_model_logger", qnn_options, expected_ep_assignment,
+                   f16_helper.feeds_, qnn_f16_outputs, is_qnn_ep, session_option_pairs);
+  }
+
+  if (expected_ep_assignment != ExpectedEPNodeAssignment::None) {
+    // Run QDQ model on CPU EP and collect outputs.
+    std::vector<OrtValue> cpu_f16_outputs;
+    InferenceModel(f16_model_data, "fp16_model_logger", {}, ExpectedEPNodeAssignment::All,
+                   f16_helper.feeds_, cpu_f16_outputs);
+    ASSERT_EQ(cpu_f16_outputs.size(), num_outputs);
+    ASSERT_EQ(qnn_f16_outputs.size(), num_outputs);
+
+    // limit the error message count in case test with large data failed
+    size_t max_error_count = 10;
+    size_t error_count = 0;
+
+    // Compare accuracy of QDQ results with float model.
+    // QNN EP must be at least as accurate as CPU EP when running the QDQ model.
+    const std::string base_output_name = "output_";
+    for (size_t i = 0; i < num_outputs; i++) {
+      std::string debug_output_name = base_output_name + std::to_string(i);
+      auto& cpu_f16_tensor = cpu_f16_outputs[i].Get<Tensor>();
+      auto& qnn_f16_tensor = qnn_f16_outputs[i].Get<Tensor>();
+
+      ASSERT_EQ(cpu_f16_tensor.GetElementType(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
+      ASSERT_EQ(qnn_f16_tensor.GetElementType(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
+      ASSERT_EQ(output_types[i], ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
+
+      const size_t num_vals = output_vals[i].size();
+      gsl::span<const float> cpu_f32_vals = output_vals[i];
+      gsl::span<const MLFloat16> cpu_f16_vals = cpu_f16_tensor.DataAsSpan<MLFloat16>();
+      gsl::span<const MLFloat16> qnn_f16_vals = qnn_f16_tensor.DataAsSpan<MLFloat16>();
+
+      ASSERT_EQ(num_vals, cpu_f16_vals.size());
+      ASSERT_EQ(num_vals, qnn_f16_vals.size());
+
+      float max_f16_cpu_err = 0.0f;
+      float max_f16_qnn_err = 0.0f;
+
+      for (size_t j = 0; j < num_vals && error_count < max_error_count; j++) {
+        const float expected_val = cpu_f32_vals[j];           // f32@CPU_EP val ("ground-truth")
+        const float qnn_f16_val = qnn_f16_vals[j].ToFloat();  // f16@QNN_EP val
+        const float cpu_f16_val = cpu_f16_vals[j].ToFloat();  // f16@CPU_EP val
+
+        // Get errors of f16@CPU_EP and f16@QNN_EP against f32@CPU_EP.
+        const float cpu_relative_err = std::fabs(expected_val - cpu_f16_val) / expected_val;
+        const float qnn_relative_err = std::fabs(expected_val - qnn_f16_val) / expected_val;
+
+        // Also compare the FP16 values against each other.
+        // This is equivalent to abs(f16@QNN_EP - f16@CPU_EP) / output_range
+        const float f16_vals_err = std::fabs(qnn_relative_err - cpu_relative_err);
+
+        // True if f16@QNN_EP is at least as accurate as f16@CPU_EP when compared to expected f32@CPU_EP value.
+        const bool is_as_accurate_as_cpu_ep = qnn_relative_err <= qnn_relative_err;
+
+        // True if the normalized difference between f16@QNN_EP and f16@CPU_EP is within tolerance.
+        const bool f16_vals_diff_within_tolerance = f16_vals_err <= tolerance;
+
+        const bool passed_test = is_as_accurate_as_cpu_ep || f16_vals_diff_within_tolerance;
+        if (!passed_test) {
+          ++error_count;
+        }
+        EXPECT_TRUE(passed_test)
+            << "Inaccuracy detected for output '" << debug_output_name
+            << "', element " << j << ", tolerance=" << (tolerance * 100) << "%"
+            << ".\nExpected val (f32@CPU_EP): " << expected_val << "\n"
+            << "f16@QNN_EP val: " << qnn_f16_val << " (err: " << qnn_relative_err << ")\n"
+            << "f16@CPU_EP val: " << cpu_f16_val << " (err: " << cpu_relative_err << ")\n";
+
+        max_f16_cpu_err = std::max(max_f16_cpu_err, cpu_relative_err);
+        max_f16_qnn_err = std::max(max_f16_qnn_err, qnn_relative_err);
+      }
+
+      if (error_count > 0) {
+        std::cerr << std::endl
+                  << "[WARNING]: Output " << i
+                  << " required larger tolerance to pass accuracy checks" << std::endl
+                  << "Max relative error against f32@CPU_EP = " << max_f16_cpu_err << std::endl
+                  << "Max relative error against f16@CPU_EP = " << max_f16_qnn_err << std::endl;
+      }
+    }
+  }
+}
+
 /**
  * Creates and returns an input in a test model graph. The input's characteristics are defined
  * by the provided input definition.
diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc
index 2f3b0e84a123e..a6422407d79fd 100644
--- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc
+++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc
@@ -1110,6 +1110,61 @@ TEST_F(QnnHTPBackendTests, LpNormalization_u16_rank4) {
                          kOnnxDomain,
                          true);
 }
+
+static GetTestQDQModelFn<uint16_t> BuildQDQConvertAddTestCase(const TestInputDef<float>& input0_def,
+                                                              const TestInputDef<float>& input1_def) {
+  return [input0_def, input1_def](ModelTestBuilder& builder, std::vector<QuantParams<uint16_t>>& output_qparams) {
+    constexpr bool use_contrib_qdq = true;
+
+    // Input0 -> Quantize(u8) -> Dequantize(u8 to float) -> input0_after_qdq
+    NodeArg* input0 = MakeTestInput<float>(builder, input0_def);
+    QuantParams<uint8_t> input0_u8_qparams = GetTestInputQuantParams<uint8_t>(input0_def);
+    NodeArg* input0_after_qdq = AddQDQNodePair<uint8_t>(builder, input0, input0_u8_qparams.scale,
+                                                        input0_u8_qparams.zero_point, use_contrib_qdq);
+
+    // input0_after_qdq -> Quantize(u16) -> Dequantize(u16 to float)
+    QuantParams<uint16_t> input0_u16_qparams = GetTestInputQuantParams<uint16_t>(input0_def);
+    NodeArg* input0_after_convert = AddQDQNodePair<uint16_t>(builder, input0_after_qdq, input0_u16_qparams.scale,
+                                                             input0_u16_qparams.zero_point, use_contrib_qdq);
+
+    // Input1 -> Quantize(u16) -> Dequantize(u16 to float) -> input1_after_qdq
+    NodeArg* input1 = MakeTestInput<float>(builder, input1_def);
+    QuantParams<uint16_t> input1_qparams = GetTestInputQuantParams<uint16_t>(input1_def);
+    NodeArg* input1_after_qdq = AddQDQNodePair<uint16_t>(builder, input1, input1_qparams.scale,
+                                                         input1_qparams.zero_point, use_contrib_qdq);
+
+    // Add op -> op_output
+    auto* op_output = builder.MakeIntermediate();
+    builder.AddNode("Add", {input0_after_convert, input1_after_qdq}, {op_output});
+
+    // op_output -> Q -> DQ -> output
+    AddQDQNodePairWithOutputAsGraphOutput<uint16_t>(builder, op_output, output_qparams[0].scale,
+                                                    output_qparams[0].zero_point, use_contrib_qdq);
+  };
+}
+
+// Test quantization type conversion (mixed precision) with Add.
+// First input is converted from uint8_t to uint16_t.
+TEST_F(QnnHTPBackendTests, Add_U8_U16_Convert) {
+  std::vector<float> input0_data = GetFloatDataInRange(-10.0f, 10.0f, 8);
+  std::vector<float> input1_data = GetFloatDataInRange(-20.0f, 20.0f, 8);
+  TestInputDef<float> input0_def({1, 2, 2, 2}, false, input0_data);
+  TestInputDef<float> input1_def({1, 2, 2, 2}, false, input1_data);
+
+  ProviderOptions provider_options;
+#if defined(_WIN32)
+  provider_options["backend_path"] = "QnnHtp.dll";
+#else
+  provider_options["backend_path"] = "libQnnHtp.so";
+#endif
+
+  TestQDQModelAccuracy(BuildOpTestCase<float>("Add", {input0_def, input1_def}, {}, {}, kOnnxDomain),
+                       BuildQDQConvertAddTestCase(input0_def, input1_def),
+                       provider_options,
+                       18,
+                       ExpectedEPNodeAssignment::All);
+}
+
 #endif  // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
 
 }  // namespace test
diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py
index c48b07422d452..e441230537410 100644
--- a/onnxruntime/test/python/onnx_backend_test_series.py
+++ b/onnxruntime/test/python/onnx_backend_test_series.py
@@ -140,8 +140,8 @@ def create_backend_test(test_name=None):
         if backend.supports_device("OPENVINO_CPU_FP16"):
             current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_CPU_FP16")
 
-        if backend.supports_device("OPENVINO_NPU_FP16"):
-            current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_NPU_FP16")
+        if backend.supports_device("OPENVINO_NPU"):
+            current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_NPU")
 
         if backend.supports_device("OPENVINO"):
             current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_opset18")
diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py
index 91b6c71e735a8..ab56f3fa0f37f 100644
--- a/onnxruntime/test/python/onnxruntime_test_python.py
+++ b/onnxruntime/test/python/onnxruntime_test_python.py
@@ -559,6 +559,16 @@ def test_get_and_set_option_with_values(option_name, option_values):
 
                 test_get_and_set_option_with_values("enable_hip_graph", ["1", "0"])
 
+                # test for user_compute_stream
+                option = options["ROCMExecutionProvider"]
+                option["user_compute_stream"] = "1"
+                sess.set_providers(["ROCMExecutionProvider"], [option])
+                new_options = sess.get_provider_options()
+                new_option = new_options["ROCMExecutionProvider"]
+                self.assertEqual(new_option["user_compute_stream"], "1")
+                # set user_compute_stream will set has_user_compute_stream to 1 too
+                self.assertEqual(new_option["has_user_compute_stream"], "1")
+
             run_rocm_options_test()
 
     def test_invalid_set_providers(self):
diff --git a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py
index c4e13e773535d..ce04dff2aecb0 100644
--- a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py
+++ b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py
@@ -84,6 +84,7 @@ def test_select_ep_to_run_cuda_graph(self):
         elif "CUDAExecutionProvider" in onnxrt.get_available_providers():
             providers = [("CUDAExecutionProvider", {"enable_cuda_graph": True})]
             self.run_model_with_cuda_graph(providers)
+            self.run_model_with_cuda_graph_annotation(providers)
 
     def run_model_with_cuda_graph(self, providers):
         INPUT_SIZE = 1280  # noqa: N806
@@ -100,13 +101,15 @@ def run_model_with_cuda_graph(self, providers):
         io_binding.bind_ortvalue_input("X", x_ortvalue)
         io_binding.bind_ortvalue_output("Y", y_ortvalue)
 
+        ro = onnxrt.RunOptions()
+
         # One regular run for the necessary memory allocation and cuda graph capturing
-        session.run_with_iobinding(io_binding)
+        session.run_with_iobinding(io_binding, ro)
         expected_y = np.array([[5.0], [11.0], [17.0]] * INPUT_SIZE, dtype=np.float32)
         np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05)
 
         # After capturing, CUDA graph replay happens from this Run onwards
-        session.run_with_iobinding(io_binding)
+        session.run_with_iobinding(io_binding, ro)
         np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05)
 
         # Update input and then replay CUDA graph
@@ -116,7 +119,7 @@ def run_model_with_cuda_graph(self, providers):
                 dtype=np.float32,
             )
         )
-        session.run_with_iobinding(io_binding)
+        session.run_with_iobinding(io_binding, ro)
         np.testing.assert_allclose(
             np.array([[50.0], [110.0], [170.0]] * INPUT_SIZE, dtype=np.float32),
             y_ortvalue.numpy(),
@@ -124,6 +127,58 @@ def run_model_with_cuda_graph(self, providers):
             atol=1e-05,
         )
 
+    def run_model_with_cuda_graph_annotation(self, providers):
+        INPUT_SIZE = 1280  # noqa: N806
+
+        x_base = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]
+        y_base = [[0.0], [0.0], [0.0], [0.0]]
+        expected_y_base = [[5.0], [11.0], [17.0], [23.0]]
+
+        x_base_mul_10 = [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0], [70.0, 80.0]]
+        expected_y_base_mul_10 = [[50.0], [110.0], [170.0], [230.0]]
+
+        test_num = 4
+
+        x_ortvalues = []
+        y_ortvalues = []
+        for i in range(test_num):
+            x = np.array(x_base[: i + 1][:] * INPUT_SIZE, dtype=np.float32)
+            y = np.array(y_base[: i + 1][:] * INPUT_SIZE, dtype=np.float32)
+            x_ortvalues.append(onnxrt.OrtValue.ortvalue_from_numpy(x, "cuda", 0))
+            y_ortvalues.append(onnxrt.OrtValue.ortvalue_from_numpy(y, "cuda", 0))
+
+        onnxrt.set_default_logger_severity(0)
+        session = onnxrt.InferenceSession(get_name("matmul_2.onnx"), providers=providers)
+        io_bindings = [session.io_binding()] * test_num
+        ro = onnxrt.RunOptions()
+
+        # Regular run to capture CUDA graph
+        for i in range(test_num):
+            io_bindings[i].bind_ortvalue_input("X", x_ortvalues[i])
+            io_bindings[i].bind_ortvalue_output("Y", y_ortvalues[i])
+            # TODO: Temporarily remove the default cuda graph capture test for the first regular run
+            # because it fails on a training CI. Need to investigate the root cause.
+            ro.add_run_config_entry("gpu_graph_id", str(i + 1))
+            io_bindings[i].synchronize_inputs()
+            session.run_with_iobinding(io_bindings[i], ro)
+            io_bindings[i].synchronize_outputs()
+            expected_y = np.array(expected_y_base[: i + 1][:] * INPUT_SIZE, dtype=np.float32)
+            np.testing.assert_allclose(expected_y, y_ortvalues[i].numpy(), rtol=1e-05, atol=1e-05)
+
+        del ro
+        ro = onnxrt.RunOptions()
+
+        # After capturing, CUDA graph replay happens from this Run onwards
+        for i in range(test_num):
+            # Update input and then replay CUDA graph
+            x_ortvalues[i].update_inplace(np.array(x_base_mul_10[: i + 1][:] * INPUT_SIZE, dtype=np.float32))
+            ro.add_run_config_entry("gpu_graph_id", str(i + 1))
+            io_bindings[i].synchronize_inputs()
+            session.run_with_iobinding(io_bindings[i], ro)
+            io_bindings[i].synchronize_outputs()
+            expected_y = np.array(expected_y_base_mul_10[: i + 1][:] * INPUT_SIZE, dtype=np.float32)
+            np.testing.assert_allclose(expected_y, y_ortvalues[i].numpy(), rtol=1e-05, atol=1e-05)
+
     def test_arena_with_cuda_graph(self):
         if "CUDAExecutionProvider" in onnxrt.get_available_providers():
             # To test cuda graph catpure, we set Arena extend strategy to be SameAsRequested so as to detect any
diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py
index c1bbb49f10c7e..b30282f2ab41f 100644
--- a/onnxruntime/test/python/quantization/op_test_utils.py
+++ b/onnxruntime/test/python/quantization/op_test_utils.py
@@ -358,6 +358,7 @@ def check_model_correctness(
         model_onnx = onnx.load(f)
     ops_set = set(node.op_type for node in model_onnx.graph.node)
     check_reference_evaluator = not (ops_set & {"EmbedLayerNormalization", "Conv", "Attention", "Transpose"})
+    check_target_evaluator = False
 
     with open(model_path_to_check, "rb") as f:
         model_check = onnx.load(f)
@@ -413,7 +414,7 @@ def check_model_correctness(
             check_sign_f8_quantization(model_path_origin, model_path_to_check)
 
     # Verifies the expected outputs.
-    if check_reference_evaluator and onnx_recent_enough:
+    if check_target_evaluator and onnx_recent_enough:
         if op_matmul:
             reference_new_ops = [QLinearMatMul]
         else:
diff --git a/onnxruntime/test/python/quantization/test_fusions.py b/onnxruntime/test/python/quantization/test_fusions.py
new file mode 100644
index 0000000000000..bea110e566fb9
--- /dev/null
+++ b/onnxruntime/test/python/quantization/test_fusions.py
@@ -0,0 +1,401 @@
+#!/usr/bin/env python
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import math
+import unittest
+
+import numpy as np
+import onnx
+
+import onnxruntime
+from onnxruntime.quantization.execution_providers.qnn.fusion_lpnorm import FusionLpNormalization
+from onnxruntime.quantization.fusions import FusionGelu, FusionLayerNormalization
+from onnxruntime.quantization.onnx_model import ONNXModel
+
+
+class TestFusions(unittest.TestCase):
+    def check_fused_model_correctness(self, orig_model, fused_model, inputs, rtol=1e-7, atol=0):
+        """
+        Checks that the output of the fused model matches the output of the original model.
+        """
+        orig_session = onnxruntime.InferenceSession(orig_model.SerializeToString(), providers=["CPUExecutionProvider"])
+        orig_results = orig_session.run(None, inputs)
+
+        fused_session = onnxruntime.InferenceSession(
+            fused_model.SerializeToString(), providers=["CPUExecutionProvider"]
+        )
+        fused_results = fused_session.run([], inputs)
+
+        self.assertEqual(len(orig_results), len(fused_results), "Number of outputs for fused model differs")
+        for idx, expected_output in enumerate(orig_results):
+            actual_output = fused_results[idx]
+            np.testing.assert_allclose(
+                expected_output,
+                actual_output,
+                rtol=rtol,
+                atol=atol,
+                err_msg=f"Fused model output {idx} differs",
+            )
+
+    def build_erf_sequence_1_model(self, shape):
+        """
+        Erf sequence that fuses into Gelu:
+           +-------Mul(0.5)---------------------+
+           |                                    |
+           |                                    v
+        [root] --> Div -----> Erf  --> Add --> Mul -->
+                  (B=1.4142...)       (1)
+
+        This method builds 2 of these Erf sequences:
+
+        [root] -> ERF_SEQUENCE1 -> ERF_SEQUENCE2 -> output
+        """
+        root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
+        output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)
+        one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const")
+        half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const")
+        root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const")
+
+        # First Erf sequence
+        mul0_node = onnx.helper.make_node("Mul", ["root", "half_const"], ["mul0_out"])
+        div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"])
+        erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"])
+        add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"])
+        mul1_node = onnx.helper.make_node("Mul", ["add_out", "mul0_out"], ["seq1_output"])
+
+        # Second Erf sequence
+        mul0_node_dup = onnx.helper.make_node("Mul", ["seq1_output", "half_const"], ["mul0_out_dup"])
+        div_node_dup = onnx.helper.make_node("Div", ["seq1_output", "root2_const"], ["div_out_dup"])
+        erf_node_dup = onnx.helper.make_node("Erf", ["div_out_dup"], ["erf_out_dup"])
+        add_node_dup = onnx.helper.make_node("Add", ["erf_out_dup", "one_const"], ["add_out_dup"])
+        mul1_node_dup = onnx.helper.make_node("Mul", ["add_out_dup", "mul0_out_dup"], ["output"])
+
+        graph = onnx.helper.make_graph(
+            [
+                mul0_node,
+                div_node,
+                erf_node,
+                add_node,
+                mul1_node,
+                mul0_node_dup,
+                div_node_dup,
+                erf_node_dup,
+                add_node_dup,
+                mul1_node_dup,
+            ],
+            "two_erf_sequences",
+            [root_inp],
+            [output],
+            initializer=[one_const, half_const, root2_const],
+        )
+        opset_imports = [
+            onnx.helper.make_opsetid("", 18),
+            onnx.helper.make_opsetid("com.microsoft", 1),
+        ]
+        model = onnx.helper.make_model(graph, opset_imports=opset_imports)
+        return ONNXModel(model)
+
+    def build_erf_sequence_2_model(self, shape):
+        """
+           +------------------------------------+
+           |                                    |
+           |                                    v
+        [root] --> Div -----> Erf  --> Add --> Mul -->Mul -->
+                  (B=1.4142...)       (1)            (0.5)
+
+        """
+        root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
+        output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)
+        one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const")
+        half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const")
+        root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const")
+
+        div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"])
+        erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"])
+        add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"])
+        mul0_node = onnx.helper.make_node("Mul", ["add_out", "root"], ["mul0_out"])
+        mul1_node = onnx.helper.make_node("Mul", ["mul0_out", "half_const"], ["output"])
+
+        graph = onnx.helper.make_graph(
+            [div_node, erf_node, add_node, mul0_node, mul1_node],
+            "erf_sequence_2",
+            [root_inp],
+            [output],
+            initializer=[one_const, half_const, root2_const],
+        )
+        opset_imports = [
+            onnx.helper.make_opsetid("", 18),
+            onnx.helper.make_opsetid("com.microsoft", 1),
+        ]
+        model = onnx.helper.make_model(graph, opset_imports=opset_imports)
+        return ONNXModel(model)
+
+    def build_erf_sequence_3_model(self, shape):
+        """
+           +------------------------------------------+
+           |                                          |
+           |                                          v
+        [root] --> Div -----> Erf  --> Add --> Mul -->Mul
+                  (B=1.4142...)       (A=1)   (A=0.5)
+
+        """
+        root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
+        output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)
+        one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const")
+        half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const")
+        root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const")
+
+        div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"])
+        erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"])
+        add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"])
+        mul0_node = onnx.helper.make_node("Mul", ["add_out", "half_const"], ["mul0_out"])
+        mul1_node = onnx.helper.make_node("Mul", ["mul0_out", "root"], ["output"])
+
+        graph = onnx.helper.make_graph(
+            [div_node, erf_node, add_node, mul0_node, mul1_node],
+            "erf_sequence_3",
+            [root_inp],
+            [output],
+            initializer=[one_const, half_const, root2_const],
+        )
+        opset_imports = [
+            onnx.helper.make_opsetid("", 18),
+            onnx.helper.make_opsetid("com.microsoft", 1),
+        ]
+        model = onnx.helper.make_model(graph, opset_imports=opset_imports)
+        return ONNXModel(model)
+
+    def build_erf_sequence_4_model(self, shape):
+        """
+           +----------------------------------------------+
+           |                                              |
+           |                                              v
+        [root] --> Mul -----> Erf    -->   Add --> Mul -->Mul
+                   (A=0.7071067690849304)  (B=1)  (B=0.5)
+
+        """
+        root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
+        output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)
+        one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const")
+        half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const")
+        frac_const = onnx.numpy_helper.from_array(np.array(0.7071067690849304, dtype=np.float32), "frac_const")
+
+        mul0_node = onnx.helper.make_node("Mul", ["root", "frac_const"], ["mul0_out"])
+        erf_node = onnx.helper.make_node("Erf", ["mul0_out"], ["erf_out"])
+        add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"])
+        mul1_node = onnx.helper.make_node("Mul", ["add_out", "half_const"], ["mul1_out"])
+        mul2_node = onnx.helper.make_node("Mul", ["mul1_out", "root"], ["output"])
+
+        graph = onnx.helper.make_graph(
+            [mul0_node, erf_node, add_node, mul1_node, mul2_node],
+            "erf_sequence_4",
+            [root_inp],
+            [output],
+            initializer=[one_const, half_const, frac_const],
+        )
+        opset_imports = [
+            onnx.helper.make_opsetid("", 18),
+            onnx.helper.make_opsetid("com.microsoft", 1),
+        ]
+        model = onnx.helper.make_model(graph, opset_imports=opset_imports)
+        return ONNXModel(model)
+
+    def build_reduce_mean_sequence_model(self, shape, scale_val, bias_val, axis=-1):
+        """
+            +----------------------+
+            |                      |
+            |                      v
+        [Root] --> ReduceMean -->  Sub  --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
+                   (axis=2 or -1)  |      (Y=2)   (axis=2 or -1)  (E-6 or E-12 or 0) ^       ^       ^
+                                   |                                                 |       |       |
+                                   +-------------------------------------------------+    [Scale]  [Bias]
+        """
+        root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
+        output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)
+        scale_const = onnx.numpy_helper.from_array(np.array(scale_val, dtype=np.float32), "scale_const")
+        bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const")
+        axes_const = onnx.numpy_helper.from_array(np.array([axis], dtype=np.int64), "axes_const")
+        two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const")
+        eps_const = onnx.numpy_helper.from_array(np.array(1.0e-8, dtype=np.float32), "eps_const")
+
+        rm0_node = onnx.helper.make_node("ReduceMean", ["root", "axes_const"], ["rm0_out"])
+        sub_node = onnx.helper.make_node("Sub", ["root", "rm0_out"], ["sub_out"])
+        pow_node = onnx.helper.make_node("Pow", ["sub_out", "two_const"], ["pow_out"])
+        rm1_node = onnx.helper.make_node("ReduceMean", ["pow_out", "axes_const"], ["rm1_out"])
+        add0_node = onnx.helper.make_node("Add", ["rm1_out", "eps_const"], ["add0_out"])
+        sqrt_node = onnx.helper.make_node("Sqrt", ["add0_out"], ["sqrt_out"])
+        div_node = onnx.helper.make_node("Div", ["sub_out", "sqrt_out"], ["div_out"])
+        mul_node = onnx.helper.make_node("Mul", ["div_out", "scale_const"], ["mul_out"])
+        add1_node = onnx.helper.make_node("Add", ["mul_out", "bias_const"], ["output"])
+
+        graph = onnx.helper.make_graph(
+            [rm0_node, sub_node, pow_node, rm1_node, add0_node, sqrt_node, div_node, mul_node, add1_node],
+            "reduce_mean_sequence",
+            [root_inp],
+            [output],
+            initializer=[scale_const, bias_const, axes_const, two_const, eps_const],
+        )
+        opset_imports = [
+            onnx.helper.make_opsetid("", 18),
+        ]
+        model = onnx.helper.make_model(graph, opset_imports=opset_imports)
+        return ONNXModel(model)
+
+    def build_reduce_l2_sequence_model(self, shape, epsilon_val, axis=-1):
+        """
+        [root] --> ReduceL2 -----> Clip  --> Expand ----> Div -->
+           |      (axis=-1)    (min=epsilon) (shape=root)  ^
+           |   (keepdims=True)                             |
+           |                                               |
+           +-----------------------------------------------+
+        """
+        root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
+        output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)
+        axes_const = onnx.numpy_helper.from_array(np.array([axis], dtype=np.int64), "axes_const")
+        eps_const = onnx.numpy_helper.from_array(np.array(epsilon_val, dtype=np.float32), "eps_const")
+        shape_const = onnx.numpy_helper.from_array(np.array(list(shape), dtype=np.int64), "shape_const")
+
+        rl2_node = onnx.helper.make_node("ReduceL2", ["root", "axes_const"], ["rl2_out"], keepdims=1)
+        clip_node = onnx.helper.make_node("Clip", ["rl2_out", "eps_const"], ["clip_out"])
+        expand_node = onnx.helper.make_node("Expand", ["clip_out", "shape_const"], ["expand_out"])
+        div_node = onnx.helper.make_node("Div", ["root", "expand_out"], ["output"])
+
+        graph = onnx.helper.make_graph(
+            [rl2_node, clip_node, expand_node, div_node],
+            "reducel2_sequence",
+            [root_inp],
+            [output],
+            initializer=[axes_const, eps_const, shape_const],
+        )
+        opset_imports = [
+            onnx.helper.make_opsetid("", 18),
+        ]
+        model = onnx.helper.make_model(graph, opset_imports=opset_imports)
+        return ONNXModel(model)
+
+    def test_fuse_erf_to_gelu_1(self):
+        shape = (1, 2, 3)
+        model = self.build_erf_sequence_1_model(shape)
+        orig_model = onnx.ModelProto()
+        orig_model.CopyFrom(model.model)
+
+        # Check that fusion simplified model to 2 Gelu nodes.
+        modified = FusionGelu(model).apply()
+        self.assertTrue(modified)
+        self.assertEqual(len(model.model.graph.node), 2)
+
+        gelu_node_0 = model.model.graph.node[0]
+        gelu_node_1 = model.model.graph.node[1]
+        self.assertEqual(gelu_node_0.op_type, "Gelu")
+        self.assertEqual(gelu_node_1.op_type, "Gelu")
+
+        self.assertTrue(gelu_node_0.name)
+        self.assertTrue(gelu_node_1.name)
+        self.assertNotEqual(gelu_node_0.name, gelu_node_1.name)  # Generated names should not be equal
+
+        # Check that fusion is equivalent to original Erf model.
+        inputs = {"root": np.ones(shape, dtype=np.float32)}
+        self.check_fused_model_correctness(orig_model, model.model, inputs)
+
+    def test_fuse_erf_to_gelu_2(self):
+        shape = (1, 2, 3)
+        model = self.build_erf_sequence_2_model(shape)
+        orig_model = onnx.ModelProto()
+        orig_model.CopyFrom(model.model)
+
+        # Check that fusion simplified model to 1 Gelu node.
+        modified = FusionGelu(model).apply()
+        self.assertTrue(modified)
+        self.assertEqual(len(model.model.graph.node), 1)
+
+        gelu_node = model.model.graph.node[0]
+        self.assertEqual(gelu_node.op_type, "Gelu")
+        self.assertTrue(gelu_node.name)
+
+        # Check that fusion is equivalent to original Erf model.
+        inputs = {"root": np.ones(shape, dtype=np.float32)}
+        self.check_fused_model_correctness(orig_model, model.model, inputs)
+
+    def test_fuse_erf_to_gelu_3(self):
+        shape = (1, 2, 3)
+        model = self.build_erf_sequence_3_model(shape)
+        orig_model = onnx.ModelProto()
+        orig_model.CopyFrom(model.model)
+
+        # Check that fusion simplified model to 1 Gelu node.
+        modified = FusionGelu(model).apply()
+        self.assertTrue(modified)
+        self.assertEqual(len(model.model.graph.node), 1)
+
+        gelu_node = model.model.graph.node[0]
+        self.assertEqual(gelu_node.op_type, "Gelu")
+        self.assertTrue(gelu_node.name)
+
+        # Check that fusion is equivalent to original Erf model.
+        inputs = {"root": np.ones(shape, dtype=np.float32)}
+        self.check_fused_model_correctness(orig_model, model.model, inputs)
+
+    def test_fuse_erf_to_gelu_4(self):
+        shape = (1, 2, 3)
+        model = self.build_erf_sequence_4_model(shape)
+        orig_model = onnx.ModelProto()
+        orig_model.CopyFrom(model.model)
+
+        # Check that fusion simplified model to 1 Gelu node.
+        modified = FusionGelu(model).apply()
+        self.assertTrue(modified)
+        self.assertEqual(len(model.model.graph.node), 1)
+
+        gelu_node = model.model.graph.node[0]
+        self.assertEqual(gelu_node.op_type, "Gelu")
+        self.assertTrue(gelu_node.name)
+
+        # Check that fusion is equivalent to original Erf model.
+        inputs = {"root": np.ones(shape, dtype=np.float32)}
+        self.check_fused_model_correctness(orig_model, model.model, inputs)
+
+    def test_fuse_reduce_l2_to_lpnorm(self):
+        shape = (1, 2, 3)
+        model = self.build_reduce_l2_sequence_model(shape, 1e-12, axis=-1)
+        orig_model = onnx.ModelProto()
+        orig_model.CopyFrom(model.model)
+
+        # Check that fusion simplified model to 1 LpNormalization node.
+        modified = FusionLpNormalization(model).apply()
+        self.assertTrue(modified)
+        self.assertEqual(len(model.model.graph.node), 1)
+
+        lpnorm_node = model.model.graph.node[0]
+        self.assertEqual(lpnorm_node.op_type, "LpNormalization")
+        self.assertTrue(lpnorm_node.name)
+
+        # LpNorm's p attribute should be set to 2
+        p_attr = next(attr for attr in lpnorm_node.attribute if attr.name == "p")
+        self.assertEqual(p_attr.i, 2)
+
+    def test_fuse_reduce_mean_to_layer_norm(self):
+        shape = (1, 2, 3)
+        model = self.build_reduce_mean_sequence_model(shape, [2.0, 2.0, 2.0], [1.0, 1.0, 1.0], axis=-1)
+        orig_model = onnx.ModelProto()
+        orig_model.CopyFrom(model.model)
+
+        # Check that fusion simplified model to 1 LayerNormalization node.
+        modified = FusionLayerNormalization(model).apply()
+        self.assertTrue(modified)
+        self.assertEqual(len(model.model.graph.node), 1)
+
+        layer_norm_node = model.model.graph.node[0]
+        self.assertEqual(layer_norm_node.op_type, "LayerNormalization")
+        self.assertTrue(layer_norm_node.name)
+
+        # Check that fused model is equivalent to original model.
+        inputs = {"root": np.ones(shape, dtype=np.float32)}
+        self.check_fused_model_correctness(orig_model, model.model, inputs)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py
index 73dae08af8ece..88e5052db4e2e 100644
--- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py
+++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py
@@ -125,7 +125,10 @@ def quant_test(
         from onnxruntime.quantization import matmul_4bits_quantizer
 
         model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
-        quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric)
+        quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig(
+            block_size=block_size, is_symmetric=is_symmetric
+        )
+        quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config)
         quant.process()
         quant.model.save_model_to_file(model_int4_path, False)
 
@@ -165,6 +168,9 @@ def quant_test_with_algo(
         elif algorithm == "GPTQ":
             # test GPTQ algorithm
             algo_config = matmul_4bits_quantizer.GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader)
+        elif algorithm == "HQQ":
+            # test HQQ algorithm
+            algo_config = matmul_4bits_quantizer.HQQWeightOnlyQuantConfig(block_size=block_size)
 
         model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
         quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config)
@@ -227,6 +233,17 @@ def test_quantize_matmul_int4_using_gptq_algo(self):
         data_reader = self.input_feeds(1, {"input": [100, 52]})
         self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 32, False)
 
+    @unittest.skipIf(
+        find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
+    )
+    def test_quantize_matmul_int4_using_hqq_algo(self):
+        if not find_spec("torch"):
+            self.skipTest("skip test_hqq_quant since torch is not installed")
+        model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
+        self.construct_model_matmul(model_fp32_path, symmetric=False)
+        data_reader = self.input_feeds(1, {"input": [100, 52]})
+        self.quant_test_with_algo("HQQ", model_fp32_path, data_reader, 32, False)
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/onnxruntime/test/python/quantization/test_op_pad.py b/onnxruntime/test/python/quantization/test_op_pad.py
index 03e29dd64f8a7..291bf42405d58 100644
--- a/onnxruntime/test/python/quantization/test_op_pad.py
+++ b/onnxruntime/test/python/quantization/test_op_pad.py
@@ -222,12 +222,8 @@ def verify_quantize_with_pad_mode(
         activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8
         activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8"
         weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8"
-        model_i8_path = "qop_pad_{}_i8_{}{}_{}{}.onnx".format(
-            quantize_mode,
-            tag_pad_mode,
-            tag_constant_value,
-            activation_type_str,
-            weight_type_str,
+        model_i8_path = (
+            f"qop_pad_{quantize_mode}_i8_{tag_pad_mode}{tag_constant_value}_{activation_type_str}{weight_type_str}.onnx"
         )
         data_reader.rewind()
         self.quantize_model(
diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py
index 223f405e8947a..9e7a4a125121d 100644
--- a/onnxruntime/test/python/quantization/test_qdq.py
+++ b/onnxruntime/test/python/quantization/test_qdq.py
@@ -20,7 +20,7 @@
     create_clip_node,
 )
 
-from onnxruntime.quantization import QDQQuantizer, QuantFormat, QuantizationMode, QuantType, quantize_static
+from onnxruntime.quantization import QDQQuantizer, QuantFormat, QuantType, quantize_static
 from onnxruntime.quantization.calibrate import TensorData
 
 
@@ -87,14 +87,11 @@ def td(vals):
 
         op_types_to_quantize = ["Add"]
 
-        mode = QuantizationMode.QLinearOps
         model = onnx.load_model(test_model_path)
         quantizer = QDQQuantizer(
             model,
             True,  # per_channel
             False,  # reduce_range
-            mode,
-            True,  # static
             QuantType.QInt8,  # weight_type
             QuantType.QInt8,  # activation_type
             compute_data,
@@ -191,14 +188,11 @@ def td(vals):
 
         op_types_to_quantize = ["Add", "MatMul"]
 
-        mode = QuantizationMode.QLinearOps
         model = onnx.load_model(test_model_path)
         quantizer = QDQQuantizer(
             model,
             True,  # per_channel
             False,  # reduce_range
-            mode,
-            True,  # static
             QuantType.QInt8,  # weight_type
             QuantType.QInt8,  # activation_type
             compute_data,
diff --git a/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py
new file mode 100644
index 0000000000000..6503b3223b828
--- /dev/null
+++ b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py
@@ -0,0 +1,263 @@
+#!/usr/bin/env python
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import math
+import unittest
+from pathlib import Path
+
+import numpy as np
+import onnx
+
+import onnxruntime
+from onnxruntime.quantization.execution_providers.qnn import qnn_preprocess_model
+from onnxruntime.quantization.quant_utils import model_has_external_data, ms_domain
+
+
+class TestQnnPreprocessModel(unittest.TestCase):
+    def build_model(self, shape, scale_val, bias_val):
+        """
+        Build a model that supports 3 kinds of fusions:
+        - Erf sequence to Gelu
+        - ReduceL2 sequence to LpNormalization
+        - ReduceMean sequence to LayerNormalization
+        """
+        root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape)
+        output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape)
+
+        # Erf sequence
+        one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const")
+        half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const")
+        root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const")
+
+        e_mul0_node = onnx.helper.make_node("Mul", ["root", "half_const"], ["e_mul0_out"])
+        e_div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["e_div_out"])
+        e_erf_node = onnx.helper.make_node("Erf", ["e_div_out"], ["e_erf_out"])
+        e_add_node = onnx.helper.make_node("Add", ["e_erf_out", "one_const"], ["e_add_out"])
+        e_mul1_node = onnx.helper.make_node("Mul", ["e_add_out", "e_mul0_out"], ["erf_seq_output"])
+
+        # ReduceL2 sequence
+        axes_const = onnx.numpy_helper.from_array(np.array([-1], dtype=np.int64), "axes_const")
+        eps_const = onnx.numpy_helper.from_array(np.array(1e-12, dtype=np.float32), "eps_const")
+        shape_const = onnx.numpy_helper.from_array(np.array(list(shape), dtype=np.int64), "shape_const")
+
+        l2_rl2_node = onnx.helper.make_node("ReduceL2", ["erf_seq_output", "axes_const"], ["l2_rl2_out"], keepdims=1)
+        l2_clip_node = onnx.helper.make_node("Clip", ["l2_rl2_out", "eps_const"], ["l2_clip_out"])
+        l2_expand_node = onnx.helper.make_node("Expand", ["l2_clip_out", "shape_const"], ["l2_expand_out"])
+        l2_div_node = onnx.helper.make_node("Div", ["erf_seq_output", "l2_expand_out"], ["l2_seq_output"])
+
+        # ReduceMean sequence
+        scale_const = onnx.numpy_helper.from_array(np.array(scale_val, dtype=np.float32), "scale_const")
+        bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const")
+        two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const")
+
+        m_rm0_node = onnx.helper.make_node("ReduceMean", ["l2_seq_output", "axes_const"], ["m_rm0_out"])
+        m_sub_node = onnx.helper.make_node("Sub", ["l2_seq_output", "m_rm0_out"], ["m_sub_out"])
+        m_pow_node = onnx.helper.make_node("Pow", ["m_sub_out", "two_const"], ["m_pow_out"])
+        m_rm1_node = onnx.helper.make_node("ReduceMean", ["m_pow_out", "axes_const"], ["m_rm1_out"])
+        m_add0_node = onnx.helper.make_node("Add", ["m_rm1_out", "eps_const"], ["m_add0_out"])
+        m_sqrt_node = onnx.helper.make_node("Sqrt", ["m_add0_out"], ["m_sqrt_out"])
+        m_div_node = onnx.helper.make_node("Div", ["m_sub_out", "m_sqrt_out"], ["m_div_out"])
+        m_mul_node = onnx.helper.make_node("Mul", ["m_div_out", "scale_const"], ["m_mul_out"])
+        m_add1_node = onnx.helper.make_node("Add", ["m_mul_out", "bias_const"], ["output"])
+
+        graph = onnx.helper.make_graph(
+            [
+                e_mul0_node,
+                e_div_node,
+                e_erf_node,
+                e_add_node,
+                e_mul1_node,
+                l2_rl2_node,
+                l2_clip_node,
+                l2_expand_node,
+                l2_div_node,
+                m_rm0_node,
+                m_sub_node,
+                m_pow_node,
+                m_rm1_node,
+                m_add0_node,
+                m_sqrt_node,
+                m_div_node,
+                m_mul_node,
+                m_add1_node,
+            ],
+            "qnn_f32_model",
+            [root_inp],
+            [output],
+            initializer=[
+                one_const,
+                half_const,
+                root2_const,
+                axes_const,
+                eps_const,
+                shape_const,
+                scale_const,
+                bias_const,
+                two_const,
+            ],
+        )
+        opset_imports = [
+            onnx.helper.make_opsetid("", 18),
+        ]
+        model = onnx.helper.make_model(graph, opset_imports=opset_imports)
+        return onnx.shape_inference.infer_shapes(model)
+
+    def test_all_fusions(self):
+        """
+        Test calling qnn_preprocess_model() with a model that supports all 3 fusions.
+        """
+        model = self.build_model((1, 2, 3), [2.0, 2.0, 2.0], [1.0, 1.0, 1.0])
+        onnx.save_model(model, "model.onnx")
+        modified = qnn_preprocess_model("model.onnx", "model.qnn_pp.onnx", fuse_layernorm=True)
+
+        self.assertTrue(modified)
+
+        fused_model = onnx.load_model("model.qnn_pp.onnx")
+
+        # 3 fused Ops: Gelu, LpNorm, LayerNorm
+        self.assertEqual(len(fused_model.graph.node), 3)
+        expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"}
+        for node in fused_model.graph.node:
+            self.assertIn(node.op_type, expected_op_types)
+
+        # Should have added "com.microsoft" opset import because we added a Gelu.
+        ms_domain_opset = next((opset for opset in fused_model.opset_import if opset.domain == ms_domain), None)
+        self.assertIsNotNone(ms_domain_opset)
+        self.assertEqual(ms_domain_opset.version, 1)
+
+    def test_external_data(self):
+        """
+        Test calling qnn_preprocess_model() with a model that uses external data.
+        The new preprocessed model should also have external data.
+        """
+        model = self.build_model((1, 2, 3), [2.0, 2.0, 2.0], [1.0, 1.0, 1.0])
+        onnx.save_model(
+            model,
+            "model.onnx",
+            save_as_external_data=True,
+            all_tensors_to_one_file=True,
+            location="weights.bin",
+            size_threshold=0,
+        )
+        modified = qnn_preprocess_model(
+            "model.onnx",
+            "model.qnn_pp.onnx",
+            fuse_layernorm=True,
+            save_as_external_data=True,
+            all_tensors_to_one_file=True,
+            external_data_location="weights2.bin",
+            external_data_size_threshold=0,
+        )
+
+        self.assertTrue(modified)
+
+        # Model should still have external data.
+        self.assertTrue(model_has_external_data(Path("model.qnn_pp.onnx")))
+
+        fused_model = onnx.load_model("model.qnn_pp.onnx", load_external_data=False)
+
+        # 3 fused Ops: Gelu, LpNorm, LayerNorm
+        self.assertEqual(len(fused_model.graph.node), 3)
+        expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"}
+        for node in fused_model.graph.node:
+            self.assertIn(node.op_type, expected_op_types)
+
+    def build_multi_input_output_model(self, shape):
+        """
+        Returns the following model.
+                               +----------> [X]
+                               |
+        [A] ---> Add ---> Abs -+-> Mul ---> [Y]
+                  ^                 ^
+                  |                 |
+        [B] ------+-----------------+
+        """
+        input_a = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, shape)
+        input_b = onnx.helper.make_tensor_value_info("B", onnx.TensorProto.FLOAT, shape)
+        output_x = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape)
+        output_y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape)
+
+        add_node = onnx.helper.make_node("Add", ["A", "B"], ["add_out"], name="add_node")
+        abs_node = onnx.helper.make_node("Abs", ["add_out"], ["X"], name="abs_node")
+        mul_node = onnx.helper.make_node("Mul", ["X", "B"], ["Y"], name="mul_node")
+
+        graph = onnx.helper.make_graph(
+            [add_node, abs_node, mul_node],
+            "multi_io_graph",
+            [input_a, input_b],
+            [output_x, output_y],
+        )
+        opset_imports = [
+            onnx.helper.make_opsetid("", 18),
+        ]
+        model = onnx.helper.make_model(graph, opset_imports=opset_imports)
+        return onnx.shape_inference.infer_shapes(model)
+
+    def test_make_io_channel_last(self):
+        """
+        Test making a model's inputs and outputs channel-last.
+        """
+        model = self.build_multi_input_output_model((1, 2, 3, 4))
+        onnx.save_model(model, "model.onnx")
+        modified = qnn_preprocess_model(
+            "model.onnx",
+            "model.qnn_pp.onnx",
+            inputs_to_make_channel_last=["A", "B"],
+            outputs_to_make_channel_last=["X", "Y"],
+        )
+
+        self.assertTrue(modified)
+
+        preproc_model = onnx.load_model("model.qnn_pp.onnx")
+        self.assertEqual(len(preproc_model.graph.node), 7)
+
+        num_transposes = sum(1 for node in preproc_model.graph.node if node.op_type == "Transpose")
+        self.assertEqual(num_transposes, 4)
+
+        # Check that the outputs of the new model are the same, but transposed.
+        input_a = np.arange(0.0, 24.0, 1.0, dtype=np.float32).reshape((1, 2, 3, 4))
+        input_a_t = input_a.transpose(0, 2, 3, 1)
+        input_b = np.arange(1.0, 25.0, 1.0, dtype=np.float32).reshape((1, 2, 3, 4))
+        input_b_t = input_b.transpose(0, 2, 3, 1)
+
+        orig_session = onnxruntime.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
+        orig_results = orig_session.run(None, {"A": input_a, "B": input_b})
+
+        new_session = onnxruntime.InferenceSession(
+            preproc_model.SerializeToString(), providers=["CPUExecutionProvider"]
+        )
+        new_results = new_session.run(None, {"A": input_a_t, "B": input_b_t})
+
+        self.assertEqual(len(orig_results), len(new_results))
+        for idx, orig_output in enumerate(orig_results):
+            transposed_output = new_results[idx]
+            np.testing.assert_allclose(
+                orig_output,
+                transposed_output.transpose(0, 3, 1, 2),
+                err_msg=f"Channel-last model output {idx} differs",
+            )
+
+    def test_make_io_channel_last_rank_error(self):
+        """
+        Test making a model's inputs and outputs channel-last with a rank < 3 (error).
+        """
+        model = self.build_multi_input_output_model((1, 2))
+        onnx.save_model(model, "model.onnx")
+
+        with self.assertRaises(ValueError) as context:
+            qnn_preprocess_model(
+                "model.onnx",
+                "model.qnn_pp.onnx",
+                inputs_to_make_channel_last=["A", "B"],
+                outputs_to_make_channel_last=["X", "Y"],
+            )
+
+        self.assertIn("to be of rank >= 3", str(context.exception))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py
index 765825d4b86e3..97931acf03f42 100644
--- a/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py
+++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py
@@ -122,9 +122,11 @@ def test_quantize_blockwise_4bits(self):
                                     dequantize_blockwise_4bits(
                                         quant_value_ref[c, k],
                                         scales_ref[c, k],
-                                        (zero_point_ref[c, k // 2] >> 4)
-                                        if (k & 1)
-                                        else (zero_point_ref[c, k // 2] & 0x0F),
+                                        (
+                                            (zero_point_ref[c, k // 2] >> 4)
+                                            if (k & 1)
+                                            else (zero_point_ref[c, k // 2] & 0x0F)
+                                        ),
                                         min(block_size, rows - k * block_size),
                                     ),
                                     dequantize_blockwise_4bits(
diff --git a/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py
new file mode 100644
index 0000000000000..2b5d1f36070e5
--- /dev/null
+++ b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py
@@ -0,0 +1,92 @@
+#!/usr/bin/env python
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+import unittest
+
+import numpy as np
+import onnx
+import onnx.helper as oh
+import onnx.numpy_helper as onh
+
+from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
+from onnxruntime.quantization.quant_utils import QuantizationMode, QuantType
+
+
+class TestQuantizerShapeInference(unittest.TestCase):
+    def test_com_microsoft(self):
+        model = oh.make_model(
+            oh.make_graph(
+                [
+                    oh.make_node("MatMul", ["X", "W1"], ["T1"]),
+                    oh.make_node("FusedMatMul", ["T1", "W2"], ["T2"], domain="com.microsoft"),
+                    oh.make_node("MatMul", ["T2", "W3"], ["T3"]),
+                    oh.make_node("MatMul", ["T3", "W4"], ["Y"]),
+                ],
+                "name",
+                [oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 4])],
+                [oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1, 4])],
+                [
+                    onh.from_array(np.random.randn(4, 4).astype(np.float32), "W1"),
+                    onh.from_array(np.random.randn(4, 4).astype(np.float32), "W2"),
+                    onh.from_array(np.random.randn(4, 4).astype(np.float32), "W3"),
+                    onh.from_array(np.random.randn(4, 4).astype(np.float32), "W4"),
+                ],
+            ),
+            opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)],
+        )
+        model_shaped = onnx.shape_inference.infer_shapes(model)
+        shaped_results = set(t.name for t in model_shaped.graph.value_info)
+        # every result after T1 depends on T2 coming from a node com.microsoft,
+        # shape_inference cannot go beyond this point
+        self.assertEqual(shaped_results, {"T1"})
+
+        # first try: checks it raises an exception
+        quantizer = ONNXQuantizer(
+            model,
+            False,  # per_channel
+            False,  # reduce_range
+            QuantizationMode.IntegerOps,  # mode
+            False,  # static
+            QuantType.QInt8,  #  weight_type,
+            QuantType.QUInt8,  # dynamic activation only supports uint8
+            None,
+            [],  # nodes_to_quantize,
+            [],  # nodes_to_exclude
+            ["MatMul"],  # op_types_to_quantize,
+            {"MatMulConstBOnly": True},  # extra_options,
+            # {'DefaultTensorType': 1, }
+        )
+
+        with self.assertRaises(RuntimeError) as e:
+            quantizer.quantize_model()
+            self.assertIn("Unable to find data type for weight_name=", str(e))
+
+        # second try: checks it works
+        quantizer = ONNXQuantizer(
+            model,
+            False,  # per_channel
+            False,  # reduce_range
+            QuantizationMode.IntegerOps,  # mode
+            False,  # static
+            QuantType.QInt8,  #  weight_type,
+            QuantType.QUInt8,  # dynamic activation only supports uint8
+            None,
+            [],  # nodes_to_quantize,
+            [],  # nodes_to_exclude
+            ["MatMul"],  # op_types_to_quantize,
+            {
+                "MatMulConstBOnly": True,
+                "DefaultTensorType": 1,
+            },
+        )
+
+        model = quantizer.quantize_model()
+        ops = {n.op_type for n in model.graph.node}
+        self.assertEqual(ops, {"Cast", "FusedMatMul", "MatMulInteger", "DynamicQuantizeLinear", "Mul"})
+
+
+if __name__ == "__main__":
+    unittest.main(verbosity=2)
diff --git a/onnxruntime/test/python/quantization/test_subgraph.py b/onnxruntime/test/python/quantization/test_subgraph.py
new file mode 100644
index 0000000000000..c425bf956f976
--- /dev/null
+++ b/onnxruntime/test/python/quantization/test_subgraph.py
@@ -0,0 +1,64 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import os
+import tempfile
+import unittest
+import urllib.request
+
+import onnx
+
+from onnxruntime.quantization import quantize_dynamic
+
+
+class TestDynamicQuantizationSubgraph(unittest.TestCase):
+    def test_dynamic_quantization_subgraph(self):
+        with tempfile.TemporaryDirectory() as tmpdir:
+            onnx_path = os.path.join(tmpdir, "decoder_model_merged.onnx")
+            quantized_onnx_path = os.path.join(tmpdir, "decoder_model_merged_quantized.onnx")
+            urllib.request.urlretrieve(
+                "https://huggingface.co/fxmarty/t5-tiny-onnx-testing/resolve/main/decoder_model_merged.onnx", onnx_path
+            )
+
+            quantize_dynamic(
+                model_input=onnx_path,
+                model_output=quantized_onnx_path,
+                per_channel=True,
+                op_types_to_quantize=[
+                    "Conv",
+                    "MatMul",
+                    "Attention",
+                    "LSTM",
+                    "Gather",
+                    "Transpose",
+                    "EmbedLayerNormalization",
+                ],
+                extra_options={"EnableSubgraph": True},
+            )
+            model = onnx.load(quantized_onnx_path)
+
+            # The initializer `shared.weight_merged_0` is attached to the top-level graph, and used in a Gather node in each subgraphs.
+            # We expect the quantized Gather (after which a DequantizeLinear is attached) initializer to also be attached to the top-level graph.
+            found_gather_quantized = False
+            for initializer in model.graph.initializer:
+                if initializer.name == "shared.weight_merged_0_quantized":
+                    found_gather_quantized = True
+                    break
+            self.assertTrue(found_gather_quantized)
+
+            found_gather_scale = False
+            for initializer in model.graph.initializer:
+                if initializer.name == "shared.weight_merged_0_scale":
+                    found_gather_scale = True
+                    break
+            self.assertTrue(found_gather_scale)
+
+            # No initializers related to the Gather node should be attached to the subgraphs.
+            for node in model.graph.node:
+                for attr in node.attribute:
+                    if attr.type == onnx.AttributeProto.GRAPH:
+                        for initializer in attr.g.initializer:
+                            self.assertTrue("shared.weight" not in initializer.name)
diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py
index 0470953e385b6..9ea4719f3c595 100644
--- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py
+++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py
@@ -13,7 +13,7 @@
 
 from onnxruntime import quantization
 from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config
-from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType
+from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType, ms_domain
 
 
 class DummyDataReader(quantization.CalibrationDataReader):
@@ -423,6 +423,36 @@ def test_qdq_overrides_per_channel2(self):
             self.assertEqual(zp, expected_zp)
             self.assertEqual(scale, np.float32(expected_scale))
 
+    def test_16bit_overrides_set_ms_domain(self):
+        """
+        Test that overriding a tensor to 16bit (when default is 8bit) automatically sets the 'com.microsoft'
+        domain on DQ and Q ops.
+        """
+        qdq_model_name = "model_quant_overrides_to_16bit.onnx"
+        inp_zp, _, sig_out_zp, _, _, _, _, _, out_zp, _ = self.perform_qdq_quantization(
+            qdq_model_name,
+            activation_type=onnx.TensorProto.UINT8,  # Default to 8bit activations
+            extra_options={
+                "TensorQuantOverrides": {
+                    "INP": [{"quant_type": quantization.QuantType.QUInt16}],
+                    "SIG_OUT": [{"quant_type": quantization.QuantType.QUInt16}],
+                }
+            },
+        )
+
+        # Input and Sigmoid's output should be overridden to 16bit
+        self.assertEqual(inp_zp.data_type, onnx.TensorProto.UINT16)
+        self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16)
+
+        # Output should the default uint8 type
+        self.assertEqual(out_zp.data_type, onnx.TensorProto.UINT8)
+
+        # Q/DQ ops should all have the 'com.microsoft' domain
+        qdq_model = onnx.load_model(qdq_model_name)
+        for node in qdq_model.graph.node:
+            if node.op_type in {"QuantizeLinear", "DequantizeLinear"}:
+                self.assertEqual(node.domain, ms_domain)
+
     def test_override_validation_nonexisting_tensor(self):
         """
         Test that specifying a non-existing tensor should fail.
@@ -555,6 +585,36 @@ def test_get_qnn_qdq_config(self):
         self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16)
         self.assertEqual(sig_out_sc.float_data[0], np.float32(1.0 / 65536.0))
 
+    def test_get_qnn_qdq_config_ext_data(self):
+        """
+        Test that get_qnn_qdq_config() returns a config that enables external data
+        if the input model has external data.
+        """
+
+        # Create model with a weight large enough (> 1024 bytes) to be stored externally.
+        large_weight = onnx.numpy_helper.from_array(np.random.random((1, 32, 32)).astype(np.float32), "weight")
+        graph = onnx.helper.make_graph(
+            [onnx.helper.make_node("Add", ["input", "weight"], ["output"])],
+            "add_ext_data",
+            [onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, (1, 32, 32))],
+            [onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, (1, 32, 32))],
+            initializer=[large_weight],
+        )
+        model = onnx.helper.make_model(
+            graph,
+            opset_imports=[onnx.helper.make_opsetid("", 18)],
+        )
+        onnx.save_model(
+            model,
+            "add_ext_data.onnx",
+            save_as_external_data=True,
+            all_tensors_to_one_file=True,
+            location="add_ext_data.bin",
+        )
+
+        qnn_config = get_qnn_qdq_config("add_ext_data.onnx", DummyDataReader(self.activations))
+        self.assertTrue(qnn_config.use_external_data_format)
+
 
 if __name__ == "__main__":
     t = TestTensorQuantOverridesOption()
diff --git a/onnxruntime/test/python/transformers/bert_model_generator.py b/onnxruntime/test/python/transformers/bert_model_generator.py
index 9b9409545615b..a84137f092e64 100644
--- a/onnxruntime/test/python/transformers/bert_model_generator.py
+++ b/onnxruntime/test/python/transformers/bert_model_generator.py
@@ -94,12 +94,16 @@ def create_bert_attention(
             perm=[0, 2, 3, 1],
         ),
         # mask nodes
-        helper.make_node("Unsqueeze", ["input_mask", "axes_1"], ["unsqueeze0_out"], "unsqueeze0")
-        if has_unsqueeze_two_inputs
-        else helper.make_node("Unsqueeze", ["input_mask"], ["unsqueeze0_out"], "unsqueeze0", axes=[1]),
-        helper.make_node("Unsqueeze", ["unsqueeze0_out", "axes_2"], ["unsqueeze1_out"], "unsqueeze1")
-        if has_unsqueeze_two_inputs
-        else helper.make_node("Unsqueeze", ["unsqueeze0_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[2]),
+        (
+            helper.make_node("Unsqueeze", ["input_mask", "axes_1"], ["unsqueeze0_out"], "unsqueeze0")
+            if has_unsqueeze_two_inputs
+            else helper.make_node("Unsqueeze", ["input_mask"], ["unsqueeze0_out"], "unsqueeze0", axes=[1])
+        ),
+        (
+            helper.make_node("Unsqueeze", ["unsqueeze0_out", "axes_2"], ["unsqueeze1_out"], "unsqueeze1")
+            if has_unsqueeze_two_inputs
+            else helper.make_node("Unsqueeze", ["unsqueeze0_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[2])
+        ),
         # when attention_mask is float type, no need to cast
         helper.make_node("Cast", ["unsqueeze1_out"], ["cast_out"], "cast", to=1) if not use_float_mask else None,
         helper.make_node(
@@ -291,9 +295,11 @@ def create_tf2onnx_attention_3d(input_hidden_size=16, num_heads=4, head_size=4,
         helper.make_node("Add", ["einsum_k_out", "add_k_weight"], ["add_k_out"], "add_k"),
         helper.make_node("Mul", ["add_k_out", "mul_weight_1"], ["mul_k_out"], "mul_k"),
         # mask nodes
-        helper.make_node("Unsqueeze", ["input_mask", "axes_1"], ["unsqueeze0_out"], "unsqueeze0")
-        if has_unsqueeze_two_inputs
-        else helper.make_node("Unsqueeze", ["input_mask"], ["unsqueeze0_out"], "unsqueeze0", axes=[1, 2]),
+        (
+            helper.make_node("Unsqueeze", ["input_mask", "axes_1"], ["unsqueeze0_out"], "unsqueeze0")
+            if has_unsqueeze_two_inputs
+            else helper.make_node("Unsqueeze", ["input_mask"], ["unsqueeze0_out"], "unsqueeze0", axes=[1, 2])
+        ),
         helper.make_node(
             "Slice",
             ["unsqueeze0_out", "slice_start", "slice_end", "slice_axes", "slice_steps"],
diff --git a/onnxruntime/test/python/transformers/conformer_model_generator.py b/onnxruntime/test/python/transformers/conformer_model_generator.py
index 71e4f2b63cf4f..5b27a46ea0fdc 100644
--- a/onnxruntime/test/python/transformers/conformer_model_generator.py
+++ b/onnxruntime/test/python/transformers/conformer_model_generator.py
@@ -22,9 +22,7 @@ def get_tensor_and_weight(name: str, shape: List[int], random=False, zeros=False
     weights = (
         [np.random.uniform(low, high) for _ in range(total_elements)]
         if random
-        else [0.0] * total_elements
-        if zeros
-        else [1.0] * total_elements
+        else [0.0] * total_elements if zeros else [1.0] * total_elements
     )
     return helper.make_tensor(name, TensorProto.FLOAT, shape, weights), weights
 
diff --git a/onnxruntime/test/python/transformers/gpt2_model_generator.py b/onnxruntime/test/python/transformers/gpt2_model_generator.py
index 4a1b48d4d1b48..0865c87b70da7 100644
--- a/onnxruntime/test/python/transformers/gpt2_model_generator.py
+++ b/onnxruntime/test/python/transformers/gpt2_model_generator.py
@@ -41,15 +41,17 @@ def create_gpt2_attention(hidden_size=64, num_heads=4, max_seq_len=32, switch_ad
             ["fc_out"],
             "add_fc",
         ),
-        helper.make_node("Split", ["fc_out", "split_q_k_v"], ["q", "k", "v"], "split_qkv", axis=2)
-        if is_opset_13_or_newer
-        else helper.make_node(
-            "Split",
-            ["fc_out"],
-            ["q", "k", "v"],
-            "split_qkv",
-            axis=2,
-            split=[hidden_size, hidden_size, hidden_size],
+        (
+            helper.make_node("Split", ["fc_out", "split_q_k_v"], ["q", "k", "v"], "split_qkv", axis=2)
+            if is_opset_13_or_newer
+            else helper.make_node(
+                "Split",
+                ["fc_out"],
+                ["q", "k", "v"],
+                "split_qkv",
+                axis=2,
+                split=[hidden_size, hidden_size, hidden_size],
+            )
         ),
         # q nodes
         helper.make_node("Reshape", ["q", "reshape_x_shape"], ["reshape_q_out"], "reshape_q"),
@@ -79,19 +81,23 @@ def create_gpt2_attention(hidden_size=64, num_heads=4, max_seq_len=32, switch_ad
             perm=[0, 2, 1, 3],
         ),
         # past
-        helper.make_node("Split", ["past", "split_1_1"], ["split_k", "split_v"], "split_past", axis=0)
-        if is_opset_13_or_newer
-        else helper.make_node(
-            "Split",
-            ["past"],
-            ["split_k", "split_v"],
-            "split_past",
-            axis=0,
-            split=[1, 1],
+        (
+            helper.make_node("Split", ["past", "split_1_1"], ["split_k", "split_v"], "split_past", axis=0)
+            if is_opset_13_or_newer
+            else helper.make_node(
+                "Split",
+                ["past"],
+                ["split_k", "split_v"],
+                "split_past",
+                axis=0,
+                split=[1, 1],
+            )
+        ),
+        (
+            helper.make_node("Squeeze", ["split_k", "axes_0"], ["past_k"], "squeeze_past_k")
+            if is_opset_13_or_newer
+            else helper.make_node("Squeeze", ["split_k"], ["past_k"], "squeeze_past_k", axes=[0])
         ),
-        helper.make_node("Squeeze", ["split_k", "axes_0"], ["past_k"], "squeeze_past_k")
-        if is_opset_13_or_newer
-        else helper.make_node("Squeeze", ["split_k"], ["past_k"], "squeeze_past_k", axes=[0]),
         helper.make_node(
             "Concat",
             ["past_k", "transpose_k_out"],
@@ -106,9 +112,11 @@ def create_gpt2_attention(hidden_size=64, num_heads=4, max_seq_len=32, switch_ad
             "transpose_concat_k",
             perm=[0, 1, 3, 2],
         ),
-        helper.make_node("Squeeze", ["split_v", "axes_0"], ["past_v"], "squeeze_past_v")
-        if is_opset_13_or_newer
-        else helper.make_node("Squeeze", ["split_v"], ["past_v"], "squeeze_past_v", axes=[0]),
+        (
+            helper.make_node("Squeeze", ["split_v", "axes_0"], ["past_v"], "squeeze_past_v")
+            if is_opset_13_or_newer
+            else helper.make_node("Squeeze", ["split_v"], ["past_v"], "squeeze_past_v", axes=[0])
+        ),
         helper.make_node(
             "Concat",
             ["past_v", "transpose_v_out"],
@@ -117,33 +125,37 @@ def create_gpt2_attention(hidden_size=64, num_heads=4, max_seq_len=32, switch_ad
             axis=-2,
         ),
         # present
-        helper.make_node(
-            "Unsqueeze",
-            ["concat_k_out", "axes_0"],
-            ["concat_k_unsqueeze_out"],
-            "concat_k_unsqueeze",
-        )
-        if is_opset_13_or_newer
-        else helper.make_node(
-            "Unsqueeze",
-            ["concat_k_out"],
-            ["concat_k_unsqueeze_out"],
-            "concat_k_unsqueeze",
-            axes=[0],
+        (
+            helper.make_node(
+                "Unsqueeze",
+                ["concat_k_out", "axes_0"],
+                ["concat_k_unsqueeze_out"],
+                "concat_k_unsqueeze",
+            )
+            if is_opset_13_or_newer
+            else helper.make_node(
+                "Unsqueeze",
+                ["concat_k_out"],
+                ["concat_k_unsqueeze_out"],
+                "concat_k_unsqueeze",
+                axes=[0],
+            )
         ),
-        helper.make_node(
-            "Unsqueeze",
-            ["concat_v_out", "axes_0"],
-            ["concat_v_unsqueeze_out"],
-            "concat_v_unsqueeze",
-        )
-        if is_opset_13_or_newer
-        else helper.make_node(
-            "Unsqueeze",
-            ["concat_v_out"],
-            ["concat_v_unsqueeze_out"],
-            "concat_v_unsqueeze",
-            axes=[0],
+        (
+            helper.make_node(
+                "Unsqueeze",
+                ["concat_v_out", "axes_0"],
+                ["concat_v_unsqueeze_out"],
+                "concat_v_unsqueeze",
+            )
+            if is_opset_13_or_newer
+            else helper.make_node(
+                "Unsqueeze",
+                ["concat_v_out"],
+                ["concat_v_unsqueeze_out"],
+                "concat_v_unsqueeze",
+                axes=[0],
+            )
         ),
         helper.make_node(
             "Concat",
@@ -159,19 +171,21 @@ def create_gpt2_attention(hidden_size=64, num_heads=4, max_seq_len=32, switch_ad
             ["transpose_q_shape_slice_out"],
             "transpose_q_shape_slice",
         ),
-        helper.make_node(
-            "Squeeze",
-            ["transpose_q_shape_slice_out", "axes_0"],
-            ["transpose_q_shape_slice_squeeze_out"],
-            "transpose_q_shape_slice_squeeze",
-        )
-        if is_opset_13_or_newer
-        else helper.make_node(
-            "Squeeze",
-            ["transpose_q_shape_slice_out"],
-            ["transpose_q_shape_slice_squeeze_out"],
-            "transpose_q_shape_slice_squeeze",
-            axes=[0],
+        (
+            helper.make_node(
+                "Squeeze",
+                ["transpose_q_shape_slice_out", "axes_0"],
+                ["transpose_q_shape_slice_squeeze_out"],
+                "transpose_q_shape_slice_squeeze",
+            )
+            if is_opset_13_or_newer
+            else helper.make_node(
+                "Squeeze",
+                ["transpose_q_shape_slice_out"],
+                ["transpose_q_shape_slice_squeeze_out"],
+                "transpose_q_shape_slice_squeeze",
+                axes=[0],
+            )
         ),
         helper.make_node("Shape", ["concat_k_out"], ["concat_k_shape_out"], "concat_k_shape"),
         helper.make_node(
@@ -180,19 +194,21 @@ def create_gpt2_attention(hidden_size=64, num_heads=4, max_seq_len=32, switch_ad
             ["concat_k_shape_slice_out"],
             "concat_k_shape_slice",
         ),
-        helper.make_node(
-            "Squeeze",
-            ["concat_k_shape_slice_out", "axes_0"],
-            ["concat_k_shape_slice_squeeze_out"],
-            "concat_k_shape_slice_squeeze",
-        )
-        if is_opset_13_or_newer
-        else helper.make_node(
-            "Squeeze",
-            ["concat_k_shape_slice_out"],
-            ["concat_k_shape_slice_squeeze_out"],
-            "concat_k_shape_slice_squeeze",
-            axes=[0],
+        (
+            helper.make_node(
+                "Squeeze",
+                ["concat_k_shape_slice_out", "axes_0"],
+                ["concat_k_shape_slice_squeeze_out"],
+                "concat_k_shape_slice_squeeze",
+            )
+            if is_opset_13_or_newer
+            else helper.make_node(
+                "Squeeze",
+                ["concat_k_shape_slice_out"],
+                ["concat_k_shape_slice_squeeze_out"],
+                "concat_k_shape_slice_squeeze",
+                axes=[0],
+            )
         ),
         helper.make_node(
             "Sub",
@@ -200,22 +216,26 @@ def create_gpt2_attention(hidden_size=64, num_heads=4, max_seq_len=32, switch_ad
             ["sub_out"],
             "sub",
         ),
-        helper.make_node("Unsqueeze", ["sub_out", "axes_0"], ["sub_unsqueeze_out"], "sub_unsqueeze")
-        if is_opset_13_or_newer
-        else helper.make_node("Unsqueeze", ["sub_out"], ["sub_unsqueeze_out"], "sub_unsqueeze", axes=[0]),
-        helper.make_node(
-            "Unsqueeze",
-            ["concat_k_shape_slice_squeeze_out", "axes_0"],
-            ["concat_k_shape_slice_squeeze_unsqueeze_out"],
-            "concat_k_shape_slice_squeeze_unsqueeze",
-        )
-        if is_opset_13_or_newer
-        else helper.make_node(
-            "Unsqueeze",
-            ["concat_k_shape_slice_squeeze_out"],
-            ["concat_k_shape_slice_squeeze_unsqueeze_out"],
-            "concat_k_shape_slice_squeeze_unsqueeze",
-            axes=[0],
+        (
+            helper.make_node("Unsqueeze", ["sub_out", "axes_0"], ["sub_unsqueeze_out"], "sub_unsqueeze")
+            if is_opset_13_or_newer
+            else helper.make_node("Unsqueeze", ["sub_out"], ["sub_unsqueeze_out"], "sub_unsqueeze", axes=[0])
+        ),
+        (
+            helper.make_node(
+                "Unsqueeze",
+                ["concat_k_shape_slice_squeeze_out", "axes_0"],
+                ["concat_k_shape_slice_squeeze_unsqueeze_out"],
+                "concat_k_shape_slice_squeeze_unsqueeze",
+            )
+            if is_opset_13_or_newer
+            else helper.make_node(
+                "Unsqueeze",
+                ["concat_k_shape_slice_squeeze_out"],
+                ["concat_k_shape_slice_squeeze_unsqueeze_out"],
+                "concat_k_shape_slice_squeeze_unsqueeze",
+                axes=[0],
+            )
         ),
         helper.make_node(
             "Slice",
@@ -255,23 +275,27 @@ def create_gpt2_attention(hidden_size=64, num_heads=4, max_seq_len=32, switch_ad
             ["input_mask_reshape_out"],
             "input_mask_reshape",
         ),
-        helper.make_node(
-            "Unsqueeze",
-            ["input_mask_reshape_out", "axes_1"],
-            ["unsqueeze0_out"],
-            "unsqueeze0",
-        )
-        if is_opset_13_or_newer
-        else helper.make_node(
-            "Unsqueeze",
-            ["input_mask_reshape_out"],
-            ["unsqueeze0_out"],
-            "unsqueeze0",
-            axes=[1],
+        (
+            helper.make_node(
+                "Unsqueeze",
+                ["input_mask_reshape_out", "axes_1"],
+                ["unsqueeze0_out"],
+                "unsqueeze0",
+            )
+            if is_opset_13_or_newer
+            else helper.make_node(
+                "Unsqueeze",
+                ["input_mask_reshape_out"],
+                ["unsqueeze0_out"],
+                "unsqueeze0",
+                axes=[1],
+            )
+        ),
+        (
+            helper.make_node("Unsqueeze", ["unsqueeze0_out", "axes_2"], ["unsqueeze1_out"], "unsqueeze1")
+            if is_opset_13_or_newer
+            else helper.make_node("Unsqueeze", ["unsqueeze0_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[2])
         ),
-        helper.make_node("Unsqueeze", ["unsqueeze0_out", "axes_2"], ["unsqueeze1_out"], "unsqueeze1")
-        if is_opset_13_or_newer
-        else helper.make_node("Unsqueeze", ["unsqueeze0_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[2]),
         helper.make_node("Sub", ["sub_weight", "unsqueeze1_out"], ["mask_sub_out"], "sub_mask"),
         helper.make_node("Mul", ["mask_sub_out", "mul_weight"], ["mul_mask_out"], "mul_mask"),
         # qk nodes
@@ -322,33 +346,37 @@ def create_gpt2_attention(hidden_size=64, num_heads=4, max_seq_len=32, switch_ad
             ["qkv_shape_slice_out"],
             "qkv_shape_slice",
         ),
-        helper.make_node(
-            "Squeeze",
-            ["qkv_shape_slice_out", "axes_0"],
-            ["qkv_shape_slice_squeeze_out"],
-            "qkv_shape_slice_squeeze",
-        )
-        if is_opset_13_or_newer
-        else helper.make_node(
-            "Squeeze",
-            ["qkv_shape_slice_out"],
-            ["qkv_shape_slice_squeeze_out"],
-            "qkv_shape_slice_squeeze",
-            axes=[0],
+        (
+            helper.make_node(
+                "Squeeze",
+                ["qkv_shape_slice_out", "axes_0"],
+                ["qkv_shape_slice_squeeze_out"],
+                "qkv_shape_slice_squeeze",
+            )
+            if is_opset_13_or_newer
+            else helper.make_node(
+                "Squeeze",
+                ["qkv_shape_slice_out"],
+                ["qkv_shape_slice_squeeze_out"],
+                "qkv_shape_slice_squeeze",
+                axes=[0],
+            )
         ),
-        helper.make_node(
-            "Unsqueeze",
-            ["qkv_shape_slice_squeeze_out", "axes_0"],
-            ["qkv_shape_slice_squeeze_unsqueeze_out"],
-            "qkv_shape_slice_squeeze_unsqueeze",
-        )
-        if is_opset_13_or_newer
-        else helper.make_node(
-            "Unsqueeze",
-            ["qkv_shape_slice_squeeze_out"],
-            ["qkv_shape_slice_squeeze_unsqueeze_out"],
-            "qkv_shape_slice_squeeze_unsqueeze",
-            axes=[0],
+        (
+            helper.make_node(
+                "Unsqueeze",
+                ["qkv_shape_slice_squeeze_out", "axes_0"],
+                ["qkv_shape_slice_squeeze_unsqueeze_out"],
+                "qkv_shape_slice_squeeze_unsqueeze",
+            )
+            if is_opset_13_or_newer
+            else helper.make_node(
+                "Unsqueeze",
+                ["qkv_shape_slice_squeeze_out"],
+                ["qkv_shape_slice_squeeze_unsqueeze_out"],
+                "qkv_shape_slice_squeeze_unsqueeze",
+                axes=[0],
+            )
         ),
         helper.make_node(
             "Concat",
@@ -387,33 +415,37 @@ def create_gpt2_attention(hidden_size=64, num_heads=4, max_seq_len=32, switch_ad
             "shape_qkv_gather_0",
             axis=0,
         ),
-        helper.make_node(
-            "Unsqueeze",
-            ["qkv_shape_1", "axes_0"],
-            ["qkv_shape_1_unsqueeze_out"],
-            "qkv_shape_1_unsqueeze",
-        )
-        if is_opset_13_or_newer
-        else helper.make_node(
-            "Unsqueeze",
-            ["qkv_shape_1"],
-            ["qkv_shape_1_unsqueeze_out"],
-            "qkv_shape_1_unsqueeze",
-            axes=[0],
+        (
+            helper.make_node(
+                "Unsqueeze",
+                ["qkv_shape_1", "axes_0"],
+                ["qkv_shape_1_unsqueeze_out"],
+                "qkv_shape_1_unsqueeze",
+            )
+            if is_opset_13_or_newer
+            else helper.make_node(
+                "Unsqueeze",
+                ["qkv_shape_1"],
+                ["qkv_shape_1_unsqueeze_out"],
+                "qkv_shape_1_unsqueeze",
+                axes=[0],
+            )
         ),
-        helper.make_node(
-            "Unsqueeze",
-            ["qkv_shape_0", "axes_0"],
-            ["qkv_shape_0_unsqueeze_out"],
-            "qkv_shape_0_unsqueeze",
-        )
-        if is_opset_13_or_newer
-        else helper.make_node(
-            "Unsqueeze",
-            ["qkv_shape_0"],
-            ["qkv_shape_0_unsqueeze_out"],
-            "qkv_shape_0_unsqueeze",
-            axes=[0],
+        (
+            helper.make_node(
+                "Unsqueeze",
+                ["qkv_shape_0", "axes_0"],
+                ["qkv_shape_0_unsqueeze_out"],
+                "qkv_shape_0_unsqueeze",
+            )
+            if is_opset_13_or_newer
+            else helper.make_node(
+                "Unsqueeze",
+                ["qkv_shape_0"],
+                ["qkv_shape_0_unsqueeze_out"],
+                "qkv_shape_0_unsqueeze",
+                axes=[0],
+            )
         ),
         helper.make_node(
             "Concat",
@@ -767,9 +799,11 @@ def create_gpt2_fused_embedlayer(
                 "",
                 "ids",
             ],
-            ["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index", "embedding_sum"]
-            if output_embedding_sum
-            else ["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index"],
+            (
+                ["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index", "embedding_sum"]
+                if output_embedding_sum
+                else ["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index"]
+            ),
             "EmbedLayerNormalization_0",
             domain="com.microsoft",
             epsilon=epsilon,
diff --git a/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py b/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py
index af835d2906e87..ec64f2359f4be 100644
--- a/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py
+++ b/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py
@@ -24,25 +24,17 @@ def get_size():
     return comm.Get_size()
 
 
-def barrier():
-    comm.Barrier()
-
-
 def print_out(*args):
     if get_rank() == 0:
         print(*args)
 
 
-def broadcast(data):
-    comm = MPI.COMM_WORLD
-    comm.broadcast(data, root=0)
-
-
 local_rank = get_rank()
 
 ORT_DTYPE = TensorProto.FLOAT16
 NP_TYPE = np.float16 if ORT_DTYPE == TensorProto.FLOAT16 else np.float32
-THRESHOLD = 1e-3
+THRESHOLD_TP = 3e-2
+THRESHOLD_EP = 1e-6
 
 
 def create_moe_onnx_graph(
@@ -52,51 +44,64 @@ def create_moe_onnx_graph(
     hidden_size,
     inter_size,
     fc1_experts_weights,
-    fc2_experts_weights,
     fc1_experts_bias,
+    fc2_experts_weights,
     fc2_experts_bias,
-    local_experts_start_index=-1,
+    fc3_experts_weights,
+    local_experts_start_index=0,
+    topk=2,
+    normalize_routing_weights=1,
+    activation_type="gelu",
+    tensor_shards=1,
 ):
-    use_sharded_moe = local_experts_start_index >= 0
+    use_sharded_moe = num_experts > local_num_experts or tensor_shards > 1
     nodes = [
-        helper.make_node(
-            "MoE",
-            [
-                "input",
-                "router_probs",
-                "fc1_experts_weights",
-                "fc2_experts_weights",
-                "fc1_experts_bias",
-                "fc2_experts_bias",
-            ],
-            ["output"],
-            "MoE_0",
-            k=1,
-            activation_type="gelu",
-            domain="com.microsoft",
-        )
-        if not use_sharded_moe
-        else helper.make_node(
-            "ShardedMoE",
-            [
-                "input",
-                "router_probs",
-                "fc1_experts_weights",
-                "fc2_experts_weights",
-                "fc1_experts_bias",
-                "fc2_experts_bias",
-            ],
-            ["output"],
-            "MoE_0",
-            k=1,
-            activation_type="gelu",
-            local_experts_start_index=local_experts_start_index,
-            domain="com.microsoft",
+        (
+            helper.make_node(
+                "MoE",
+                [
+                    "input",
+                    "router_probs",
+                    "fc1_experts_weights",
+                    "fc1_experts_bias",
+                    "fc2_experts_weights",
+                    "fc2_experts_bias",
+                    "fc3_experts_weights",
+                ],
+                ["output"],
+                "MoE_0",
+                k=topk,
+                normalize_routing_weights=normalize_routing_weights,
+                activation_type=activation_type,
+                domain="com.microsoft",
+            )
+            if not use_sharded_moe
+            else helper.make_node(
+                "ShardedMoE",
+                [
+                    "input",
+                    "router_probs",
+                    "fc1_experts_weights",
+                    "fc1_experts_bias",
+                    "fc2_experts_weights",
+                    "fc2_experts_bias",
+                    "fc3_experts_weights",
+                ],
+                ["output"],
+                "MoE_0",
+                k=topk,
+                normalize_routing_weights=normalize_routing_weights,
+                activation_type=activation_type,
+                local_experts_start_index=local_experts_start_index,
+                tensor_shards=tensor_shards,
+                domain="com.microsoft",
+            )
         ),
     ]
 
     fc1_shape = [local_num_experts, hidden_size, inter_size]
     fc2_shape = [local_num_experts, inter_size, hidden_size]
+    fc3_shape = fc1_shape
 
     initializers = [
         helper.make_tensor(
@@ -113,6 +118,13 @@ def create_moe_onnx_graph(
             fc2_experts_weights.flatten(),
             raw=False,
         ),
+        helper.make_tensor(
+            "fc3_experts_weights",
+            ORT_DTYPE,
+            fc3_shape,
+            fc3_experts_weights.flatten(),
+            raw=False,
+        ),
     ]
 
     fc1_bias_shape = [local_num_experts, inter_size]
@@ -164,18 +176,18 @@ def create_moe_onnx_graph(
     return model.SerializeToString()
 
 
-def test_moe_with_expert_slicing(
+def generate_weights_and_initial_model(
+    num_rows,
+    num_experts,
     hidden_size,
     inter_size,
-    num_experts,
-    num_rows,
 ):
-    local_experts_start_index = local_rank * num_experts // get_size()
-
-    fc1_experts_weights_all = np.random.rand(num_experts, hidden_size, inter_size).astype(NP_TYPE)
-    fc2_experts_weights_all = np.random.rand(num_experts, inter_size, hidden_size).astype(NP_TYPE)
-    fc1_experts_bias_all = np.random.rand(num_experts, inter_size).astype(NP_TYPE)
-    fc2_experts_bias_all = np.random.rand(num_experts, hidden_size).astype(NP_TYPE)
+    s = 0.1
+    fc1_experts_weights_all = np.random.normal(scale=s, size=(num_experts, hidden_size, inter_size)).astype(NP_TYPE)
+    fc2_experts_weights_all = np.random.normal(scale=s, size=(num_experts, inter_size, hidden_size)).astype(NP_TYPE)
+    fc3_experts_weights_all = np.random.normal(scale=s, size=(num_experts, hidden_size, inter_size)).astype(NP_TYPE)
+    fc1_experts_bias_all = np.random.normal(scale=s, size=(num_experts, inter_size)).astype(NP_TYPE)
+    fc2_experts_bias_all = np.random.normal(scale=s, size=(num_experts, hidden_size)).astype(NP_TYPE)
 
     onnx_model_full = create_moe_onnx_graph(
         num_rows,
@@ -184,34 +196,31 @@ def test_moe_with_expert_slicing(
         hidden_size,
         inter_size,
         fc1_experts_weights_all,
-        fc2_experts_weights_all,
         fc1_experts_bias_all,
+        fc2_experts_weights_all,
         fc2_experts_bias_all,
+        fc3_experts_weights_all,
     )
 
-    fc1_experts_weights = fc1_experts_weights_all[
-        local_experts_start_index : local_experts_start_index + num_experts // get_size(), :, :
-    ]
-    fc2_experts_weights = fc2_experts_weights_all[
-        local_experts_start_index : local_experts_start_index + num_experts // get_size(), :, :
-    ]
-    fc1_experts_bias = fc1_experts_bias_all[
-        local_experts_start_index : local_experts_start_index + num_experts // get_size(), :
-    ]
-
-    onnx_model_local = create_moe_onnx_graph(
-        num_rows,
-        num_experts,
-        num_experts // get_size(),
-        hidden_size,
-        inter_size,
-        fc1_experts_weights,
-        fc2_experts_weights,
-        fc1_experts_bias,
+    return (
+        onnx_model_full,
+        fc1_experts_weights_all,
+        fc1_experts_bias_all,
+        fc2_experts_weights_all,
         fc2_experts_bias_all,
-        local_experts_start_index,
+        fc3_experts_weights_all,
     )
 
+
+def run_ort_with_parity_check(
+    onnx_model_full,
+    onnx_model_local,
+    num_rows,
+    hidden_size,
+    num_experts,
+    inter_size,
+    threshold,
+):
     sess_options = onnxruntime.SessionOptions()
     cuda_provider_options = {"device_id": local_rank}
     execution_providers = [("CUDAExecutionProvider", cuda_provider_options)]
@@ -227,30 +236,161 @@ def test_moe_with_expert_slicing(
     output = ort_session.run(None, ort_inputs)
     sharded_output = ort_session_local.run(None, ort_inputs)
 
-    assert np.allclose(output[0], sharded_output[0], atol=THRESHOLD, rtol=THRESHOLD)
+    print_out("max diff:", np.max(np.abs(output[0] - sharded_output[0])))
+    assert np.allclose(output[0], sharded_output[0], atol=threshold, rtol=threshold)
 
     print_out(
-        "hidden_size: ",
+        "hidden_size:",
         hidden_size,
-        " inter_size: ",
+        " inter_size:",
         inter_size,
-        " num_experts: ",
+        " num_experts:",
         num_experts,
-        " num_rows: ",
+        " num_rows:",
         num_rows,
-        " world_size: ",
+        " world_size:",
         get_size(),
         " Parity: OK",
     )
 
 
+def test_moe_with_tensor_parallelism(
+    hidden_size,
+    inter_size,
+    num_experts,
+    num_rows,
+    threshold=THRESHOLD_TP,
+):
+    assert inter_size % get_size() == 0
+
+    (
+        onnx_model_full,
+        fc1_experts_weights_all,
+        fc1_experts_bias_all,
+        fc2_experts_weights_all,
+        fc2_experts_bias_all,
+        fc3_experts_weights_all,
+    ) = generate_weights_and_initial_model(
+        num_rows,
+        num_experts,
+        hidden_size,
+        inter_size,
+    )
+
+    fc1_experts_weights = fc1_experts_weights_all[
+        :, :, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size()
+    ]
+    fc2_experts_weights = fc2_experts_weights_all[
+        :, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size(), :
+    ]
+    fc3_experts_weights = fc3_experts_weights_all[
+        :, :, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size()
+    ]
+    fc1_experts_bias = fc1_experts_bias_all[
+        :, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size()
+    ]
+
+    onnx_model_local = create_moe_onnx_graph(
+        num_rows,
+        num_experts,
+        num_experts,
+        hidden_size,
+        inter_size // get_size(),
+        fc1_experts_weights,
+        fc1_experts_bias,
+        fc2_experts_weights,
+        fc2_experts_bias_all,
+        fc3_experts_weights,
+        tensor_shards=get_size(),
+    )
+
+    run_ort_with_parity_check(
+        onnx_model_full,
+        onnx_model_local,
+        num_rows,
+        hidden_size,
+        num_experts,
+        inter_size,
+        threshold,
+    )
+
+
+def test_moe_with_expert_parallelism(
+    hidden_size,
+    inter_size,
+    num_experts,
+    num_rows,
+    threshold=THRESHOLD_EP,
+):
+    local_experts_start_index = local_rank * num_experts // get_size()
+
+    (
+        onnx_model_full,
+        fc1_experts_weights_all,
+        fc1_experts_bias_all,
+        fc2_experts_weights_all,
+        fc2_experts_bias_all,
+        fc3_experts_weights_all,
+    ) = generate_weights_and_initial_model(
+        num_rows,
+        num_experts,
+        hidden_size,
+        inter_size,
+    )
+
+    fc1_experts_weights = fc1_experts_weights_all[
+        local_experts_start_index : local_experts_start_index + num_experts // get_size(), :, :
+    ]
+    fc2_experts_weights = fc2_experts_weights_all[
+        local_experts_start_index : local_experts_start_index + num_experts // get_size(), :, :
+    ]
+    fc3_experts_weights = fc3_experts_weights_all[
+        local_experts_start_index : local_experts_start_index + num_experts // get_size(), :, :
+    ]
+    fc1_experts_bias = fc1_experts_bias_all[
+        local_experts_start_index : local_experts_start_index + num_experts // get_size(), :
+    ]
+
+    onnx_model_local = create_moe_onnx_graph(
+        num_rows,
+        num_experts,
+        num_experts // get_size(),
+        hidden_size,
+        inter_size,
+        fc1_experts_weights,
+        fc1_experts_bias,
+        fc2_experts_weights,
+        fc2_experts_bias_all,
+        fc3_experts_weights,
+        local_experts_start_index,
+    )
+
+    run_ort_with_parity_check(
+        onnx_model_full,
+        onnx_model_local,
+        num_rows,
+        hidden_size,
+        num_experts,
+        inter_size,
+        threshold,
+    )
+
+
 class TestMoE(unittest.TestCase):
-    def test_moe_expert_slicing(self):
-        for hidden_size in [16, 128]:
-            for inter_size in [512, 1024]:
-                for num_experts in [8, 16, 32]:
-                    for num_rows in [16, 128, 512]:
-                        test_moe_with_expert_slicing(
+    def test_moe_parallelism(self):
+        for hidden_size in [128, 1024]:
+            for inter_size in [512, 2048]:
+                for num_experts in [64]:
+                    for num_rows in [1024]:
+                        print_out("EP")
+                        test_moe_with_expert_parallelism(
+                            hidden_size,
+                            inter_size,
+                            num_experts,
+                            num_rows,
+                        )
+                        print_out("TP")
+                        test_moe_with_tensor_parallelism(
                             hidden_size,
                             inter_size,
                             num_experts,
diff --git a/onnxruntime/test/python/transformers/test_data/bert_squad_tensorflow2.1_keras2onnx_opset11/generate_tiny_keras2onnx_bert_models.py b/onnxruntime/test/python/transformers/test_data/bert_squad_tensorflow2.1_keras2onnx_opset11/generate_tiny_keras2onnx_bert_models.py
index c42c42c3ca170..0086ce0d289c7 100644
--- a/onnxruntime/test/python/transformers/test_data/bert_squad_tensorflow2.1_keras2onnx_opset11/generate_tiny_keras2onnx_bert_models.py
+++ b/onnxruntime/test/python/transformers/test_data/bert_squad_tensorflow2.1_keras2onnx_opset11/generate_tiny_keras2onnx_bert_models.py
@@ -403,9 +403,7 @@ def generate_test_data(
         evalTime = timeit.default_timer() - start_time  # noqa: N806
         if outputs[0].tolist() != result[0].tolist():
             print(
-                "Error: not same result after optimization. use_cpu={}, no_opt_output={}, opt_output={}".format(
-                    use_cpu, result[0].tolist(), outputs[1].tolist()
-                )
+                f"Error: not same result after optimization. use_cpu={use_cpu}, no_opt_output={result[0].tolist()}, opt_output={outputs[1].tolist()}"
             )
         print(f"** Evaluation done in total {evalTime} secs")
 
diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py
index 90d28872d3cc8..b784c83329c76 100644
--- a/onnxruntime/test/python/transformers/test_flash_attn.py
+++ b/onnxruntime/test/python/transformers/test_flash_attn.py
@@ -229,9 +229,11 @@ def create_group_query_attention_graph_prompt(
             [
                 config.batch_size,
                 config.q_sequence_length,
-                (config.num_heads * config.head_size)
-                if not packed
-                else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size),
+                (
+                    (config.num_heads * config.head_size)
+                    if not packed
+                    else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size)
+                ),
             ],
         ),
         helper.make_tensor_value_info(
@@ -415,9 +417,11 @@ def create_group_query_attention_graph_past(
             [
                 config.batch_size,
                 config.sequence_length,
-                (config.num_heads * config.head_size)
-                if not packed
-                else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size),
+                (
+                    (config.num_heads * config.head_size)
+                    if not packed
+                    else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size)
+                ),
             ],
         ),
         helper.make_tensor_value_info(
diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py
index 40ea8cf774918..33ec1bd7728fe 100644
--- a/onnxruntime/test/python/transformers/test_generation.py
+++ b/onnxruntime/test/python/transformers/test_generation.py
@@ -381,22 +381,23 @@ def test_logits_processor(self):
 
     @pytest.mark.slow
     def test_cross_qk_overall(self):
-        decoder_input_ids = [
-            "--chain_model",
-            "--collect_cross_qk",
-            "--output_cross_qk",
-            "--use_forced_decoder_ids",
-            "--extra_decoding_ids",
-            "--output_no_speech_probs",
+        cross_qk_input_args = [
             "--use_vocab_mask",
             "--use_prefix_vocab_mask",
+            "--use_forced_decoder_ids",
             "--use_logits_processor",
+            "--collect_cross_qk",
+            "--extra_decoding_ids",
         ]
-        self.run_configs(decoder_input_ids)
+        cross_qk_output_args = [
+            "--output_cross_qk",
+            "--output_no_speech_probs",
+        ]
+        self.run_configs(cross_qk_input_args + cross_qk_output_args)
 
     @pytest.mark.slow
     def test_openai_impl_whisper(self):
-        optional_args = ["--model_impl", "openai", "--chain_model", "--use_whisper_beamsearch"]
+        optional_args = ["--model_impl", "openai"]
         self.run_configs(optional_args)
 
 
diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py
new file mode 100644
index 0000000000000..90b7da255081a
--- /dev/null
+++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py
@@ -0,0 +1,365 @@
+# --------------------------------------------------------------------------
+# Copyright 2020 The HuggingFace Inc. team
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+# --------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation.  All rights reserved.
+# Licensed under the MIT License.  See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+import unittest
+from collections import OrderedDict
+
+import numpy
+import torch
+import torch.nn.functional as F
+from onnx import TensorProto, helper
+from torch import nn
+
+import onnxruntime
+
+torch.manual_seed(42)
+numpy.random.seed(42)
+
+ORT_DTYPE = TensorProto.FLOAT
+NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32
+THRESHOLD = 3e-2
+
+
+def value_string_of(numpy_array):
+    arr = numpy_array.flatten()
+    lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)]
+    return "{\n    " + "f,\n    ".join(lines) + "f}"
+
+
+def print_tensor(name, numpy_array):
+    print(f"const std::vector<float> {name} = {value_string_of(numpy_array)};")
+
+
+def create_moe_onnx_graph(
+    num_rows,
+    num_experts,
+    hidden_size,
+    inter_size,
+    fc1_experts_weights,
+    fc2_experts_weights,
+    fc3_experts_weights,
+    topk,
+):
+    nodes = [
+        helper.make_node(
+            "MoE",
+            [
+                "input",
+                "router_probs",
+                "fc1_experts_weights",
+                "",
+                "fc2_experts_weights",
+                "",
+                "fc3_experts_weights",
+            ],
+            ["output"],
+            "MoE_0",
+            k=topk,
+            normalize_routing_weights=1,
+            activation_type="silu",
+            domain="com.microsoft",
+        ),
+    ]
+
+    fc1_shape = [num_experts, hidden_size, inter_size]
+    fc2_shape = [num_experts, inter_size, hidden_size]
+    fc3_shape = [num_experts, hidden_size, inter_size]
+
+    torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32
+
+    initializers = [
+        helper.make_tensor(
+            "fc1_experts_weights",
+            ORT_DTYPE,
+            fc1_shape,
+            fc1_experts_weights.to(torch_type).flatten().tolist(),
+            raw=False,
+        ),
+        helper.make_tensor(
+            "fc2_experts_weights",
+            ORT_DTYPE,
+            fc2_shape,
+            fc2_experts_weights.to(torch_type).flatten().tolist(),
+            raw=False,
+        ),
+        helper.make_tensor(
+            "fc3_experts_weights",
+            ORT_DTYPE,
+            fc3_shape,
+            fc3_experts_weights.to(torch_type).flatten().tolist(),
+            raw=False,
+        ),
+    ]
+
+    graph_inputs = [
+        helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]),
+    ]
+
+    graph_inputs.append(
+        helper.make_tensor_value_info(
+            "router_probs",
+            ORT_DTYPE,
+            [num_rows, num_experts],
+        )
+    )
+
+    graph_outputs = [
+        helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]),
+    ]
+
+    graph = helper.make_graph(
+        nodes,
+        "MoE_Graph",
+        graph_inputs,
+        graph_outputs,
+        initializers,
+    )
+
+    model = helper.make_model(graph)
+    return model.SerializeToString()
+
+
+class ClassInstantier(OrderedDict):
+    def __getitem__(self, key):
+        content = super().__getitem__(key)
+        cls, kwargs = content if isinstance(content, tuple) else (content, {})
+        return cls(**kwargs)
+
+
+ACT2CLS = {
+    "silu": nn.SiLU,
+}
+ACT2FN = ClassInstantier(ACT2CLS)
+
+
+class MixtralConfig:
+    def __init__(
+        self,
+        hidden_size=4096,
+        intermediate_size=14336,
+        num_hidden_layers=32,
+        num_attention_heads=32,
+        num_key_value_heads=8,
+        hidden_act="silu",
+        initializer_range=0.02,
+        rms_norm_eps=1e-5,
+        use_cache=True,
+        rope_theta=1e6,
+        attention_dropout=0.0,
+        num_experts_per_tok=2,
+        num_local_experts=8,
+        output_router_logits=False,
+        router_aux_loss_coef=0.001,
+    ):
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        if num_key_value_heads is None:
+            num_key_value_heads = num_attention_heads
+        self.num_key_value_heads = num_key_value_heads
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.rms_norm_eps = rms_norm_eps
+        self.use_cache = use_cache
+        self.rope_theta = rope_theta
+        self.attention_dropout = attention_dropout
+        self.num_experts_per_tok = num_experts_per_tok
+        self.num_local_experts = num_local_experts
+        self.output_router_logits = output_router_logits
+        self.router_aux_loss_coef = router_aux_loss_coef
+
+
+class MixtralBlockSparseTop2MLP(nn.Module):
+    def __init__(self, config: MixtralConfig):
+        super().__init__()
+        self.ffn_dim = config.intermediate_size
+        self.hidden_dim = config.hidden_size
+
+        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
+        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
+        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
+
+        self.act_fn = ACT2FN[config.hidden_act]
+
+    def forward(self, hidden_states):
+        current_hidden_states_1 = self.act_fn(self.w1(hidden_states))
+        current_hidden_states_3 = self.w3(hidden_states)
+        current_hidden_states = current_hidden_states_1 * current_hidden_states_3
+        current_hidden_states = self.w2(current_hidden_states)
+        return current_hidden_states
+
+
+class MixtralSparseMoeBlock(nn.Module):
+    """
+    This implementation is
+    strictly equivalent to standard MoE with full capacity (no
+    dropped tokens). It's faster since it formulates MoE operations
+    in terms of block-sparse operations to accommodate imbalanced
+    assignments of tokens to experts, whereas standard MoE either
+    (1) drop tokens at the cost of reduced performance or (2) set
+    capacity factor to number of experts and thus waste computation
+    and memory on padding.
+    """
+
+    def __init__(self, config, batch_size, sequence_length):
+        super().__init__()
+        self.hidden_dim = config.hidden_size
+        self.ffn_dim = config.intermediate_size
+        self.num_experts = config.num_local_experts
+        self.top_k = config.num_experts_per_tok
+
+        # gating
+        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
+
+        self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
+
+        w1_list = []
+        w2_list = []
+        w3_list = []
+        for i in range(self.num_experts):
+            w1_list.append(self.experts[i].w1.weight.transpose(0, 1))
+            w2_list.append(self.experts[i].w2.weight.transpose(0, 1))
+            w3_list.append(self.experts[i].w3.weight.transpose(0, 1))
+
+        self.moe_experts_weight1 = torch.stack(w1_list, dim=0)
+        self.moe_experts_weight2 = torch.stack(w2_list, dim=0)
+        self.moe_experts_weight3 = torch.stack(w3_list, dim=0)
+
+        self.batch_size = batch_size
+        self.sequence_length = sequence_length
+        self.moe_onnx_graph = create_moe_onnx_graph(
+            self.batch_size * self.sequence_length,
+            self.num_experts,
+            self.hidden_dim,
+            self.ffn_dim,
+            self.moe_experts_weight1,
+            self.moe_experts_weight2,
+            self.moe_experts_weight3,
+            self.top_k,
+        )
+
+        self.ort_sess = self.create_ort_session()
+
+    def create_ort_session(self):
+        from onnxruntime import InferenceSession, SessionOptions
+
+        sess_options = SessionOptions()
+
+        cuda_providers = ["CUDAExecutionProvider"]
+        if cuda_providers[0] not in onnxruntime.get_available_providers():
+            return None
+
+        sess_options.log_severity_level = 2
+        ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"])
+
+        return ort_session
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        """ """
+        batch_size, sequence_length, hidden_dim = hidden_states.shape
+        hidden_states = hidden_states.view(-1, hidden_dim)
+        # router_logits: (batch * sequence_length, n_experts)
+        router_logits = self.gate(hidden_states)
+
+        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+
+        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+        # we cast back to the input dtype
+        routing_weights = routing_weights.to(hidden_states.dtype)
+
+        final_hidden_states = torch.zeros(
+            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
+        )
+
+        # One hot encode the selected experts to create an expert mask
+        # this will be used to easily index which expert is going to be sollicitated
+        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+        # Loop over all available experts in the model and perform the computation on each expert
+        for expert_idx in range(self.num_experts):
+            expert_layer = self.experts[expert_idx]
+            idx, top_x = torch.where(expert_mask[expert_idx])
+
+            if top_x.shape[0] == 0:
+                continue
+
+            # in torch it is faster to index using lists than torch tensors
+            top_x_list = top_x.tolist()
+            idx_list = idx.tolist()
+
+            # Index the correct hidden states and compute the expert hidden state for
+            # the current expert. We need to make sure to multiply the output hidden
+            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+            current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
+            current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
+
+            # However `index_add_` only support torch tensors for indexing so we'll use
+            # the `top_x` tensor here.
+            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
+        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
+        return final_hidden_states  # , router_logits
+
+    def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        batch_size, sequence_length, hidden_dim = hidden_states.shape
+        hidden_states = hidden_states.view(-1, hidden_dim)
+        # router_logits: (batch * sequence_length, n_experts)
+        router_logits = self.gate(hidden_states)
+
+        ort_inputs = {
+            "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)),
+            "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)),
+        }
+
+        ort_output = None
+        if self.ort_sess is not None:
+            ort_output = self.ort_sess.run(None, ort_inputs)
+            return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1)  # , router_logits
+
+        # print_tensor("input", ort_inputs["input"])
+        # print_tensor("router_probs", ort_inputs["router_probs"])
+        # print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy())
+        # print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy())
+        # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy())
+        # print_tensor("output", ort_output[0])
+
+        return None
+
+    def parity_check(self):
+        hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim)
+        torch_output = self.forward(hidden_state)
+        ort_output = self.ort_forward(hidden_state)
+        if ort_output is not None:
+            assert torch.allclose(torch_output, ort_output, rtol=1e-04, atol=1e-04)
+            print(
+                "batch_size:",
+                self.batch_size,
+                " sequence_length:",
+                self.sequence_length,
+                " max_diff:",
+                (torch_output - ort_output).abs().max(),
+                " parity: OK",
+            )
+
+
+class TestMixtralMoE(unittest.TestCase):
+    def test_mixtral_moe_parity(self):
+        for batch_size in [1, 16]:
+            for sequence_length in [128, 1024]:
+                # use a small sizes to speed up the test
+                config = MixtralConfig(hidden_size=256, intermediate_size=1024)
+                mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length)
+                mixtral_moe.parity_check()
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_parity_moe.py
index 72ca5d9975c05..dbf6ee7dabb0e 100644
--- a/onnxruntime/test/python/transformers/test_parity_moe.py
+++ b/onnxruntime/test/python/transformers/test_parity_moe.py
@@ -47,8 +47,8 @@ def create_moe_onnx_graph(
     hidden_size,
     inter_size,
     fc1_experts_weights,
-    fc2_experts_weights,
     fc1_experts_bias,
+    fc2_experts_weights,
     fc2_experts_bias,
 ):
     nodes = [
@@ -58,8 +58,8 @@ def create_moe_onnx_graph(
                 "input",
                 "router_probs",
                 "fc1_experts_weights",
-                "fc2_experts_weights",
                 "fc1_experts_bias",
+                "fc2_experts_weights",
                 "fc2_experts_bias",
             ],
             ["output"],
@@ -250,8 +250,8 @@ def __init__(
             in_features,
             hidden_features,
             self.moe_experts.weight1,
-            self.moe_experts.weight2,
             self.moe_experts.bias1,
+            self.moe_experts.weight2,
             self.moe_experts.bias2,
         )
 
@@ -296,8 +296,6 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000):
             ).data_ptr(),
         )
 
-        iobinding.synchronize_inputs()
-
         iobinding.bind_output(
             name="output",
             device_type="cuda",
@@ -308,11 +306,12 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000):
                 numpy.zeros(ort_inputs["input"].shape), "cuda", device_id
             ).data_ptr(),
         )
-        iobinding.synchronize_outputs()
 
         s = time.time()
         for _ in range(repeat):
+            iobinding.synchronize_inputs()
             self.ort_sess.run_with_iobinding(iobinding)
+            iobinding.synchronize_outputs()
         e = time.time()
         print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms")
 
@@ -356,8 +355,8 @@ def onnx_forward(self, iobinding=False):
         # print_tensor("input", ort_inputs["input"])
         # print_tensor("router_probs", ort_inputs["router_probs"])
         # print_tensor("fc1_experts_weights", self.moe_experts.weight1.detach().numpy())
-        # print_tensor("fc2_experts_weights", self.moe_experts.weight2.detach().numpy())
         # print_tensor("fc1_experts_bias", self.moe_experts.bias1.detach().numpy())
+        # print_tensor("fc2_experts_weights", self.moe_experts.weight2.detach().numpy())
         # print_tensor("fc2_experts_bias", self.moe_experts.bias2.detach().numpy())
         # print_tensor("output", ort_output[0])
 
diff --git a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py
index 77ce09d7e793b..7892000ae45a0 100644
--- a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py
+++ b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py
@@ -50,7 +50,7 @@ def run_timestamp(self, provider: str):
         ort_out = sess.run(None, ort_inputs)
         ort_out_tensor = torch.from_numpy(ort_out[0])
         ort_transcription = processor.batch_decode(
-            ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True
+            ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True, decode_with_timestamps=True
         )
         print(ort_transcription)
         expected_transcription = [
@@ -58,7 +58,7 @@ def run_timestamp(self, provider: str):
                 "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>",
                 "offsets": [
                     {
-                        "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>",
+                        "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
                         "timestamp": (0.0, 5.44),
                     }
                 ],
diff --git a/onnxruntime/test/python/transformers/whisper_model_generator.py b/onnxruntime/test/python/transformers/whisper_model_generator.py
index 71d1a4cbdceeb..a57b45cbc5ea3 100644
--- a/onnxruntime/test/python/transformers/whisper_model_generator.py
+++ b/onnxruntime/test/python/transformers/whisper_model_generator.py
@@ -22,9 +22,7 @@ def get_tensor_and_weight(name: str, shape: List[int], random=False, zeros=False
     weights = (
         [np.random.uniform(low, high) for _ in range(total_elements)]
         if random
-        else [0.0] * total_elements
-        if zeros
-        else [1.0] * total_elements
+        else [0.0] * total_elements if zeros else [1.0] * total_elements
     )
     return helper.make_tensor(name, TensorProto.FLOAT, shape, weights), weights
 
diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc
index 8dad2c8e2d10d..52dd2a84e383b 100644
--- a/onnxruntime/test/shared_lib/test_inference.cc
+++ b/onnxruntime/test/shared_lib/test_inference.cc
@@ -180,6 +180,9 @@ static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& mod
 }
 
 static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/mul_1.onnx");
+#if defined(USE_CUDA)
+static constexpr PATH_TYPE CUDA_GRAPH_ANNOTATION_MODEL_URI = TSTR("testdata/mul_1_dynamic.onnx");
+#endif
 static constexpr PATH_TYPE MATMUL_MODEL_URI = TSTR("testdata/matmul_1.onnx");
 #ifndef ORT_NO_RTTI
 static constexpr PATH_TYPE SEQUENCE_MODEL_URI = TSTR("testdata/sequence_length.onnx");
@@ -205,7 +208,7 @@ static constexpr PATH_TYPE MODEL_WITH_CUSTOM_MODEL_METADATA = TSTR("testdata/mod
 static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/VariedInputCustomOp.onnx");
 static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_3.onnx");
 static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx");
-static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx");
+static constexpr PATH_TYPE OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx");
 static constexpr PATH_TYPE VARIADIC_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/custom_op_variadic_io.onnx");
 static constexpr PATH_TYPE VARIADIC_UNDEF_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR(
     "testdata/custom_op_variadic_undef_io.onnx");
@@ -1079,7 +1082,7 @@ TEST(CApiTest, invalid_variadic_input_homogeneity_custom_op) {
   }
 }
 
-TEST(CApiTest, optional_input_output_custom_op_handler) {
+TEST(CApiTest, optional_input_custom_op_handler) {
   MyCustomOpWithOptionalInput custom_op{onnxruntime::kCpuExecutionProvider};
 
   // `MyCustomOpFooBar` defines a custom op with atmost 3 inputs and the second input is optional.
@@ -1144,7 +1147,7 @@ TEST(CApiTest, optional_input_output_custom_op_handler) {
   {
     std::vector<const char*> input_names = {"X1", "X2"};
     ort_inputs.erase(ort_inputs.begin() + 2);  // remove the last input in the container
-    Ort::Session session(*ort_env, OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2, session_options);
+    Ort::Session session(*ort_env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2, session_options);
     auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
                                    &output_name, 1);
     ASSERT_EQ(ort_outputs.size(), 1u);
@@ -1163,6 +1166,7 @@ TEST(CApiTest, optional_input_output_custom_op_handler) {
     }
   }
 }
+
 TEST(CApiTest, custom_op_with_attributes_handler) {
   MyCustomOpWithAttributes custom_op{onnxruntime::kCpuExecutionProvider};
 
@@ -2082,6 +2086,152 @@ TEST(CApiTest, basic_cuda_graph) {
 #endif
 }
 
+#if defined(USE_CUDA)
+struct CudaGraphInputOutputData_0 {
+  const std::array<int64_t, 2> x_shape = {3, 2};
+  std::array<float, 3 * 2> x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  const std::array<int64_t, 2> expected_y_shape = {3, 2};
+  std::array<float, 3 * 2> expected_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};
+
+  std::array<float, 3 * 2> y_values;
+  std::array<float, 3 * 2> new_x_values = {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f};
+  std::array<float, 3 * 2> new_expected_y = {10.0f, 40.0f, 90.0f, 160.0f, 250.0f, 360.0f};
+} cg_data_0;
+
+struct CudaGraphInputOutputData_1 {
+  const std::array<int64_t, 2> x_shape = {3, 1};
+  std::array<float, 3> x_values = {1.0f, 3.0f, 5.0f};
+  const std::array<int64_t, 2> expected_y_shape = {3, 2};
+  std::array<float, 3 * 2> expected_y = {1.0f, 2.0f, 9.0f, 12.0f, 25.0f, 30.0f};
+
+  std::array<float, 3 * 2> y_values;
+  std::array<float, 3> new_x_values = {10.0f, 30.0f, 50.0f};
+  std::array<float, 3 * 2> new_expected_y = {10.0f, 20.0f, 90.0f, 120.0f, 250.0f, 300.0f};
+} cg_data_1;
+
+struct CudaGraphInputOutputData_2 {
+  const std::array<int64_t, 2> x_shape = {1, 2};
+  std::array<float, 3 * 2> x_values = {1.0f, 2.0f};
+  const std::array<int64_t, 2> expected_y_shape = {3, 2};
+  std::array<float, 3 * 2> expected_y = {1.0f, 4.0f, 3.0f, 8.0f, 5.0f, 12.0f};
+
+  std::array<float, 3 * 2> y_values;
+  std::array<float, 3 * 2> new_x_values = {10.0f, 20.0f};
+  std::array<float, 3 * 2> new_expected_y = {10.0f, 40.0f, 30.0f, 80.0f, 50.0f, 120.0f};
+} cg_data_2;
+
+template <typename T>
+static void RunWithCudaGraphAnnotation(T& cg_data,
+                                       Ort::Session& session,
+                                       Ort::MemoryInfo& info_mem,
+                                       Ort::MemoryAllocation& input_data,
+                                       Ort::MemoryAllocation& output_data,
+                                       const char* cuda_graph_annotation) {
+  (void)cudaMemcpy(input_data.get(),
+                   cg_data.x_values.data(),
+                   sizeof(float) * cg_data.x_values.size(),
+                   cudaMemcpyHostToDevice);
+
+  // Create an OrtValue tensor backed by data on CUDA memory
+  Ort::Value bound_x = Ort::Value::CreateTensor(info_mem,
+                                                reinterpret_cast<float*>(input_data.get()),
+                                                cg_data.x_values.size(),
+                                                cg_data.x_shape.data(),
+                                                cg_data.x_shape.size());
+
+  // Create an OrtValue tensor backed by data on CUDA memory
+  Ort::Value bound_y = Ort::Value::CreateTensor(info_mem,
+                                                reinterpret_cast<float*>(output_data.get()),
+                                                cg_data.expected_y.size(),
+                                                cg_data.expected_y_shape.data(),
+                                                cg_data.expected_y_shape.size());
+
+  // Create IoBinding for inputs and outputs.
+  Ort::IoBinding binding(session);
+  binding.BindInput("X", bound_x);
+  binding.BindOutput("Y", bound_y);
+
+  Ort::RunOptions run_option;
+  if (cuda_graph_annotation != nullptr) {
+    run_option.AddConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation, cuda_graph_annotation);
+  }
+
+  // One regular run for necessary memory allocation and graph capturing
+  session.Run(run_option, binding);
+
+  // Check the values against the bound raw memory (needs copying from device to host first)
+  (void)cudaMemcpy(cg_data.y_values.data(),
+                   output_data.get(),
+                   sizeof(float) * cg_data.y_values.size(),
+                   cudaMemcpyDeviceToHost);
+  ASSERT_THAT(cg_data.y_values, ::testing::ContainerEq(cg_data.expected_y));
+
+  // Replay the captured CUDA graph
+  session.Run(run_option, binding);
+  (void)cudaMemcpy(cg_data.y_values.data(),
+                   output_data.get(),
+                   sizeof(float) * cg_data.y_values.size(),
+                   cudaMemcpyDeviceToHost);
+  ASSERT_THAT(cg_data.y_values, ::testing::ContainerEq(cg_data.expected_y));
+
+  // Change the input and replay the CUDA graph again.
+  (void)cudaMemcpy(input_data.get(),
+                   cg_data.new_x_values.data(),
+                   sizeof(float) * cg_data.new_x_values.size(),
+                   cudaMemcpyHostToDevice);
+  binding.SynchronizeInputs();
+
+  session.Run(run_option, binding);
+  (void)cudaMemcpy(cg_data.y_values.data(),
+                   output_data.get(),
+                   sizeof(float) * cg_data.y_values.size(),
+                   cudaMemcpyDeviceToHost);
+  ASSERT_THAT(cg_data.y_values, ::testing::ContainerEq(cg_data.new_expected_y));
+
+  // Clean up
+  binding.ClearBoundInputs();
+  binding.ClearBoundOutputs();
+}
+
+TEST(CApiTest, basic_cuda_graph_with_annotation) {
+  const auto& api = Ort::GetApi();
+  Ort::SessionOptions session_options;
+
+  // Enable cuda graph in cuda provider option.
+  OrtCUDAProviderOptionsV2* cuda_options = nullptr;
+  ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr);
+  std::unique_ptr<OrtCUDAProviderOptionsV2, decltype(api.ReleaseCUDAProviderOptions)>
+      rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions);
+  std::vector<const char*> keys{"enable_cuda_graph"};
+  std::vector<const char*> values{"1"};
+  ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr);
+
+  ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2(
+                  static_cast<OrtSessionOptions*>(session_options),
+                  rel_cuda_options.get()) == nullptr);
+
+  Ort::Session session(*ort_env, CUDA_GRAPH_ANNOTATION_MODEL_URI, session_options);
+  Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault);
+
+  Ort::Allocator allocator(session, info_mem);
+  auto allocator_info = allocator.GetInfo();
+  ASSERT_TRUE(info_mem == allocator_info);
+
+  size_t max_input_size = 6;
+  size_t max_output_size = 6;
+
+  auto input_data = allocator.GetAllocation(max_input_size * sizeof(float));
+  auto output_data = allocator.GetAllocation(max_output_size * sizeof(float));
+
+  ASSERT_NE(input_data.get(), nullptr);
+  ASSERT_NE(output_data.get(), nullptr);
+
+  RunWithCudaGraphAnnotation(cg_data_0, session, info_mem, input_data, output_data, nullptr);
+  RunWithCudaGraphAnnotation(cg_data_1, session, info_mem, input_data, output_data, "1");
+  RunWithCudaGraphAnnotation(cg_data_2, session, info_mem, input_data, output_data, "2");
+}
+#endif
+
 // The following test uses some ops not supported in the reduced ops build
 #ifndef REDUCED_OPS_BUILD
 #if defined(USE_CUDA) || defined(USE_TENSORRT)
@@ -3858,3 +4008,34 @@ TEST(CApiTest, RunAsyncFail) {
   Ort::RunOptions run_options;
   EXPECT_THROW(session.RunAsync(run_options, input_names, input_tensors, 1, output_names, output_values, 1, CallbackFail, nullptr), std::exception);
 }
+
+struct MockGQA : public OrtCustomOp {
+  MockGQA() {
+    OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) {
+      size_t ret = 2;
+      *input_index = static_cast<int*>(malloc(ret * sizeof(int)));
+      (*input_index)[0] = 3;
+      (*input_index)[1] = 4;
+      *output_index = static_cast<int*>(malloc(ret * sizeof(int)));
+      (*output_index)[0] = 1;
+      (*output_index)[1] = 2;
+      return ret;
+    };
+  }
+};
+
+TEST(CApiTest, OrtCustomOp_GetInPlace) {
+  MockGQA mock_gqa;
+  int* input_index = nullptr;
+  int* output_index = nullptr;
+  size_t len = mock_gqa.GetMayInplace(&input_index, &output_index);
+  ASSERT_NE(input_index, nullptr);
+  ASSERT_NE(output_index, nullptr);
+  ASSERT_EQ(input_index[0], 3);
+  ASSERT_EQ(input_index[1], 4);
+  ASSERT_EQ(output_index[0], 1);
+  ASSERT_EQ(output_index[1], 2);
+  ASSERT_EQ(len, static_cast<size_t>(2));
+  free(input_index);
+  free(output_index);
+}
diff --git a/onnxruntime/test/testdata/custom_op_library/custom_op_test_float8.py b/onnxruntime/test/testdata/custom_op_library/custom_op_test_float8.py
index 84cf71455f84a..6db8e8fe660f8 100644
--- a/onnxruntime/test/testdata/custom_op_library/custom_op_test_float8.py
+++ b/onnxruntime/test/testdata/custom_op_library/custom_op_test_float8.py
@@ -1,6 +1,7 @@
 """
 This file was used to generate model `custom_op_test_float8.py`.
 """
+
 from onnx import TensorProto
 from onnx.checker import check_model
 from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info
diff --git a/onnxruntime/test/testdata/matmul_integer_to_float.py b/onnxruntime/test/testdata/matmul_integer_to_float.py
index b898390044cf4..e6c51009018f9 100644
--- a/onnxruntime/test/testdata/matmul_integer_to_float.py
+++ b/onnxruntime/test/testdata/matmul_integer_to_float.py
@@ -4,7 +4,7 @@
 from onnx import TensorProto, helper
 
 
-def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False):  # noqa: N802
+def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bias=False):  # noqa: N802
     nodes = [  # subgraph
         helper.make_node(
             "MatMulInteger",
@@ -13,7 +13,13 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False):  # noqa:
             "MatMulInteger",
         ),
         helper.make_node("Mul", ["a_scale", "b_scale"], ["multiplier"], "mul_right"),
-        helper.make_node("Cast", ["matmul_output_int32"], ["matmul_output_float"], "cast", to=1),
+        helper.make_node(
+            "Cast",
+            ["matmul_output_int32"],
+            ["matmul_output_float"],
+            "cast",
+            to=TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT,
+        ),
         helper.make_node(
             "Mul",
             ["matmul_output_float", "multiplier"],
@@ -25,8 +31,8 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False):  # noqa:
     inputs = [  # inputs
         helper.make_tensor_value_info("A", TensorProto.INT8 if sign_i else TensorProto.UINT8, ["M", "K"]),
         helper.make_tensor_value_info("B", TensorProto.INT8 if sign_w else TensorProto.UINT8, ["K", "N"]),
-        helper.make_tensor_value_info("a_scale", TensorProto.FLOAT, [1]),
-        helper.make_tensor_value_info("b_scale", TensorProto.FLOAT, ["C"]),
+        helper.make_tensor_value_info("a_scale", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, [1]),
+        helper.make_tensor_value_info("b_scale", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["C"]),
     ]
 
     if has_zp:
@@ -48,14 +54,22 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False):  # noqa:
     if bias:
         nodes.extend([helper.make_node("Add", ["mul_bottom_output", "bias"], ["Y"], "add")])
 
-        inputs.extend([helper.make_tensor_value_info("bias", TensorProto.FLOAT, ["N"])])
+        inputs.extend(
+            [
+                helper.make_tensor_value_info(
+                    "bias", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["N"]
+                )
+            ]
+        )
 
     graph = helper.make_graph(
         nodes,
         "DynamicQuantizeMatMul_fusion",  # name
         inputs,
         [  # outputs
-            helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["M", "N"]),
+            helper.make_tensor_value_info(
+                "Y", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["M", "N"]
+            ),
         ],
     )
 
@@ -64,10 +78,32 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False):  # noqa:
 
 
 if __name__ == "__main__":
-    GenerateModel("matmul_integer_to_float_int8.onnx", False, True)
-    GenerateModel("matmul_integer_to_float_uint8.onnx", False, False)
-    GenerateModel("matmul_integer_to_float_int8_bias.onnx", False, True, False, True)
-    GenerateModel("matmul_integer_to_float_uint8_bias.onnx", False, False, False, True)
+    GenerateModel("matmul_integer_to_float16_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=True)
+    GenerateModel("matmul_integer_to_float_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=False)
+    GenerateModel("matmul_integer_to_float_uint8.onnx", sign_i=False, sign_w=False, output_type_fp16=False)
+    GenerateModel(
+        "matmul_integer_to_float_int8_bias.onnx",
+        sign_i=False,
+        sign_w=True,
+        output_type_fp16=False,
+        has_zp=False,
+        bias=True,
+    )
+    GenerateModel(
+        "matmul_integer_to_float_uint8_bias.onnx",
+        sign_i=False,
+        sign_w=False,
+        output_type_fp16=False,
+        has_zp=False,
+        bias=True,
+    )
 
-    GenerateModel("matmul_integer_to_float_int8_int8.onnx", True, True)
-    GenerateModel("matmul_integer_to_float_int8_int8_bias.onnx", True, True, False, True)
+    GenerateModel("matmul_integer_to_float_int8_int8.onnx", sign_i=True, sign_w=True, output_type_fp16=False)
+    GenerateModel(
+        "matmul_integer_to_float_int8_int8_bias.onnx",
+        sign_i=True,
+        sign_w=True,
+        output_type_fp16=False,
+        has_zp=False,
+        bias=True,
+    )
diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx
index 9f4465a914963..906dec542a4fa 100644
--- a/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx
+++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx
@@ -1,4 +1,4 @@
-:Ì
+	:Ì
 U
 A
 B
@@ -44,4 +44,4 @@ mul_bottom"MulDynamicQuantizeMatMul_fusionZ
 
 
 M
-NB
\ No newline at end of file
+NB
\ No newline at end of file
diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx
index 01b7e15aa4a1f..16cdf03c7ae59 100644
--- a/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx
+++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx
@@ -1,4 +1,4 @@
-:Ä
+	:Ä
 9
 A
 Bmatmul_output_int32
MatMulInteger"
MatMulInteger
@@ -41,4 +41,4 @@ mul_bottom"Mul
 
 
 M
-NB
\ No newline at end of file
+NB
\ No newline at end of file
diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx
index 9d38828e25d6a..55102757a0b57 100644
--- a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx
+++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx
@@ -1,4 +1,4 @@
-:Ì
+	:Ì
 U
 A
 B
@@ -44,4 +44,4 @@ mul_bottom"MulDynamicQuantizeMatMul_fusionZ
 
 
 M
-NB
\ No newline at end of file
+NB
\ No newline at end of file
diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx
index 4d9a55af50a87..d9d7222a1acaa 100644
--- a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx
+++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx
@@ -1,4 +1,4 @@
-:Ä
+	:Ä
 9
 A
 Bmatmul_output_int32
MatMulInteger"
MatMulInteger
@@ -41,4 +41,4 @@ mul_bottom"Mul
 
 
 M
-NB
\ No newline at end of file
+NB
\ No newline at end of file
diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx
index a4c6d20d59be8..5373ce145688e 100644
--- a/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx
+++ b/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx
@@ -1,4 +1,4 @@
-:Ì
+	:Ì
 U
 A
 B
@@ -44,4 +44,4 @@ mul_bottom"MulDynamicQuantizeMatMul_fusionZ
 
 
 M
-NB
\ No newline at end of file
+NB
\ No newline at end of file
diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx
index a5be0c63f4dcb..e407414b23b24 100644
--- a/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx
+++ b/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx
@@ -1,4 +1,4 @@
-:Ä
+	:Ä
 9
 A
 Bmatmul_output_int32
MatMulInteger"
MatMulInteger
@@ -41,4 +41,4 @@ mul_bottom"Mul
 
 
 M
-NB
\ No newline at end of file
+NB
\ No newline at end of file
diff --git a/onnxruntime/test/testdata/mul_1_dynamic.onnx b/onnxruntime/test/testdata/mul_1_dynamic.onnx
new file mode 100644
index 0000000000000..fb7822498b004
Binary files /dev/null and b/onnxruntime/test/testdata/mul_1_dynamic.onnx differ
diff --git a/onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx b/onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx
new file mode 100644
index 0000000000000..dc7d39206dd49
Binary files /dev/null and b/onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx differ
diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
index ca089c42032b1..f120bf9968558 100644
--- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
+++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
@@ -493,9 +493,12 @@
         "test_range_float_type_positive_delta_expanded_cpu", // Error but not a failure.
         "test_scan_sum_cpu", // Disabled due to output mismatch with tolerance.
         "test_scan9_sum_cpu", // Disabled due to output mismatch with tolerance.
-        "test_reduce_max_bool_inputs_cpu"
+        "test_reduce_max_bool_inputs_cpu",
+        "test_gelu_default_1_cpu", // Disabled due to accuracy mismatch
+        "test_gelu_default_2_cpu"
+        
     ],
-    "current_failing_tests_OPENVINO_NPU_FP16": [
+    "current_failing_tests_OPENVINO_NPU": [
         "^test_prelu_broadcast",
         "test_loop11_cpu"
     ],
diff --git a/onnxruntime/test/testdata/ort_github_issue_19590.onnx b/onnxruntime/test/testdata/ort_github_issue_19590.onnx
new file mode 100644
index 0000000000000..fa07b624780bb
Binary files /dev/null and b/onnxruntime/test/testdata/ort_github_issue_19590.onnx differ
diff --git a/onnxruntime/test/testdata/ort_github_issue_19590.py b/onnxruntime/test/testdata/ort_github_issue_19590.py
new file mode 100644
index 0000000000000..9be07134fd8ad
--- /dev/null
+++ b/onnxruntime/test/testdata/ort_github_issue_19590.py
@@ -0,0 +1,77 @@
+import onnx
+from onnx import TensorProto, helper
+
+# graph with a QDQ MatMul node unit where one input is and initializer -> DQ and the other is on a path that
+# contains a supported node followed by an unsupported node followed by the DQ -> MatMul.
+# The DQ of the initializer is prior to the unsupported node. If the partitioning utils do not process the QDQ node
+# unit together, the DQ for the initializer and the first supported node will be in the first partition, which
+# incorrectly breaks up the QDQ node unit.
+graph_proto = helper.make_graph(
+    [
+        # DQ of initializer for MatMul B input
+        helper.make_node(
+            "DequantizeLinear",
+            inputs=["matmul_b_uint8", "scale0"],
+            outputs=["dq_matmul_b"],
+            name="dq_matmul_b",
+        ),
+        # Treat as supported
+        helper.make_node(
+            "Mul",
+            inputs=["input:0", "scale_input"],
+            outputs=["mul:0"],
+            name="mul0",
+        ),
+        # Treat as unsupported
+        helper.make_node("Cast", inputs=["mul:0"], outputs=["mul_uint8"], name="cast0", to=2),
+        # DQ of MatMul A input
+        helper.make_node(
+            "DequantizeLinear",
+            inputs=["mul_uint8", "scale1"],
+            outputs=["dq_matmul_a"],
+            name="dq_matmul_a",
+        ),
+        # MatMul
+        helper.make_node(
+            "MatMul",
+            inputs=[
+                "dq_matmul_a",
+                "dq_matmul_b",
+            ],
+            outputs=["matmul_ab"],
+            name="matmul_ab",
+        ),
+        # Q
+        helper.make_node(
+            "QuantizeLinear",
+            inputs=["matmul_ab", "scale2"],
+            outputs=["q_matmul_ab"],
+            name="q_matmul_ab",
+        ),
+        # DQ for model output
+        helper.make_node(
+            "DequantizeLinear",
+            inputs=["q_matmul_ab", "scale2"],
+            outputs=["out:0"],
+            name="dq_graph_output",
+        ),
+    ],
+    "Main_graph",
+    [
+        helper.make_tensor_value_info("input:0", TensorProto.FLOAT, [3, 2]),
+    ],
+    [
+        helper.make_tensor_value_info("out:0", TensorProto.FLOAT, [3, 2]),
+    ],
+    [
+        helper.make_tensor("scale0", TensorProto.FLOAT, [1], [20.0]),
+        helper.make_tensor("scale1", TensorProto.FLOAT, [1], [30.0]),
+        helper.make_tensor("scale2", TensorProto.FLOAT, [1], [40.0]),
+        helper.make_tensor("matmul_b_uint8", TensorProto.UINT8, [2, 2], [1, 2, 3, 4]),
+        helper.make_tensor("scale_input", TensorProto.FLOAT, [2], [3.0, 4.0]),
+    ],
+)
+
+model = helper.make_model(graph_proto)
+onnx.checker.check_model(model, True)
+onnx.save(model, "ort_github_issue_19590.onnx")
diff --git a/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py b/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py
index 4c1e3a70de1c7..443444044bb8d 100644
--- a/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py
+++ b/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py
@@ -190,7 +190,7 @@ def main():
     device_candidates = ["cuda", "cpu"]
     test_data_step_count = 11
     for device in device_candidates:
-        for adam_mode in range(0, 2):
+        for adam_mode in range(2):
             generate_adamw_single_weight_tests(adam_mode, test_data_step_count, device)
             generate_adamw_multiple_weights_tests(adam_mode, test_data_step_count, device)
 
diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py
index ed06495b42beb..54fe7b808bf12 100644
--- a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py
+++ b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py
@@ -21,19 +21,21 @@ def GenerateNodes(model_name, has_cast, suffix=""):  # noqa: N802
             ["gather0_out" + suffix],
             "gather0" + suffix,
         ),
-        helper.make_node(
-            "Unsqueeze",
-            ["gather0_out" + suffix, "axes_0"],
-            ["unsqueeze0_out" + suffix],
-            "unsqueeze0" + suffix,
-        )
-        if opset_version == 13
-        else helper.make_node(
-            "Unsqueeze",
-            ["gather0_out" + suffix],
-            ["unsqueeze0_out" + suffix],
-            "unsqueeze0" + suffix,
-            axes=[0],
+        (
+            helper.make_node(
+                "Unsqueeze",
+                ["gather0_out" + suffix, "axes_0"],
+                ["unsqueeze0_out" + suffix],
+                "unsqueeze0" + suffix,
+            )
+            if opset_version == 13
+            else helper.make_node(
+                "Unsqueeze",
+                ["gather0_out" + suffix],
+                ["unsqueeze0_out" + suffix],
+                "unsqueeze0" + suffix,
+                axes=[0],
+            )
         ),
         helper.make_node("Shape", ["input_ids" + suffix], ["shape2_out" + suffix], "shape2" + suffix),
         helper.make_node(
@@ -42,19 +44,21 @@ def GenerateNodes(model_name, has_cast, suffix=""):  # noqa: N802
             ["gather1_out" + suffix],
             "gather1" + suffix,
         ),
-        helper.make_node(
-            "Unsqueeze",
-            ["gather1_out" + suffix, "axes_0"],
-            ["unsqueeze1_out" + suffix],
-            "unsqueeze1" + suffix,
-        )
-        if opset_version == 13
-        else helper.make_node(
-            "Unsqueeze",
-            ["gather1_out" + suffix],
-            ["unsqueeze1_out" + suffix],
-            "unsqueeze1" + suffix,
-            axes=[0],
+        (
+            helper.make_node(
+                "Unsqueeze",
+                ["gather1_out" + suffix, "axes_0"],
+                ["unsqueeze1_out" + suffix],
+                "unsqueeze1" + suffix,
+            )
+            if opset_version == 13
+            else helper.make_node(
+                "Unsqueeze",
+                ["gather1_out" + suffix],
+                ["unsqueeze1_out" + suffix],
+                "unsqueeze1" + suffix,
+                axes=[0],
+            )
         ),
         helper.make_node(
             "Concat",
@@ -80,19 +84,21 @@ def GenerateNodes(model_name, has_cast, suffix=""):  # noqa: N802
             ["range_out" + suffix],
             "range" + suffix,
         ),
-        helper.make_node(
-            "Unsqueeze",
-            ["range_out" + suffix, "axes_0"],
-            ["unsqueeze2_out" + suffix],
-            "unsqueeze2" + suffix,
-        )
-        if opset_version == 13
-        else helper.make_node(
-            "Unsqueeze",
-            ["range_out" + suffix],
-            ["unsqueeze2_out" + suffix],
-            "unsqueeze2" + suffix,
-            axes=[0],
+        (
+            helper.make_node(
+                "Unsqueeze",
+                ["range_out" + suffix, "axes_0"],
+                ["unsqueeze2_out" + suffix],
+                "unsqueeze2" + suffix,
+            )
+            if opset_version == 13
+            else helper.make_node(
+                "Unsqueeze",
+                ["range_out" + suffix],
+                ["unsqueeze2_out" + suffix],
+                "unsqueeze2" + suffix,
+                axes=[0],
+            )
         ),
         helper.make_node(
             "Expand",
@@ -145,21 +151,23 @@ def GenerateNodes(model_name, has_cast, suffix=""):  # noqa: N802
             "mask_cast" + suffix,
             to=6,
         ),
-        helper.make_node(
-            "ReduceSum",
-            ["mask_cast_out" + suffix, "axes_1"],
-            ["mask_index_out" + suffix],
-            "mask_index" + suffix,
-            keepdims=0,
-        )
-        if opset_version == 13
-        else helper.make_node(
-            "ReduceSum",
-            ["mask_cast_out" + suffix],
-            ["mask_index_out" + suffix],
-            "mask_index" + suffix,
-            axes=[1],
-            keepdims=0,
+        (
+            helper.make_node(
+                "ReduceSum",
+                ["mask_cast_out" + suffix, "axes_1"],
+                ["mask_index_out" + suffix],
+                "mask_index" + suffix,
+                keepdims=0,
+            )
+            if opset_version == 13
+            else helper.make_node(
+                "ReduceSum",
+                ["mask_cast_out" + suffix],
+                ["mask_index_out" + suffix],
+                "mask_index" + suffix,
+                axes=[1],
+                keepdims=0,
+            )
         ),
         helper.make_node(
             "Attention",
@@ -372,21 +380,23 @@ def GenerateModel5(model_name):  # noqa: N802
             epsion=0.000009999999747378752,
         ),
         helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6),
-        helper.make_node(
-            "ReduceSum",
-            ["mask_cast_out", "axes_1"],
-            ["mask_index_out"],
-            "mask_index",
-            keepdims=0,
-        )
-        if opset_version == 13
-        else helper.make_node(
-            "ReduceSum",
-            ["mask_cast_out"],
-            ["mask_index_out"],
-            "mask_index",
-            axes=[1],
-            keepdims=0,
+        (
+            helper.make_node(
+                "ReduceSum",
+                ["mask_cast_out", "axes_1"],
+                ["mask_index_out"],
+                "mask_index",
+                keepdims=0,
+            )
+            if opset_version == 13
+            else helper.make_node(
+                "ReduceSum",
+                ["mask_cast_out"],
+                ["mask_index_out"],
+                "mask_index",
+                axes=[1],
+                keepdims=0,
+            )
         ),
         helper.make_node(
             "Attention",
@@ -514,14 +524,18 @@ def GenerateModel6(model_name):  # noqa: N802
     nodes = [  # LayerNorm subgraph
         helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"),
         helper.make_node("Gather", ["shape1_out", "indices_0"], ["gather0_out"], "gather0"),
-        helper.make_node("Unsqueeze", ["gather0_out", "axes_0"], ["unsqueeze0_out"], "unsqueeze0")
-        if opset_version == 13
-        else helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
+        (
+            helper.make_node("Unsqueeze", ["gather0_out", "axes_0"], ["unsqueeze0_out"], "unsqueeze0")
+            if opset_version == 13
+            else helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0])
+        ),
         helper.make_node("Shape", ["input_ids"], ["shape2_out"], "shape2"),
         helper.make_node("Gather", ["shape2_out", "indices_1"], ["gather1_out"], "gather1"),
-        helper.make_node("Unsqueeze", ["gather1_out", "axes_0"], ["unsqueeze1_out"], "unsqueeze1")
-        if opset_version == 13
-        else helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
+        (
+            helper.make_node("Unsqueeze", ["gather1_out", "axes_0"], ["unsqueeze1_out"], "unsqueeze1")
+            if opset_version == 13
+            else helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0])
+        ),
         helper.make_node(
             "Concat",
             ["unsqueeze0_out", "unsqueeze1_out"],
@@ -533,9 +547,11 @@ def GenerateModel6(model_name):  # noqa: N802
         helper.make_node("Equal", ["reshape_out", "equal_init"], ["equal_out"], "equal"),
         helper.make_node("Where", ["equal_out", "where_init", "reshape_out"], ["where_out"], "where"),
         helper.make_node("Range", ["start_0", "gather1_out", "delta_1"], ["range_out"], "range"),
-        helper.make_node("Unsqueeze", ["range_out", "axes_0"], ["unsqueeze2_out"], "unsqueeze2")
-        if opset_version == 13
-        else helper.make_node("Unsqueeze", ["range_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
+        (
+            helper.make_node("Unsqueeze", ["range_out", "axes_0"], ["unsqueeze2_out"], "unsqueeze2")
+            if opset_version == 13
+            else helper.make_node("Unsqueeze", ["range_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0])
+        ),
         helper.make_node("Expand", ["unsqueeze2_out", "where_out"], ["expand_out"], "expand"),
         helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather"),
         helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather"),
@@ -556,21 +572,23 @@ def GenerateModel6(model_name):  # noqa: N802
             epsion=0.000009999999747378752,
         ),
         helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6),
-        helper.make_node(
-            "ReduceSum",
-            ["mask_cast_out", "axes_1"],
-            ["mask_index_out"],
-            "mask_index",
-            keepdims=0,
-        )
-        if opset_version == 13
-        else helper.make_node(
-            "ReduceSum",
-            ["mask_cast_out"],
-            ["mask_index_out"],
-            "mask_index",
-            axes=[1],
-            keepdims=0,
+        (
+            helper.make_node(
+                "ReduceSum",
+                ["mask_cast_out", "axes_1"],
+                ["mask_index_out"],
+                "mask_index",
+                keepdims=0,
+            )
+            if opset_version == 13
+            else helper.make_node(
+                "ReduceSum",
+                ["mask_cast_out"],
+                ["mask_index_out"],
+                "mask_index",
+                axes=[1],
+                keepdims=0,
+            )
         ),
         helper.make_node(
             "Attention",
@@ -756,9 +774,11 @@ def GenerateNodes2(attention_heads):  # noqa: N802
         helper.make_node("Shape", ["input_ids"], ["shape0_out"], "shape0"),
         helper.make_node("Gather", ["shape0_out", "indices_1"], ["gather0_out"], "gather0"),
         helper.make_node("Range", ["start", "gather0_out", "delta"], ["range0_out"], "range0"),
-        helper.make_node("Unsqueeze", ["range0_out", "axes_0"], ["unsqueeze0_out"], "unsqueeze0")
-        if opset_version == 13
-        else helper.make_node("Unsqueeze", ["range0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
+        (
+            helper.make_node("Unsqueeze", ["range0_out", "axes_0"], ["unsqueeze0_out"], "unsqueeze0")
+            if opset_version == 13
+            else helper.make_node("Unsqueeze", ["range0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0])
+        ),
         helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"),
         helper.make_node("Expand", ["unsqueeze0_out", "shape1_out"], ["expand_out"], "expand"),
         helper.make_node(
@@ -778,21 +798,23 @@ def GenerateNodes2(attention_heads):  # noqa: N802
             epsion=0.000009999999747378752,
         ),
         helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6),
-        helper.make_node(
-            "ReduceSum",
-            ["mask_cast_out", "axes_1"],
-            ["mask_index_out"],
-            "mask_index",
-            keepdims=0,
-        )
-        if opset_version == 13
-        else helper.make_node(
-            "ReduceSum",
-            ["mask_cast_out"],
-            ["mask_index_out"],
-            "mask_index",
-            axes=[1],
-            keepdims=0,
+        (
+            helper.make_node(
+                "ReduceSum",
+                ["mask_cast_out", "axes_1"],
+                ["mask_index_out"],
+                "mask_index",
+                keepdims=0,
+            )
+            if opset_version == 13
+            else helper.make_node(
+                "ReduceSum",
+                ["mask_cast_out"],
+                ["mask_index_out"],
+                "mask_index",
+                axes=[1],
+                keepdims=0,
+            )
         ),
         helper.make_node(
             "Attention",
@@ -898,12 +920,16 @@ def GenerateModel9(model_name):  # noqa: N802
         helper.make_node("Expand", ["unsqueeze0_out", "shape_out"], ["expand_out"], "expand"),
         helper.make_node("Gather", ["shape_out", "indices_0"], ["gather1_out"], "gather1"),
         helper.make_node("Gather", ["shape_out", "indices_1"], ["gather2_out"], "gather2"),
-        helper.make_node("Unsqueeze", ["gather1_out", "axes_0"], ["unsqueeze1_out"], "unsqueeze1")
-        if opset_version == 13
-        else helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
-        helper.make_node("Unsqueeze", ["gather2_out", "axes_0"], ["unsqueeze2_out"], "unsqueeze2")
-        if opset_version == 13
-        else helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
+        (
+            helper.make_node("Unsqueeze", ["gather1_out", "axes_0"], ["unsqueeze1_out"], "unsqueeze1")
+            if opset_version == 13
+            else helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0])
+        ),
+        (
+            helper.make_node("Unsqueeze", ["gather2_out", "axes_0"], ["unsqueeze2_out"], "unsqueeze2")
+            if opset_version == 13
+            else helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0])
+        ),
         helper.make_node(
             "Concat",
             ["unsqueeze1_out", "unsqueeze2_out"],
diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.onnx
index 7ea69c580ee43..aa8e67bcbc59e 100644
Binary files a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.onnx and b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.onnx differ
diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float16_int8.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float16_int8.onnx
new file mode 100644
index 0000000000000..22293b0d10756
--- /dev/null
+++ b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float16_int8.onnx
@@ -0,0 +1,51 @@
+	:Ì
+U
+A
+B
+a_zero_point
+b_zero_pointmatmul_output_int32
MatMulInteger"
MatMulInteger
+.
+a_scale
+b_scale
+multiplier	mul_right"Mul
+A
+matmul_output_int32matmul_output_floatcast"Cast*	
+to
+ 
+5
+matmul_output_float
+
+multiplierY
+mul_bottom"MulDynamicQuantizeMatMul_fusionZ
+A
+
+
+M
+KZ
+B
+
+
+K
+NZ
+a_scale
+
+
+
+Z
+b_scale
+	
+
+CZ
+a_zero_point
+
+
+Z
+b_zero_point
+	
+Cb
+Y
+
+
+
+M
+NB
\ No newline at end of file
diff --git a/onnxruntime/test/unittest_main/test_main.cc b/onnxruntime/test/unittest_main/test_main.cc
index 4c38c90c2b418..d7e8bf9063645 100644
--- a/onnxruntime/test/unittest_main/test_main.cc
+++ b/onnxruntime/test/unittest_main/test_main.cc
@@ -32,17 +32,30 @@ void ortenv_setup() {
 }
 
 #ifdef USE_TENSORRT
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4100)  // Ignore warning C4100: unreferenced format parameter.
+#endif
+
 // TensorRT will load/unload libraries as builder objects are created and torn down. This will happen for
 // every single unit test, which leads to excessive test execution time due to that overhead.
 // Nvidia suggests to keep a placeholder builder object around to avoid this.
 #include "NvInfer.h"
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
 class DummyLogger : public nvinfer1::ILogger {
  public:
-  DummyLogger(Severity verbosity) {}
-  void log(Severity severity, const char* msg) noexcept override {}
+  DummyLogger(Severity /*verbosity*/) {}
+  void log(Severity /*severity*/, const char* /*msg*/) noexcept override {}
 };
 DummyLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
+
 auto const placeholder = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
+
 #endif
 
 #define TEST_MAIN main
diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc
index a94f7b5b707c7..6ad2d41edb562 100644
--- a/onnxruntime/test/util/default_providers.cc
+++ b/onnxruntime/test/util/default_providers.cc
@@ -8,6 +8,9 @@
 #ifdef USE_COREML
 #include "core/providers/coreml/coreml_provider_factory.h"
 #endif
+#ifdef USE_CUDA
+#include <core/providers/cuda/cuda_provider_options.h>
+#endif
 #include "core/session/onnxruntime_cxx_api.h"
 #include "core/framework/session_options.h"
 
@@ -110,14 +113,29 @@ std::unique_ptr<IExecutionProvider> DefaultOpenVINOExecutionProvider() {
 
 std::unique_ptr<IExecutionProvider> DefaultCudaExecutionProvider() {
 #ifdef USE_CUDA
-  OrtCUDAProviderOptions provider_options{};
+  OrtCUDAProviderOptionsV2 provider_options{};
   provider_options.do_copy_in_default_stream = true;
+  provider_options.use_tf32 = false;
   if (auto factory = CudaProviderFactoryCreator::Create(&provider_options))
     return factory->CreateProvider();
 #endif
   return nullptr;
 }
 
+#ifdef ENABLE_CUDA_NHWC_OPS
+std::unique_ptr<IExecutionProvider> DefaultCudaNHWCExecutionProvider() {
+#if defined(USE_CUDA)
+  OrtCUDAProviderOptionsV2 provider_options{};
+  provider_options.do_copy_in_default_stream = true;
+  provider_options.use_tf32 = false;
+  provider_options.prefer_nhwc = true;
+  if (auto factory = CudaProviderFactoryCreator::Create(&provider_options))
+    return factory->CreateProvider();
+#endif
+  return nullptr;
+}
+#endif
+
 std::unique_ptr<IExecutionProvider> CudaExecutionProviderWithOptions(const OrtCUDAProviderOptionsV2* provider_options) {
 #ifdef USE_CUDA
   if (auto factory = CudaProviderFactoryCreator::Create(provider_options))
@@ -207,15 +225,21 @@ std::unique_ptr<IExecutionProvider> DefaultRocmExecutionProvider(bool test_tunab
   return nullptr;
 }
 
-std::unique_ptr<IExecutionProvider> DefaultCoreMLExecutionProvider() {
-// For any non - macOS system, CoreML will only be used for ort model converter
-// Make it unavailable here, you can still manually append CoreML EP to session for model conversion
+std::unique_ptr<IExecutionProvider> DefaultCoreMLExecutionProvider(bool use_mlprogram) {
+  // To manually test CoreML model generation on a non-macOS platform, comment out the `&& defined(__APPLE__)` below.
+  // The test will create a model but execution of it will obviously fail.
 #if defined(USE_COREML) && defined(__APPLE__)
   // We want to run UT on CPU only to get output value without losing precision
   uint32_t coreml_flags = 0;
   coreml_flags |= COREML_FLAG_USE_CPU_ONLY;
+
+  if (use_mlprogram) {
+    coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM;
+  }
+
   return CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider();
 #else
+  ORT_UNUSED_PARAMETER(use_mlprogram);
   return nullptr;
 #endif
 }
diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h
index 9f78e0a0d4eb2..ae8e89c386994 100644
--- a/onnxruntime/test/util/include/default_providers.h
+++ b/onnxruntime/test/util/include/default_providers.h
@@ -35,6 +35,9 @@ namespace test {
 // unique_ptr providers with default values for session registration
 std::unique_ptr<IExecutionProvider> DefaultCpuExecutionProvider(bool enable_arena = true);
 std::unique_ptr<IExecutionProvider> DefaultCudaExecutionProvider();
+#ifdef ENABLE_CUDA_NHWC_OPS
+std::unique_ptr<IExecutionProvider> DefaultCudaNHWCExecutionProvider();
+#endif
 std::unique_ptr<IExecutionProvider> CudaExecutionProviderWithOptions(const OrtCUDAProviderOptionsV2* provider_options);
 std::unique_ptr<IExecutionProvider> DefaultDnnlExecutionProvider();
 std::unique_ptr<IExecutionProvider> DnnlExecutionProviderWithOptions(const OrtDnnlProviderOptions* provider_options);
@@ -51,7 +54,7 @@ std::unique_ptr<IExecutionProvider> DefaultRknpuExecutionProvider();
 std::unique_ptr<IExecutionProvider> DefaultAclExecutionProvider(bool enable_arena = true);
 std::unique_ptr<IExecutionProvider> DefaultArmNNExecutionProvider(bool enable_arena = true);
 std::unique_ptr<IExecutionProvider> DefaultRocmExecutionProvider(bool test_tunable_op = false);
-std::unique_ptr<IExecutionProvider> DefaultCoreMLExecutionProvider();
+std::unique_ptr<IExecutionProvider> DefaultCoreMLExecutionProvider(bool use_mlprogram = false);
 std::unique_ptr<IExecutionProvider> DefaultSnpeExecutionProvider();
 std::unique_ptr<IExecutionProvider> DefaultQnnExecutionProvider();
 std::unique_ptr<IExecutionProvider> QnnExecutionProviderWithOptions(const ProviderOptions& options,
diff --git a/onnxruntime/test/wasm/package-lock.json b/onnxruntime/test/wasm/package-lock.json
index bfa000fda440a..1beaf3b83ca28 100644
--- a/onnxruntime/test/wasm/package-lock.json
+++ b/onnxruntime/test/wasm/package-lock.json
@@ -520,9 +520,9 @@
       "dev": true
     },
     "node_modules/follow-redirects": {
-      "version": "1.15.4",
-      "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz",
-      "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==",
+      "version": "1.15.6",
+      "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz",
+      "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==",
       "dev": true,
       "funding": [
         {
@@ -1972,9 +1972,9 @@
       "dev": true
     },
     "follow-redirects": {
-      "version": "1.15.4",
-      "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz",
-      "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==",
+      "version": "1.15.6",
+      "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz",
+      "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==",
       "dev": true
     },
     "fs-extra": {
diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js
index cbc60c70b57aa..90d8b737252e5 100644
--- a/onnxruntime/wasm/js_internal_api.js
+++ b/onnxruntime/wasm/js_internal_api.js
@@ -4,39 +4,27 @@
 'use strict';
 
 /**
- * Mount external data files of a model to the virtual file system (MEMFS).
+ * Mount external data files of a model to an internal map, which will be used during session initialization.
  *
  * @param {string} externalDataFilesPath
  * @param {Uint8Array} externalDataFilesData
  */
 Module['mountExternalData'] = (externalDataFilePath, externalDataFileData) => {
   const files = Module.MountedFiles || (Module.MountedFiles = new Map());
-    files.set(externalDataFilePath, externalDataFileData);
+  files.set(externalDataFilePath, externalDataFileData);
 };
 
 /**
- * Unmount external data files of a model from the virtual file system (MEMFS).
+ * Unmount external data files of a model.
  */
 Module['unmountExternalData'] = () => {
   delete Module.MountedFiles;
 };
 
 /**
- * init JSEP
+ * initialize JSEP for asyncify support.
  */
-Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel, captureBegin, captureEnd, replay) => {
-  Module.jsepBackend = backend;
-  Module.jsepAlloc = alloc;
-  Module.jsepFree = free;
-  Module.jsepCopy = copy;
-  Module.jsepCopyAsync = copyAsync;
-  Module.jsepCreateKernel = createKernel;
-  Module.jsepReleaseKernel = releaseKernel;
-  Module.jsepRunKernel = runKernel;
-  Module.jsepCaptureBegin = captureBegin;
-  Module.jsepCaptureEnd = captureEnd;
-  Module.jsepReplay = replay;
-
+let jsepInitAsync = () => {
   // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1)
   // It removes some overhead in cwarp() and ccall() that we don't need.
   //
@@ -143,7 +131,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
         }
 
         // Flush the backend. This will submit all pending commands to the GPU.
-        backend['flush']();
+        Module.jsepBackend?.['flush']();
 
         // Await all pending promises. This includes GPU validation promises for diagnostic purposes.
         const errorPromises = state.errors;
@@ -180,20 +168,46 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
       () => Module['_OrtBindInput'],
       v => Module['_OrtBindInput'] = v);
 
-  // expose webgpu backend functions
-  Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => {
-    return backend['registerBuffer'](sessionId, index, buffer, size);
-  };
-  Module['jsepGetBuffer'] = (dataId) => {
-    return backend['getBuffer'](dataId);
-  };
-  Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => {
-    return backend['createDownloader'](gpuBuffer, size, type);
-  };
-  Module['jsepOnReleaseSession'] = sessionId => {
-    backend['onReleaseSession'](sessionId);
-  };
-  Module['jsepOnRunStart'] = sessionId => {
-    return backend['onRunStart'](sessionId);
-  };
+  // remove this function to make sure it is called only once.
+  jsepInitAsync = undefined;
+};
+
+
+/**
+ * initialize JSEP for WebGPU.
+ */
+Module['jsepInit'] = (name, params) => {
+  jsepInitAsync?.();
+
+  if (name === 'webgpu') {
+    [Module.jsepBackend,
+     Module.jsepAlloc,
+     Module.jsepFree,
+     Module.jsepCopy,
+     Module.jsepCopyAsync,
+     Module.jsepCreateKernel,
+     Module.jsepReleaseKernel,
+     Module.jsepRunKernel,
+     Module.jsepCaptureBegin,
+     Module.jsepCaptureEnd,
+     Module.jsepReplay] = params;
+
+    // expose webgpu backend functions
+    const backend = Module.jsepBackend;
+    Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => {
+      return backend['registerBuffer'](sessionId, index, buffer, size);
+    };
+    Module['jsepGetBuffer'] = (dataId) => {
+      return backend['getBuffer'](dataId);
+    };
+    Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => {
+      return backend['createDownloader'](gpuBuffer, size, type);
+    };
+    Module['jsepOnReleaseSession'] = sessionId => {
+      backend['onReleaseSession'](sessionId);
+    };
+    Module['jsepOnRunStart'] = sessionId => {
+      return backend['onRunStart'](sessionId);
+    };
+  }
 };
diff --git a/orttraining/orttraining/core/framework/triton/triton_op_executor.cc b/orttraining/orttraining/core/framework/triton/triton_op_executor.cc
index 092ab89d5d760..f30d6ddee253a 100644
--- a/orttraining/orttraining/core/framework/triton/triton_op_executor.cc
+++ b/orttraining/orttraining/core/framework/triton/triton_op_executor.cc
@@ -106,6 +106,8 @@ void TritonOpExecutor::ExecuteByFuncName(const std::string& func_name, const Inl
       PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyLong_FromLongLong(std::stoll(kv.second.first)));
     } else if (kv.second.second == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
       PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyFloat_FromDouble(std::stod(kv.second.first)));
+    } else if (kv.second.second == ONNX_NAMESPACE::TensorProto_DataType_STRING) {
+      PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyUnicode_FromString(kv.second.first.c_str()));
     } else {
       ORT_THROW("Unsupported kwargs data type: ", kv.second.second);
     }
diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc
index e675b55c8af8f..22dcf4eb92411 100755
--- a/orttraining/orttraining/core/graph/gradient_builder.cc
+++ b/orttraining/orttraining/core/graph/gradient_builder.cc
@@ -1112,6 +1112,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
 
   ArgDef grad = GO(0);
   if (!keepdims) {
+    size_t numInputs = GetSrcNodeInputSize();
     if (attributes.find("axes") != attributes.end()) {
       std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
       grad = IA("Unqueezed_Grad");
@@ -1122,6 +1123,9 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
         result.push_back(axes_values_node);
         result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), axes_values_node.output_args[0]}, {grad}));
       }
+    } else if (numInputs == 2) {  // optional input 'axes' is available as input I(1)
+      grad = IA("Unqueezed_Grad");
+      result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {grad}));
     }
   }
 
@@ -1152,12 +1156,21 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) {
   }
 
   ArgDef grad = GO(0);
-  if (!keepdims && attributes.find("axes") != attributes.end()) {
-    std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
-    grad = IA("Unsqueezed_Grad");
-    result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
+  if (!keepdims) {
+    size_t numInputs = GetSrcNodeInputSize();
+    if (attributes.find("axes") != attributes.end()) {
+      std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
+      grad = IA("Unsqueezed_Grad");
 
-    result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)}));
+      result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
+
+      result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)}));
+    } else if (numInputs == 2) {  // optional input 'axes' is available as input I(1)
+      grad = IA("Unsqueezed_Grad");
+      result.push_back(NodeDef("Unsqueeze", {GO(0), I(1)}, {grad}));
+
+      result.push_back(NodeDef("Unsqueeze", {O(0), I(1)}, {IA("Unsqueezed_Output")}));
+    }
     result.push_back(NodeDef("Sub", {I(0), IA("Unsqueezed_Output")}, {IA("Self_Sub_Result")}));
   } else {
     result.push_back(NodeDef("Sub", {I(0), O(0)}, {IA("Self_Sub_Result")}));
@@ -1188,11 +1201,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceL2Gradient) {
   ArgDef scaled_dy_arg_def = IA("Masked_Scaled_dY");
   result.emplace_back(NodeDef("Where", {IA("Masked_Y"), ZERO, IA("Scaled_dY")}, {scaled_dy_arg_def}));
 
-  if (!keepdims && attributes.find("axes") != attributes.end()) {
-    std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
+  if (!keepdims) {
+    size_t numInputs = GetSrcNodeInputSize();
     scaled_dy_arg_def = IA("Unsqueezed_Masked_Scaled_dY");
-    result.emplace_back(
-        NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)}));
+    if (attributes.find("axes") != attributes.end()) {
+      std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
+      result.emplace_back(
+          NodeDef("Unsqueeze", {IA("Masked_Scaled_dY")}, {scaled_dy_arg_def}, {MakeAttribute("axes", axes_values)}));
+    } else if (numInputs == 2) {  // optional input 'axes' is available as input I(1)
+      result.emplace_back(
+          NodeDef("Unsqueeze", {IA("Masked_Scaled_dY"), I(1)}, {scaled_dy_arg_def}));
+    }
   }
 
   result.emplace_back(NodeDef("Mul", {I(0), scaled_dy_arg_def}, {GI(0)}));
diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc
index 0412000e04e1b..ff220fcb067b8 100644
--- a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc
+++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc
@@ -42,30 +42,45 @@
 */
 namespace onnxruntime {
 bool NodeCanBeReplacedByMatmul(const Node& node) {
-  // If node type is Conv, and attr "dilations" is 1, "kernel_shape" is 1, "stride" is 1, group is 1 or 2,
-  // then it can be replaced by MatMul
-  // Kernel_shape is 1 means it is conv1d
+  /*
+  If node type is Conv, and satisfy the following conditions then it can be replaced by MatMul:
+  - not bias as input which means only has 2 inputs: input and weight
+  - "dilations" should be [1]
+    size 1 means conv1d
+  - "strides" should be [1]
+  - "pads" should be [0,0]
+  - "autopad" should be "NOTSET"
+  - "kernel_shape" should be [1]
+  */
   if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11})) {
     return false;
   }
-  const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations");
-  const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape");
-  const auto* stride = graph_utils::GetNodeAttribute(node, "strides");
-  const auto* group = graph_utils::GetNodeAttribute(node, "group");
-  if (dilations == nullptr || kernel_shape == nullptr || stride == nullptr || group == nullptr) {
+
+  // TODO: bias input can also be supported if needed
+  if (node.InputDefs().size() != 2) {
     return false;
   }
-  if ((dilations->ints_size() && dilations->ints(0) != 1) ||
-      (kernel_shape->ints_size() && kernel_shape->ints(0) != 1) ||
-      (stride->ints_size() && stride->ints(0) != 1) ||
-      group->i() >= 3) {
+
+  const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations");
+  const auto* strides = graph_utils::GetNodeAttribute(node, "strides");
+  const auto* pads = graph_utils::GetNodeAttribute(node, "pads");
+  const auto* autopad = graph_utils::GetNodeAttribute(node, "auto_pad");
+  const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape");
+  if (dilations == nullptr || strides == nullptr || pads == nullptr || autopad == nullptr || kernel_shape == nullptr) {
     return false;
   }
 
-  return true;
+  if ((dilations->ints_size() == 1 && dilations->ints(0) == 1) &&
+      (strides->ints_size() == 1 && strides->ints(0) == 1) &&
+      (autopad->s() == "NOTSET") &&
+      (pads->ints_size() == 2 && pads->ints(0) == 0 && pads->ints(1) == 0) &&
+      (kernel_shape->ints_size() == 1 && kernel_shape->ints(0) == 1)) {
+    return true;
+  }
+  return false;
 }
 
-void Conv1dToMatmul(Graph& graph, Node& conv) {
+void Conv1dToMatmul(Graph& graph, Node& conv, const std::string transformer_name) {
   // Shape of conv1d input: [batch_size, in_channels, in_length]
   // Shape of conv1d weight:[output_channels, input_channels/group, kernel_shape], kernel_shape is 1
   // We need to split the input into "group", and squeeze&split the weight, and then do MatMul
@@ -83,7 +98,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv) {
     conv1d_input_splitted_outputs.push_back(&graph.GetOrCreateNodeArg(
         graph.GenerateNodeArgName("input_split_output"), nullptr));
   }
-  auto& input_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, {conv1d_input},
+  auto& input_split = graph.AddNode(graph.GenerateNodeName(transformer_name + "Split"), "Split", node_description, {conv1d_input},
                                     {conv1d_input_splitted_outputs});
   input_split.SetExecutionProviderType(execution_provider_type);
   input_split.AddAttribute("axis", int64_t(1));
@@ -93,23 +108,25 @@ void Conv1dToMatmul(Graph& graph, Node& conv) {
   }
   // 2. Squeeze conv weight
   auto conv1d_weight = conv.MutableInputDefs()[1];
+  // auto con1d_bias = xx;
   auto weight_squeeze_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("weight_squeeze_output"), nullptr);
-  auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName("WeightSqueeze"), "Squeeze",
+  auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName(transformer_name + "WeightSqueeze"), "Squeeze",
                                        node_description, {conv1d_weight}, {weight_squeeze_output});
+  int64_t weight_squeeze_axis = 2;
   if (onnx_opset_version > 12) {
     // After onnx version 12, squeeze node has axes as input instead of attribute
     ONNX_NAMESPACE::TensorProto initializer_proto;
-    initializer_proto.set_name(graph.GenerateNodeName("ConstAsInitializer"));
+    initializer_proto.set_name(graph.GenerateNodeName(transformer_name + "ConstAsInitializer"));
     initializer_proto.add_dims(static_cast<int64_t>(1));
     initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
-    InlinedVector<int64_t> initializer_proto_value{2};
+    InlinedVector<int64_t> initializer_proto_value{weight_squeeze_axis};
     initializer_proto.set_raw_data(initializer_proto_value.data(), initializer_proto_value.size() * sizeof(int64_t));
     auto& axes_input = graph_utils::AddInitializer(graph, initializer_proto);
     // Squeeze node doesn't have opschema here, so we need to set input args count manually
     weight_squeeze.MutableInputArgsCount().resize(2);
     graph_utils::AddNodeInput(weight_squeeze, 1, axes_input);
   } else {
-    weight_squeeze.AddAttribute("axes", std::vector<int64_t>{2});
+    weight_squeeze.AddAttribute("axes", std::vector<int64_t>{weight_squeeze_axis});
   }
   weight_squeeze.SetExecutionProviderType(execution_provider_type);
   // 3. Split conv weight
@@ -118,7 +135,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv) {
     conv1d_weight_splitted_outputs.push_back(&graph.GetOrCreateNodeArg(
         graph.GenerateNodeArgName("weight_split_output"), nullptr));
   }
-  auto& weight_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description,
+  auto& weight_split = graph.AddNode(graph.GenerateNodeName(transformer_name + "Split"), "Split", node_description,
                                      {weight_squeeze_output}, {conv1d_weight_splitted_outputs});
   weight_split.AddAttribute("axis", int64_t(0));
   weight_split.SetExecutionProviderType(execution_provider_type);
@@ -130,13 +147,13 @@ void Conv1dToMatmul(Graph& graph, Node& conv) {
   for (int i = 0; i < group_num; i++) {
     auto matmul_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("matmul_output"), nullptr);
     matmul_outputs.push_back(matmul_output);
-    auto& matmul = graph.AddNode(graph.GenerateNodeName("Matmul"), "MatMul", node_description,
+    auto& matmul = graph.AddNode(graph.GenerateNodeName(transformer_name + "Matmul"), "MatMul", node_description,
                                  {conv1d_weight_splitted_outputs[i], conv1d_input_splitted_outputs[i]},
                                  {matmul_output});
     matmul.SetExecutionProviderType(execution_provider_type);
   }
   // 5. Concat matmul outputs
-  auto& concat_node = graph.AddNode(graph.GenerateNodeName("Concat"), "Concat", node_description,
+  auto& concat_node = graph.AddNode(graph.GenerateNodeName(transformer_name + "Concat"), "Concat", node_description,
                                     matmul_outputs, {});
   concat_node.SetExecutionProviderType(execution_provider_type);
   concat_node.AddAttribute("axis", int64_t(1));
@@ -155,7 +172,7 @@ Status Conv1dReplacement::ApplyImpl(Graph& graph, bool& modified, int graph_leve
     ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
     if (NodeCanBeReplacedByMatmul(node)) {
       LOGS(logger, VERBOSE) << "lora conv1d replacement, node name: " + node.Name();
-      Conv1dToMatmul(graph, node);
+      Conv1dToMatmul(graph, node, Name());
       modified = true;
     }
   }
diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc
index 894fe3b052fb2..9ce88e549eed2 100644
--- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc
+++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc
@@ -44,6 +44,7 @@
 #include "core/optimizer/relu_clip_fusion.h"
 #include "core/optimizer/reshape_fusion.h"
 #include "core/optimizer/rule_based_graph_transformer.h"
+#include "core/optimizer/shape_input_merge.h"
 #include "core/optimizer/skip_layer_norm_fusion.h"
 #include "core/optimizer/slice_elimination.h"
 #include "core/optimizer/unsqueeze_elimination.h"
@@ -115,10 +116,11 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
       ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique<PythonOpRewriter>()));
 #endif
 
-      // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for
-      // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by
-      // default, CSE will not merge them, because the different initializers are represented by different NodeArg.
+      // Put ConstantSharing and ShapeInputMerge before CommonSubexpressionElimination by intention as it can create
+      // more opportunities for CSE. For example, if A and B nodes consume same different args but produce same output
+      // or consume different initializers with same value, by default, CSE will not merge them.
       transformers.emplace_back(std::make_unique<ConstantSharing>(compatible_eps));
+      transformers.emplace_back(std::make_unique<ShapeInputMerge>(compatible_eps));
       // LayerNormFusion must be applied before CommonSubexpressionElimination as the latter will break the pattern when 2 LayerNormFusion share the same input.
       transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps));
       // Remove duplicate nodes. Must be applied before any recompute transformations.
@@ -138,7 +140,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
       transformers.emplace_back(std::make_unique<FastGeluFusion>(compatible_eps));
       transformers.emplace_back(std::make_unique<QuickGeluFusion>(compatible_eps));
       transformers.emplace_back(std::make_unique<SoftmaxCrossEntropyLossInternalFusion>(compatible_eps));
-      transformers.emplace_back(std::make_unique<GatherToSplitFusion>(compatible_eps));
+      transformers.emplace_back(std::make_unique<GatherSliceToSplitFusion>(compatible_eps));
       transformers.emplace_back(std::make_unique<GatherToSliceFusion>(compatible_eps));
       // If a model with Q, DQ nodes is being used for the purpose of training, it must be for
       // Quantization Aware Training. So, replace QDQ nodes with FakeQuant.
diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc
index 3fbdd5da7b768..54c49db0597c7 100644
--- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc
+++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc
@@ -9,6 +9,8 @@
 #include <utility>
 #include <vector>
 
+#include "core/common/string_utils.h"
+#include "core/framework/random_seed.h"
 #include "core/graph/graph_utils.h"
 #include "core/graph/graph_viewer.h"
 #include "orttraining/core/optimizer/memory_optimizer/common.h"
@@ -256,7 +258,8 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer,
                                                      logger));
 
   InlinedHashSet<const Node*> layer_boundary_ln_nodes;
-  FindLayerBoundaryLayerNormNodes(graph_viewer, logger, layer_boundary_ln_nodes);
+  FindLayerBoundaryLayerNormNodes(graph_viewer, logger, node_index_to_its_order_in_topological_sort_map,
+                                  yield_op_order_in_topological_sort, layer_boundary_ln_nodes);
 
   // The first pass - find the candidate subgraphs.
   for (int i = static_cast<int>(node_ids.size()) - 1; i >= 0; --i) {
@@ -284,7 +287,9 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer,
       memory_opt_planner.AddNodeOptimizationPlan(p_node, std::move(recompute_plan));
     }
 
-    if (can_compromise_stashed_activation) {
+    // Only detect compromise recompute when recompute is not found, in case there are multiple recompute plans
+    // for the same named activations, then user might enable those conflicting recompute plans by mistakes.
+    if (recompute_plan == nullptr && can_compromise_stashed_activation) {
       MO_LOG_DEBUG_INFO(logger, "Searching Node " + p_node->Name() + "(" + p_node->OpType() +
                                     ") for compromised recompute");
       // If the subgraph recompute can save memory by comprising the assumption - recompute graphs' input must exist
diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc
index 49e026ca86bd3..ac619bdc390d3 100644
--- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc
+++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc
@@ -28,6 +28,28 @@ constexpr bool IsForwardPassOperator(ptrdiff_t op_order_in_topological_sort,
   return op_order_in_topological_sort <= boundary_op_order_in_topological_sort;
 }
 
+// Reset seed attribute for the dropout node if the seed is not set.
+bool SetSeedForDropoutNode(Node& node) {
+  // ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support.
+  if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {12, 13}, kOnnxDomain) ||
+      graph_utils::IsSupportedOptypeVersionAndDomain(node, "BitmaskDropout", {1}, kMSDomain) ||
+      graph_utils::IsSupportedOptypeVersionAndDomain(node, "BiasDropout", {1}, kMSDomain) ||
+      graph_utils::IsSupportedOptypeVersionAndDomain(node, "BitmaskBiasDropout", {1}, kMSDomain) ||
+      graph_utils::IsSupportedOptypeVersionAndDomain(node, "BiasSoftmaxDropout", {1}, kMSDomain)) {
+    auto& attrs = node.GetAttributes();
+    if (attrs.count("seed")) {
+      return false;
+    }
+
+    int64_t seed = static_cast<int64_t>(utils::GetHashFromString(node.OutputDefs()[0]->Name())) +
+                   utils::GetRandomSeed();
+    node.AddAttribute("seed", seed);
+    return true;
+  }
+
+  return false;
+}
+
 }  // namespace
 
 Status MemoryOptimizer::ParseOptimizationConfigFromString(const std::string& memory_optimizer_config,
@@ -74,7 +96,7 @@ bool MemoryOptimizer::ModifyGraph(Graph& graph,
       optimizer::memory_optimizer::NodeRecomputePlan* recompute_plan =
           dynamic_cast<optimizer::memory_optimizer::NodeRecomputePlan*>(node_plan.get());
       ORT_ENFORCE(recompute_plan != nullptr);
-      ORT_ENFORCE(CreateRecomputeGraph(graph, recompute_plan->GetNodesInTopoOrder(), replacement_node_ptr).IsOK());
+      ORT_ENFORCE(CreateRecomputeGraph(graph, recompute_plan->GetNodesInTopoOrder(), logger, replacement_node_ptr).IsOK());
     } else {
       ORT_THROW("unsupported optimization type found.");
     }
@@ -93,7 +115,7 @@ bool MemoryOptimizer::ModifyGraph(Graph& graph,
 
         auto tid = node_index_to_its_order_in_topological_sort_map.find(it->GetNode().Index());
         // It is possible the consumer node is newly added as the recompute node, so we need a check here.
-        // For those kind of ops, we can treat them as backward ops.
+        // For those kinds of ops, we can treat them as backward ops.
         if (tid == node_index_to_its_order_in_topological_sort_map.end() ||
             !IsForwardPassOperator(node_index_to_its_order_in_topological_sort_map.at(tid->first),
                                    boundary_op_order_in_topological_sort)) {
@@ -167,11 +189,44 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve
                   .IsOK());
 
   // The second pass - apply the transformation.
-  // Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated.
+  // Note 1: Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated.
   // The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended
   // earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier
   // layers.
-  const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED);
+  //
+  // Note 2: Here we use default typo order (which tries to BFS from the outputs,
+  // so the nearest node to graph output will be visited last). So in reversed default typo order,
+  // the neareast node to graph output will be visited first.
+  // Imagine there is a such subgraph
+  //         input1 input2 input3
+  //             \    |     /
+  //         multiple layers
+  //             |
+  //            node M
+  // labels-------|-----
+  //    \         |     |
+  //    node1     |     |
+  //      \       |     |
+  //      node2  /      |
+  //        \   /       |
+  //      node loss     /
+  //          |        /
+  //       YieldOp  node1_recompute
+  //         |      /
+  //         \   node2 recompute
+  //          \ /
+  //     node loss_grad
+  //           |
+  //     critical grad path
+  //
+  // In PriorityBased order, node1 will be visited first, so it's recompute node node1_recompute will be added
+  // at last because we do this following reversed topological order. Then node1_recompute node will have lowest
+  // priority to execute, as a result, if at this time, the queue to visit contains only recompute nodes, then
+  // node1_recompute will be run at last, affecting the backward critical path, which is not what we want.
+  // Current workaround is to use default order, which will execute node1_recompute earlier than other recompute nodes
+  // in this case.
+
+  const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT);
   for (int i = static_cast<int>(node_ids.size()) - 1; i >= 0; --i) {
     Node* p_node = graph.GetNode(node_ids[i]);
     if (p_node == nullptr) {
@@ -223,6 +278,7 @@ void MemoryOptimizer::PrintSummary(const optimizer::memory_optimizer::MemoryOpti
 
 Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph,
                                              const InlinedVector<const Node*>& nodes_in_topological_order,
+                                             const logging::Logger& logger,
                                              Node*& new_output_node_ptr) const {
   InlinedHashMap<NodeArg*, NodeArg*> self_contained_outputs_map;
   for (size_t i = 0; i < nodes_in_topological_order.size(); ++i) {
@@ -236,6 +292,12 @@ Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph,
       continue;
     }
 
+    bool seed_reset = SetSeedForDropoutNode(*node_to_duplicate);
+    if (seed_reset) {
+      LOGS(logger, VERBOSE) << "Set seed for Node " << node_to_duplicate->Name() << "(" << node_to_duplicate->OpType()
+                            << ").";
+    }
+
     InlinedVector<NodeArg*> new_input_args;
     new_input_args.reserve(node_to_duplicate->MutableInputDefs().size());
     for (NodeArg* input_arg : node_to_duplicate->MutableInputDefs()) {
diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h
index b3e05fd334e48..1d837038e76c1 100644
--- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h
+++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h
@@ -94,6 +94,7 @@ class MemoryOptimizer : public GraphTransformer {
    */
   Status CreateRecomputeGraph(Graph& graph,
                               const InlinedVector<const Node*>& nodes_in_topological_order,
+                              const logging::Logger& logger,
                               Node*& recompute_subgraph_output_node) const;
 
   /**************************************************
diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc
index 12c83591c0036..b421eb2ab32da 100644
--- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc
+++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc
@@ -19,7 +19,7 @@ namespace onnxruntime::optimizer::memory_optimizer {
 
 namespace {
 
-constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 15;
+constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 50;
 
 static size_t GetElementSize(const ONNX_NAMESPACE::DataType& tensor_type) {
   const ONNX_NAMESPACE::TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type);
@@ -48,75 +48,352 @@ float InputOutputSizeRatio(const Node* node) {
   return 1.0f;
 }
 
+using IgnorableInputIndices = InlinedVector<int>;
+using OpsetToIgnorableIndicesMap = InlinedHashMap<int, IgnorableInputIndices>;
+
 /**
- * @brief Used to define per-op recompute config.
+ * @brief Get the Allowed Recompute Ops object
+ *
+ * The supported op types are predefined.
+ * Most recent revisited for ONNX v1.15.0 release - https://github.com/onnx/onnx/blob/b86cc54efce19530fb953e4b21f57e6b3888534c/docs/Operators.md
  *
+ * We defined supported list explicitly instead of using a excluding list for the following reasons:
+ * 1. Some ops generate indeterministic results (for example using random number generator). We need evaluate whether
+ *   this is a problem for recompute before adding the support, instead of fixing this after we find and try to
+ *   fix convergence issues (which will be very hard if we have multiple indeterministic operators by default supported.)
+ * 2. Some ops schema will be changed in new opsets, we need also check manually whether it is applicable to recompute
+ *   or not.
+ * 3. Some ops are not supported in older opsets, we need to check whether it is applicable to recompute or not.
  */
-struct AllowedRecomputeNodeConfig {
-  InlinedVector<int> input_arg_indices;  // input index to iterate further (bottom up)
-};
-
-// The supported op types are predefined.
-
-const InlinedHashMap<std::string, AllowedRecomputeNodeConfig>& GetAllowedRecomputeOps(int probe_op_level) {
-  static InlinedHashMap<int, InlinedHashMap<std::string, AllowedRecomputeNodeConfig>> recomputable_op_table_map;
+const InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>& GetAllowedRecomputeOps(int probe_op_level) {
+  static InlinedHashMap<int, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>> recomputable_op_table_map;
   if (recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end()) {
     return recomputable_op_table_map.at(probe_op_level);
   }
 
-  recomputable_op_table_map.insert({probe_op_level, InlinedHashMap<std::string, AllowedRecomputeNodeConfig>()});
+  recomputable_op_table_map.insert({probe_op_level, InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>()});
   auto& recomputable_op_table = recomputable_op_table_map.at(probe_op_level);
   if (probe_op_level >= static_cast<int>(ProbeLevel::Basic)) {
     recomputable_op_table.insert({
-        // Binary elementwise
-        {"Add", AllowedRecomputeNodeConfig{{0, 1}}},
-        {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}},
-        {"Div", AllowedRecomputeNodeConfig{{0, 1}}},
-        {"Equal", AllowedRecomputeNodeConfig{{0, 1}}},
-        {"Mul", AllowedRecomputeNodeConfig{{0, 1}}},
-        {"Sub", AllowedRecomputeNodeConfig{{0, 1}}},
-
-        // Data layout
-        /// The shape input is trivial whether it exists or not in backward.
-        {"Reshape", AllowedRecomputeNodeConfig{{0}}},
-        {"Shape", AllowedRecomputeNodeConfig{{0}}},
-        {"Squeeze", AllowedRecomputeNodeConfig{{0}}},
-        {"Transpose", AllowedRecomputeNodeConfig{{0}}},
-        {"Unsqueeze", AllowedRecomputeNodeConfig{{0}}},
-
-        // Unary elementwise
-        {"Dropout", AllowedRecomputeNodeConfig{{0}}},
-        {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}},
-        /// The ratio and mode input are trivial whether they exist or not in backward
-        {"BitmaskDropout", AllowedRecomputeNodeConfig{{0}}},
-        /// The axis input is trivial whether it exists or not in backward
-        {"CumSum", AllowedRecomputeNodeConfig{{0}}},
-        {"Expand", AllowedRecomputeNodeConfig{{0}}},
-        {"FastGelu", AllowedRecomputeNodeConfig{{0}}},
-        {"Gelu", AllowedRecomputeNodeConfig{{0}}},
-        {"QuickGelu", AllowedRecomputeNodeConfig{{0}}},
-
-        // Ternary elementwise
-        {"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}},
-
-        // Data copy
-        {"Tile", AllowedRecomputeNodeConfig{{0}}},
-        {"Cast", AllowedRecomputeNodeConfig{{0}}},
-        {"ConcatTraining", AllowedRecomputeNodeConfig{{0, 1}}},  // Input could be more than 2. But mostly 2.
-        {"Slice", AllowedRecomputeNodeConfig{{0}}},
-        {"Split", AllowedRecomputeNodeConfig{{0}}},
-        {"Gather", AllowedRecomputeNodeConfig{{0}}},
+        {
+            utils::GetFullQualifiedOpName("Add", kOnnxDomain),
+            {
+                {1, {}},
+                {6, {}},
+                {7, {}},
+                {13, {}},
+                {14, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("BatchNormalization", kOnnxDomain),
+            {
+                {1, {}},
+                {6, {}},
+                {7, {}},
+                {9, {}},
+                {14, {}},
+                {15, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("BiasGelu", kMSDomain),
+            {
+                {1, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("BiasDropout", kMSDomain),
+            {
+                {1, {3, 4}},  // ignore ratio (optional) and training mode (optional)
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain),
+            {
+                {1, {3, 4}},  // ignore ratio (optional) and training mode (optional)
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("BitmaskDropout", kMSDomain),
+            {
+                {1, {1, 2}},  // ignore ratio (optional) and training mode (optional)
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Cast", kOnnxDomain),
+            {
+                {1, {}},
+                {6, {}},
+                {9, {}},
+                {13, {}},
+                {19, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("ConcatTraining", kMSDomain),
+            {
+                {1, {}},
+
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("ConstantOfShape", kOnnxDomain),
+            {
+                {9, {0}},  // ignore the `input`, e.g. the shape of the expected output tensor
+                {20, {0}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Dropout", kOnnxDomain),
+            {
+                // ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support.
+                {12, {1, 2}},  // ignore ratio and training_mode
+                {13, {1, 2}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Div", kOnnxDomain),
+            {
+                {1, {}},
+                {6, {}},
+                {7, {}},
+                {13, {}},
+                {14, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Expand", kOnnxDomain),
+            {
+                {8, {1}},  // Ignore the shape.
+                {13, {1}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Cos", kOnnxDomain),
+            {
+                {7, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("CumSum", kOnnxDomain),
+            {
+                // The axis input is trivial
+                {11, {1}},
+                {14, {1}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Einsum", kOnnxDomain),
+            {
+                {12, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Equal", kOnnxDomain),
+            {
+                {1, {}},
+                {7, {}},
+                {11, {}},
+                {13, {}},
+                {19, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("FastGelu", kMSDomain),
+            {
+                {1, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Gather", kOnnxDomain),
+            {
+                {1, {1}},  // ignore the indices
+                {11, {1}},
+                {13, {1}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Gelu", kOnnxDomain),
+            {
+                {20, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Gelu", kMSDomain),
+            {
+                {1, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Less", kOnnxDomain),
+            {
+                {1, {}},
+                {7, {}},
+                {9, {}},
+                {13, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Mul", kOnnxDomain),
+            {
+                {1, {}},
+                {6, {}},
+                {7, {}},
+                {13, {}},
+                {14, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Range", kOnnxDomain),
+            {
+                {11, {0, 1, 2}},  // ignore start, end, delta, because they are scalars.
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Reshape", kOnnxDomain),
+            {
+                {1, {}},
+                {5, {}},  // ignore the shape.
+                {13, {}},
+                {14, {}},
+                {19, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Sin", kOnnxDomain),
+            {
+                {7, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Slice", kOnnxDomain),
+            {
+                {1, {}},
+                {10, {1, 2, 3, 4}},  // ignore starts, ends, axes (optional) and steps (optional)
+                {11, {1, 2, 3, 4}},
+                {13, {1, 2, 3, 4}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Split", kOnnxDomain),
+            {
+                {1, {1}},  // ignore split (optional)
+                {2, {}},
+                {11, {}},
+                {13, {1}},  // ignore the split (optional)
+                {18, {1}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Squeeze", kOnnxDomain),
+            {
+                {1, {}},
+                {11, {}},
+                {13, {1}},  // ignore the axes (optional)
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Sub", kOnnxDomain),
+            {
+                {1, {}},
+                {6, {}},
+                {7, {}},
+                {13, {}},
+                {14, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Tile", kOnnxDomain),
+            {
+                {1, {1, 2}},
+                {6, {1}},
+                {13, {1}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Transpose", kOnnxDomain),
+            {
+                {1, {}},
+                {13, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Trilu", kOnnxDomain),
+            {
+                {14, {1}},  // ignore k (optional)
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("QuickGelu", kMSDomain),
+            {
+                {1, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Unsqueeze", kOnnxDomain),
+            {
+                {1, {}},
+                {11, {}},
+                {13, {1}},  // ignore the axes (optional)
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Where", kOnnxDomain),
+            {
+                {9, {}},
+                {16, {}},
+            },
+        },
+
     });
   }
 
   if (probe_op_level >= static_cast<int>(ProbeLevel::Advanced)) {
     recomputable_op_table.insert({
-        {"LayerNormalization", AllowedRecomputeNodeConfig{{0, 1, 2}}},
-        {"MatMul", AllowedRecomputeNodeConfig{{0, 1}}},
-        {"FusedMatMul", AllowedRecomputeNodeConfig{{0, 1}}},
-        {"Softmax", AllowedRecomputeNodeConfig{{0}}},
-        {"BiasSoftmax", AllowedRecomputeNodeConfig{{0, 1}}},
-        {"BiasSoftmaxDropout", AllowedRecomputeNodeConfig{{0, 1}}},
+        {
+            utils::GetFullQualifiedOpName("BiasSoftmax", kMSDomain),
+            {
+                {1, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("BiasSoftmaxDropout", kMSDomain),
+            {
+                {1, {2}},  // ignore ratio (optional)
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
+            {
+                // Opset 1 in ONNX official does not have LayerNormalization,
+                // while our contrib op defined LayerNormalization in opset 1 in ONNX domain.
+                {1, {}},
+                {17, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("MatMul", kOnnxDomain),
+            {
+                {1, {}},
+                {9, {}},
+                {13, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("FusedMatMul", kMSDomain),
+            {
+                {1, {}},
+            },
+        },
+        {
+            utils::GetFullQualifiedOpName("Softmax", kOnnxDomain),
+            {
+                {1, {}},
+                {11, {}},
+                {13, {}},
+            },
+        },
     });
   }
 
@@ -127,8 +404,20 @@ const InlinedHashMap<std::string, AllowedRecomputeNodeConfig>& GetAllowedRecompu
  * @brief Check whether a node is a recomputable node at given probe level.
  */
 bool IsRecomputable(const Node& node, ProbeLevel probe_level) {
-  const auto& op_table = GetAllowedRecomputeOps(static_cast<int>(probe_level));
-  return op_table.find(node.OpType()) != op_table.end();
+  const InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>& op_table = GetAllowedRecomputeOps(static_cast<int>(probe_level));
+  auto it = op_table.find(utils::GetFullQualifiedOpName(node.OpType(), node.Domain()));
+  if (it == op_table.end()) {
+    return false;
+  }
+  return it->second.count(node.SinceVersion());
+}
+
+const InlinedVector<int>& GetIgnorableInputIndices(const Node& node, ProbeLevel probe_level) {
+  const InlinedHashMap<std::string, OpsetToIgnorableIndicesMap>& op_table = GetAllowedRecomputeOps(static_cast<int>(probe_level));
+  auto it = op_table.find(utils::GetFullQualifiedOpName(node.OpType(), node.Domain()));
+  ORT_ENFORCE(it != op_table.end(), "Cannot get ignorable indices since the node type is supported in the list.");
+  ORT_ENFORCE(it->second.count(node.SinceVersion()) > 0, "Cannot get ignorable indices since the opset is supported");
+  return it->second.at(node.SinceVersion());
 }
 
 /**
@@ -163,7 +452,6 @@ Status SelectRecomputeSubgraph(const Node& entry_node,
                                bool& can_compromise_stashed_activation,
                                float& save_ratio) {
   const ProbeLevel probe_level = probe_config.probe_level;
-  const auto& recomputable_op_table = GetAllowedRecomputeOps(static_cast<int>(probe_level));
 
   can_compromise_stashed_activation = false;
 
@@ -213,7 +501,7 @@ Status SelectRecomputeSubgraph(const Node& entry_node,
       //  If current op is NOT in allowed list:
       //    1). the output does not exist in backward, we cannot find a good solution for so, the search terminates.
       //    2). the output is used in backward, we don't need to trace back further, so continue searching.
-      auto op_recompute_config_it = recomputable_op_table.find(curr_node->OpType());
+      bool is_recomputable = IsRecomputable(*curr_node, probe_level);
       auto cur_output_arg_name = curr_node->OutputDefs()[p.second]->Name();
       if (is_first_queue_scan) {
         // We handle the entry node outputs differently because, we don't want this case falls into and succeed one of
@@ -221,14 +509,14 @@ Status SelectRecomputeSubgraph(const Node& entry_node,
         // 1. "op is not in recompute op list, but its output is used in backward"
         // 2. "op is in recompute op list, but its output is used in backward"
         // (either of the above checks is true for entry node outputs)
-        if (op_recompute_config_it == recomputable_op_table.end()) {
+        if (!is_recomputable) {
           early_stop = true;
           MO_LOG_DEBUG_INFO(logger, "Entry Node " + curr_node->Name() + "(" + curr_node->OpType() +
                                         ") is **NOT** in recompute op list, search terminates.");
           break;
         }
       } else {
-        if (op_recompute_config_it == recomputable_op_table.end()) {
+        if (!is_recomputable) {
           if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) {
             MO_LOG_DEBUG_INFO(logger, "Node " + curr_node->Name() + "(" + curr_node->OpType() +
                                           ") is **NOT** in recompute op list, but its output [" +
@@ -283,14 +571,30 @@ Status SelectRecomputeSubgraph(const Node& entry_node,
       }
 
       // Iterate all input nodes according to allowed input arg index of the entry node.
-      const auto& input_arg_indices = op_recompute_config_it->second.input_arg_indices;
+      const auto& igorable_input_arg_indices = GetIgnorableInputIndices(*curr_node, probe_level);
       for (auto it = curr_node->InputEdgesBegin(), end = curr_node->InputEdgesEnd(); it != end; ++it) {
         const Node::EdgeEnd& input_edge = *it;
         const auto& parent_node = input_edge.GetNode();
         const auto parent_node_output_index = input_edge.GetSrcArgIndex();
         const auto current_node_input_index = input_edge.GetDstArgIndex();
-        if (std::find(input_arg_indices.begin(), input_arg_indices.end(), current_node_input_index) !=
-            input_arg_indices.end()) {
+        if (std::find(igorable_input_arg_indices.begin(), igorable_input_arg_indices.end(), current_node_input_index) ==
+            igorable_input_arg_indices.end()) {
+          // If the tensor size is constant and very small (Now < 1M), we stop adding the input edge into queue.
+          auto output_shape = parent_node.OutputDefs()[parent_node_output_index]->Shape();
+          if (output_shape) {
+            bool all_constant_dim = true;
+            int64_t num_elem = 1;
+            for (int k = 0, dim_size = output_shape->dim_size(); k < dim_size; ++k) {
+              if (!output_shape->dim(k).has_dim_value()) {
+                all_constant_dim = false;
+                num_elem *= output_shape->dim(k).dim_value();
+              }
+            }
+            if (all_constant_dim && num_elem < 1 * 1024 * 1024) {
+              // Skip this input index.
+              continue;
+            }
+          }
           NodeOutputPort next_p = std::make_pair(&parent_node, parent_node_output_index);
 
           MO_LOG_DEBUG_INFO(logger, "Node " + parent_node.Name() + "(" + parent_node.OpType() + ")'s " +
diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc
index 04f2679ac774f..c88a0f05d36b8 100644
--- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc
+++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc
@@ -19,6 +19,9 @@ namespace onnxruntime::optimizer::memory_optimizer {
 void FindLayerBoundaryLayerNormNodes(
     const GraphViewer& graph_viewer,
     const logging::Logger&,
+    const InlinedHashMap<NodeIndex, ptrdiff_t>&
+        node_index_to_its_order_in_topological_sort_map,
+    const ptrdiff_t& yield_op_order_in_topological_sort,
     InlinedHashSet<const Node*>& layer_boundary_ln_nodes) {
   // Loop all nodes to find LayerNormalization nodes.
   // For each LayerNormalization node, keep checking its output nodes,
@@ -40,9 +43,16 @@ void FindLayerBoundaryLayerNormNodes(
     std::deque<const Node*> nodes_to_check;
     std::set<const Node*> visited_nodes;
     for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) {
-      nodes_to_check.push_back(&(*node_it));
+      // Ignore those nodes after YieldOp.
+      if (node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) < yield_op_order_in_topological_sort) {
+        nodes_to_check.push_back(&(*node_it));
+      }
     }
 
+    bool unexpected_failure = false;
+    bool found_softmax = false;
+    bool found_layernorm = false;
+    ptrdiff_t next_layernorm_execution_oder = -1;
     while (!nodes_to_check.empty()) {
       const Node* next_node = nodes_to_check.front();
       nodes_to_check.pop_front();
@@ -53,16 +63,40 @@ void FindLayerBoundaryLayerNormNodes(
 
       visited_nodes.insert(next_node);
       if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) {
-        layer_boundary_ln_nodes.insert(&node);
-        break;
+        found_softmax = true;
       } else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) {
-        break;
+        if (found_layernorm) {
+          // If we found another LayerNormalization node, we would report as warning, and do nothing for layer boundary detection.
+          unexpected_failure = true;
+          break;
+        }
+        found_layernorm = true;  // don't trace further
+        next_layernorm_execution_oder = node_index_to_its_order_in_topological_sort_map.at(next_node->Index());
+        continue;
       } else {
         for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) {
+          // Stop if the node is after next Layernorm node in execution order.
+          if (found_layernorm &&
+              node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) >= next_layernorm_execution_oder) {
+            continue;
+          }
           nodes_to_check.push_back(&(*node_it));
         }
       }
     }
+
+    if (unexpected_failure) {
+      layer_boundary_ln_nodes.clear();
+      break;
+    }
+
+    if (found_softmax) {
+      layer_boundary_ln_nodes.insert(&node);
+    } else if (!found_layernorm) {
+      // If no Softmax found, and no other LayerNormalization found, this should be the last LayerNormalization node,
+      // we also consider it as boundary node.
+      layer_boundary_ln_nodes.insert(&node);
+    }
   }
 }
 
diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h
index f2cfd640b0840..b58d822124f43 100644
--- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h
+++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h
@@ -20,6 +20,9 @@ namespace onnxruntime::optimizer::memory_optimizer {
 
 void FindLayerBoundaryLayerNormNodes(const GraphViewer& graph_viewer,
                                      const logging::Logger& logger,
+                                     const InlinedHashMap<NodeIndex, ptrdiff_t>&
+                                         node_index_to_its_order_in_topological_sort_map,
+                                     const ptrdiff_t& yield_op_order_in_topological_sort,
                                      InlinedHashSet<const Node*>& layer_boundary_ln_nodes);
 
 }  // namespace onnxruntime::optimizer::memory_optimizer
diff --git a/orttraining/orttraining/python/training/__init__.py b/orttraining/orttraining/python/training/__init__.py
index a3c22686a1039..1da95dff94f9f 100644
--- a/orttraining/orttraining/python/training/__init__.py
+++ b/orttraining/orttraining/python/training/__init__.py
@@ -23,9 +23,9 @@
 
 try:
     if is_ortmodule_available():
-        from .ortmodule import ORTModule  # noqa: F401
+        from .ortmodule import ORTModule
 
-        __all__.append("ORTModule")
+        __all__ += ["ORTModule"]
 except ImportError:
     # That is OK iff this is not a ORTModule training package
     pass
diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py
index 7a4eb251bc5bc..624b30ffdab3b 100644
--- a/orttraining/orttraining/python/training/artifacts.py
+++ b/orttraining/orttraining/python/training/artifacts.py
@@ -41,13 +41,14 @@ def generate_artifacts(
     requires_grad: Optional[List[str]] = None,
     frozen_params: Optional[List[str]] = None,
     loss: Optional[Union[LossType, onnxblock.Block]] = None,
-    optimizer: Optional[OptimType] = None,
+    optimizer: Optional[Union[OptimType, onnxblock.Block]] = None,
     artifact_directory: Optional[Union[str, bytes, os.PathLike]] = None,
     prefix: str = "",
     ort_format: bool = False,
     custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None,
     additional_output_names: Optional[List[str]] = None,
     nominal_checkpoint: bool = False,
+    loss_input_names: Optional[List[str]] = None,
 ) -> None:
     """Generates artifacts required for training with ORT training api.
 
@@ -63,8 +64,8 @@ def generate_artifacts(
         model: The base model to be used for gradient graph generation.
         requires_grad: List of names of model parameters that require gradient computation
         frozen_params: List of names of model parameters that should be frozen.
-        loss: The loss function enum to be used for training. If None, no loss node is added to the graph.
-        optimizer: The optimizer enum to be used for training. If None, no optimizer model is generated.
+        loss: The loss function enum or onnxblock to be used for training. If None, no loss node is added to the graph.
+        optimizer: The optimizer enum or onnxblock to be used for training. If None, no optimizer model is generated.
         artifact_directory: The directory to save the generated artifacts.
             If None, the current working directory is used.
         prefix: The prefix to be used for the generated artifacts. If not specified, no prefix is used.
@@ -77,7 +78,9 @@ def generate_artifacts(
             Default is False. Nominal checkpoint is a checkpoint that contains nominal information about the model
             parameters. It can be used on the device to reduce overhead while constructing the training model
             as well as to reduce the size of the checkpoint packaged with the on-device application.
-
+        loss_input_names: Specifies a list of input names to be used specifically for the loss computation. When provided,
+            only these inputs will be passed to the loss function. If `None`, all graph outputs are passed to
+            the loss function.
     Raises:
         RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block`
         RuntimeError: If the optimizer provided is not one of the supported optimizers.
@@ -111,11 +114,16 @@ def generate_artifacts(
         logging.info("Custom loss block provided: %s", loss.__class__.__name__)
 
     class _TrainingBlock(onnxblock.TrainingBlock):
-        def __init__(self, _loss):
+        def __init__(self, _loss, _loss_input_names=None):
             super().__init__()
             self._loss = _loss
+            self._loss_input_names = _loss_input_names
 
         def build(self, *inputs_to_loss):
+            # If loss_input_names is passed, only pass the specified input names to the loss function.
+            if self._loss_input_names:
+                inputs_to_loss = self._loss_input_names
+
             if additional_output_names:
                 # If additional output names is not a list, raise an error
                 if not isinstance(additional_output_names, list):
@@ -132,7 +140,7 @@ def build(self, *inputs_to_loss):
 
             return self._loss(*inputs_to_loss)
 
-    training_block = _TrainingBlock(loss_block)
+    training_block = _TrainingBlock(loss_block, loss_input_names)
 
     if requires_grad is not None and frozen_params is not None and set(requires_grad).intersection(set(frozen_params)):
         raise RuntimeError(
@@ -157,9 +165,11 @@ def build(self, *inputs_to_loss):
         logging.info("Custom op library provided: %s", custom_op_library)
         custom_op_library_path = pathlib.Path(custom_op_library)
 
-    with onnxblock.base(model), onnxblock.custom_op_library(
-        custom_op_library_path
-    ) if custom_op_library is not None else contextlib.nullcontext():
+    with onnxblock.base(model), (
+        onnxblock.custom_op_library(custom_op_library_path)
+        if custom_op_library is not None
+        else contextlib.nullcontext()
+    ):
         _ = training_block(*[output.name for output in model.graph.output])
         training_model, eval_model = training_block.to_model_proto()
         model_params = training_block.parameters()
@@ -209,14 +219,6 @@ def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_
         logging.info("No optimizer enum provided. Skipping optimizer model generation.")
         return
 
-    if not isinstance(optimizer, OptimType):
-        raise RuntimeError(
-            f"Unknown optimizer provided {type(optimizer)}. Expected optimizer to be of type "
-            "onnxruntime.training.artifacts.OptimType."
-        )
-
-    logging.info("Optimizer enum provided: %s", optimizer.name)
-
     opset_version = None
     for domain in model.opset_import:
         if domain.domain == "" or domain.domain == "ai.onnx":
@@ -225,8 +227,19 @@ def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_
 
     optim_model = None
     optim_blocks = {OptimType.AdamW: onnxblock.optim.AdamW, OptimType.SGD: onnxblock.optim.SGD}
+    optim_block = None
+    if isinstance(optimizer, OptimType):
+        logging.info("Optimizer enum provided: %s", optimizer.name)
+        optim_block = optim_blocks[optimizer]()
+    elif isinstance(optimizer, onnxblock.Block):
+        logging.info("Optimizer block provided: %s", optimizer.__class__.__name__)
+        optim_block = optimizer
+    else:
+        raise TypeError(
+            f"Unknown optimizer provided {type(optimizer)}. Expected optimizer to be either one of"
+            "onnxruntime.training.artifacts.OptimType or onnxruntime.training.onnxblock.Block."
+        )
 
-    optim_block = optim_blocks[optimizer]()
     with onnxblock.empty_base(opset_version=opset_version):
         _ = optim_block(model_params)
         optim_model = optim_block.to_model_proto()
diff --git a/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py b/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py
index d7bbd249a000e..ff128c4da4259 100644
--- a/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py
+++ b/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py
@@ -15,7 +15,6 @@
 class ApexAMPModifier(FP16OptimizerModifier):
     def __init__(self, optimizer, **kwargs) -> None:
         super().__init__(optimizer)
-        pass
 
     def can_be_modified(self):
         return self.check_requirements(
diff --git a/orttraining/orttraining/python/training/ort_triton/__init__.py b/orttraining/orttraining/python/training/ort_triton/__init__.py
index fbb59d1354ae7..5f2d0c62ffa50 100644
--- a/orttraining/orttraining/python/training/ort_triton/__init__.py
+++ b/orttraining/orttraining/python/training/ort_triton/__init__.py
@@ -9,6 +9,7 @@
 from onnxruntime.capi import _pybind_state as _C
 
 from .kernel import *  # noqa: F403
+from .triton_op_executor import register_triton_kernel  # noqa: F401
 from .triton_op_executor import call_triton_by_name, call_triton_by_onnx, get_config
 
 
diff --git a/orttraining/orttraining/python/training/ort_triton/_ir.py b/orttraining/orttraining/python/training/ort_triton/_ir.py
index a2b8407645c46..a963d30a9e6e7 100644
--- a/orttraining/orttraining/python/training/ort_triton/_ir.py
+++ b/orttraining/orttraining/python/training/ort_triton/_ir.py
@@ -392,5 +392,8 @@ def __init__(
             for ir_node in kernel.sub_nodes:
                 if isinstance(ir_node, DropoutNode):
                     ir_node.global_offset = running_offset
+                    kernel.offset_calc.symbolic_shape_variables.update(
+                        [symbol.name for symbol in running_offset.free_symbols]
+                    )
                     running_offset = running_offset + sympy.prod(ir_node.outputs[0].shape)
                     self.has_dropout = True
diff --git a/orttraining/orttraining/python/training/ort_triton/_lowering.py b/orttraining/orttraining/python/training/ort_triton/_lowering.py
index 5c848d2cecc58..4b580a0cc86de 100644
--- a/orttraining/orttraining/python/training/ort_triton/_lowering.py
+++ b/orttraining/orttraining/python/training/ort_triton/_lowering.py
@@ -312,7 +312,7 @@ def _group_nodes(self):
             for j in range(i + 1, len(groups)):
                 if any(output in group_inputs for output in groups[j].nodes_groups[0].output):
                     group_dependencies[i].add(j)
-                    for k in range(0, i):
+                    for k in range(i):
                         if i in group_dependencies[k]:
                             group_dependencies[k].add(j)
 
diff --git a/orttraining/orttraining/python/training/ort_triton/_utils.py b/orttraining/orttraining/python/training/ort_triton/_utils.py
index 95e6703be8783..877eacc0b775f 100644
--- a/orttraining/orttraining/python/training/ort_triton/_utils.py
+++ b/orttraining/orttraining/python/training/ort_triton/_utils.py
@@ -141,13 +141,14 @@ def get_reduce_info(node: NodeProto, graph: GraphProto, input_rank: int) -> Tupl
 
 
 def next_power_of_2(n: int) -> int:
-    assert n <= 2**32, "32-bit only"
+    """Return the smallest power of 2 greater than or equal to n"""
     n -= 1
     n |= n >> 1
     n |= n >> 2
     n |= n >> 4
     n |= n >> 8
     n |= n >> 16
+    n |= n >> 32
     n += 1
     return n
 
diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py
index 03bb0f4373d8d..f7b7c1ff08300 100644
--- a/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py
+++ b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py
@@ -694,7 +694,7 @@ def _bwd_kernel(
     LSE += off_hb * seqlen_q_rounded
     if not SEQUENCE_PARALLEL:
         num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
-        for start_n in range(0, num_block_n):
+        for start_n in range(num_block_n):
             _bwd_kernel_one_col_block(
                 start_n,
                 Q,
diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py
index f16abc71251ed..e104ea13c59a3 100644
--- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py
+++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py
@@ -23,6 +23,8 @@
 
 _DEBUG_MODE = "ORTMODULE_TRITON_DEBUG" in os.environ and int(os.getenv("ORTMODULE_TRITON_DEBUG")) == 1
 
+_CUSTOM_KERNELS = dict()
+
 
 @functools.lru_cache(None)
 def _gen_module_internal(sorted_graph: SortedGraph) -> Tuple[str, str, ModuleType]:
@@ -113,7 +115,10 @@ def call_triton_by_name(func_name: str, *tensors, **kwargs):
     """
 
     torch_tensors = [_from_dlpack(tensor) if tensor is not None else None for tensor in tensors]
-    func = getattr(sys.modules[".".join(__name__.split(".")[:-1])], func_name)
+    func = getattr(sys.modules[".".join(__name__.split(".")[:-1])], func_name, None)
+    if func is None:
+        func = _CUSTOM_KERNELS.get(func_name)
+    assert func is not None, f"Function {func_name} is not found in the registered kernels."
     output = func(*torch_tensors, **kwargs)
     if output is not None:
         if isinstance(output, tuple):
@@ -138,3 +143,8 @@ def call_triton_by_onnx(onnx_key: int, onnx_str: bytes, *tensors):
     if isinstance(output, tuple):
         return tuple([to_dlpack(tensor) for tensor in output])
     return to_dlpack(output)
+
+
+def register_triton_kernel(fn):
+    _CUSTOM_KERNELS[fn.__name__] = fn
+    return fn
diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py
index fbf1b7c2bac42..4a03465cf2ead 100644
--- a/orttraining/orttraining/python/training/ortmodule/__init__.py
+++ b/orttraining/orttraining/python/training/ortmodule/__init__.py
@@ -39,7 +39,7 @@ def _defined_from_envvar(name, default_value, warn=True):
 # NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and
 # assign them new values. Importing them directly do not propagate changes.
 ################################################################################
-ONNX_OPSET_VERSION = 15
+ONNX_OPSET_VERSION = 17
 MINIMUM_RUNTIME_PYTORCH_VERSION_STR = "1.8.1"
 ORTMODULE_TORCH_CPP_DIR = os.path.join(os.path.dirname(__file__), "torch_cpp_extensions")
 _FALLBACK_INIT_EXCEPTION = None
diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py
index 9288027f0188c..f81aef5f6b9c4 100644
--- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py
+++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py
@@ -821,3 +821,27 @@ def upsample_bicubic2d(g, input, output_size, align_corners, scale_factors):
         operator_s="upsample_bicubic2d",
         overload_name_s="vec",
     )
+
+
+@register_symbolic("layer_norm")
+@parse_args("v", "is", "v", "v", "f", "none")
+def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
+    # normalized_shape: input shape from an expected input of size
+    # axis: The first normalization dimension.
+    # layer_norm normalizes on the last D dimensions,
+    # where D is the size of normalized_shape
+    axis = -len(normalized_shape)
+
+    res, new_running_mean, new_running_var = g.op(
+        "LayerNormalization",
+        input,
+        weight,
+        bias,
+        epsilon_f=eps,
+        axis_i=axis,
+        outputs=3,  # force all 3 outputs to be exported in training mode
+        operator_s="layer_norm",
+        overload_name_s="vec",
+    )
+
+    return res
diff --git a/orttraining/orttraining/python/training/ortmodule/_fallback_exceptions.py b/orttraining/orttraining/python/training/ortmodule/_fallback_exceptions.py
index 12780016a9ab1..871d3fff8ce3f 100644
--- a/orttraining/orttraining/python/training/ortmodule/_fallback_exceptions.py
+++ b/orttraining/orttraining/python/training/ortmodule/_fallback_exceptions.py
@@ -10,8 +10,6 @@ class ORTModuleFallbackException(Exception):  # noqa: N818
     it can also be used for generic exception that require fallback
     """
 
-    pass
-
 
 class ORTModuleInitException(ORTModuleFallbackException):
     """Trigger fallback for ORTModule initialization related exceptions
@@ -20,8 +18,6 @@ class ORTModuleInitException(ORTModuleFallbackException):
     including PyTorch version, missing ORTModule's PyTorch C++ extension binaries, etc.
     """
 
-    pass
-
 
 class ORTModuleDeviceException(ORTModuleFallbackException):
     """Trigger fallback for device related exceptions
@@ -31,8 +27,6 @@ class ORTModuleDeviceException(ORTModuleFallbackException):
     This exception does not capture these scenarios.
     """
 
-    pass
-
 
 class ORTModuleIOError(ORTModuleFallbackException):
     """Trigger fallback for I/O related exceptions
@@ -42,8 +36,6 @@ class ORTModuleIOError(ORTModuleFallbackException):
     This exception does not capture these scenarios.
     """
 
-    pass
-
 
 class ORTModuleTorchModelException(ORTModuleFallbackException):
     """Trigger fallback for PyTorch modules related exceptions
@@ -52,8 +44,6 @@ class ORTModuleTorchModelException(ORTModuleFallbackException):
     checking type(model) over a hardcoded list of incompatible models.
     """
 
-    pass
-
 
 class ORTModuleONNXModelException(ORTModuleFallbackException):
     """Trigger fallback for ONNX model related exceptions
@@ -61,8 +51,6 @@ class ORTModuleONNXModelException(ORTModuleFallbackException):
     This exception is raised during model conversion to ONNX and post-processing validation within ORTModule frontend.
     """
 
-    pass
-
 
 def wrap_exception(
     new_exception: ORTModuleFallbackException, raised_exception: Exception
diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
index 779b6bfe50422..568c92b71277f 100755
--- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
+++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py
@@ -20,7 +20,6 @@
 from onnxruntime.capi import _pybind_state as C
 from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
 from onnxruntime.training.utils import ORTModelInputOutputSchemaType, PTable, onnx_dtype_to_pytorch_dtype
-from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3
 
 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils
 from ._fallback import (
@@ -143,6 +142,9 @@ def __init__(
 
         self._zero_stage3_param_map = {}
         if self._runtime_options.enable_zero_stage3_support:
+            # Move import to here to avoid circular dependency error
+            from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3  # type: ignore[import]
+
             # Cannot toggle feature enabling/disabling after the first time enabled.
 
             configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True)
@@ -186,7 +188,6 @@ def forward(self):
         This is an abstract method and must be overridden by a concrete implementation.
         This is the only method that the user should call on a concrete instance of the ExecutionManager
         All other methods are internal"""
-        pass
 
     def _build_graph(self, config):
         if self._runtime_options.use_static_shape:
@@ -410,9 +411,9 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu
                     # From some PyTorch version, autograd_inlining is a valid argument.
                     # We allow it to be True if custom autograd function is disabled (where autograd.Function
                     # anyway is not supported in ONNX until it can be inlined).
-                    required_export_kwargs[
-                        "autograd_inlining"
-                    ] = not self._runtime_options.enable_custom_autograd_function
+                    required_export_kwargs["autograd_inlining"] = (
+                        not self._runtime_options.enable_custom_autograd_function
+                    )
 
                 invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys()
 
@@ -679,11 +680,15 @@ def _enable_conditional_optimizations(
                     )
 
                 if self._runtime_options.enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0:
-                    graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys())
-                    self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results)
-                    self._runtime_options.embed_sparsity_ratio = ",".join(
-                        [f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()]
-                    )
+                    if detected_device.type == "cuda":
+                        # Embedding sparsity optimization is only supported on CUDA devices.
+                        graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys())
+                        self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results)
+                        self._runtime_options.embed_sparsity_ratio = ",".join(
+                            [f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()]
+                        )
+                    else:
+                        self._logger.info("Embedding sparsity-based optimization is not supported on non-CUDA devices.")
 
             # If users don't want to print input density, disable the input density observer to avoid overhead
             # when looping through inputs during training.
@@ -748,6 +753,11 @@ def _add_record(tbl, columns):
 
         if self._runtime_options.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE:
             opt_config_to_display = "ALL_RECOMPUTE_FOR_EACH_LAYER"
+        elif (
+            self._runtime_options.memory_optimization_level
+            == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE
+        ):
+            opt_config_to_display = "ALL_RECOMPUTE_FOR_EACH_LAYER_WITH_COMPROMISE"
         else:
             opt_config_to_display = self._runtime_options.memory_optimizer_config
 
@@ -760,7 +770,7 @@ def _add_record(tbl, columns):
                     f"Memory Optimization Level: [{_MemoryOptimizationLevel.to_string(self._runtime_options.memory_optimization_level)}], "
                     f"Optimization Config: [{opt_config_to_display}]"
                     if len(self._runtime_options.memory_optimizer_config) > 0
-                    else "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,..."
+                    else "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1/2 or ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,..."
                 ),
             ],
         )
diff --git a/orttraining/orttraining/python/training/ortmodule/_logger.py b/orttraining/orttraining/python/training/ortmodule/_logger.py
index a01db28374b8d..91b99d4323d6f 100644
--- a/orttraining/orttraining/python/training/ortmodule/_logger.py
+++ b/orttraining/orttraining/python/training/ortmodule/_logger.py
@@ -267,9 +267,11 @@ def wrapper(graph_execution_manager, *args, **kwargs):
                 on_exit=partial(
                     _log_with_filter,
                     graph_execution_manager._logger,
-                    graph_execution_manager._debug_options.onnxruntime_log_filter
-                    if self.is_ort_filter
-                    else graph_execution_manager._debug_options.torch_exporter_filter,
+                    (
+                        graph_execution_manager._debug_options.onnxruntime_log_filter
+                        if self.is_ort_filter
+                        else graph_execution_manager._debug_options.torch_exporter_filter
+                    ),
                     self.phase.to_string(),
                 ),
             ):
diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py
index 078ce4d27cd6f..d3fe132609a90 100644
--- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py
+++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py
@@ -14,7 +14,7 @@
 from sympy import Symbol, simplify
 from sympy.parsing.sympy_parser import parse_expr
 
-from onnxruntime.training.utils import PTable
+from onnxruntime.training.utils import PTable, log_memory_usage
 
 from ._execution_agent import TrainingAgent
 from .options import _MemoryOptimizationLevel, _RuntimeOptions
@@ -433,9 +433,7 @@ def _print_embed_label_stats(self):
                 total_token,
                 valid_token_per_batch,
             ) in self._stats:
-                stat += "\t| {:<10} | {:<10} | {:<15} | {:<10} | {:<9.2f}% | {:<15} | {:<15} | {:<15} |\n".format(
-                    step, input_type, input_name, padding_idx, density, valid_token, total_token, valid_token_per_batch
-                )
+                stat += f"\t| {step:<10} | {input_type:<10} | {input_name:<15} | {padding_idx:<10} | {density:<9.2f}% | {valid_token:<15} | {total_token:<15} | {valid_token_per_batch:<15} |\n"
             stat += "<<<\n"
             self._logger.info(stat)
             self._stats.clear()
@@ -509,6 +507,8 @@ def __init__(self, m: torch.nn.Module, logger: Logger):
 
         self._is_first_inspect = True
 
+        self._m = m
+
     def is_enabled(self) -> bool:
         """Check if memory inspector is enabled."""
         return self._is_enabled
@@ -543,7 +543,10 @@ def find_memory_optimization_opportunity(self, execution_agent: TrainingAgent, r
 
         # If the memory optimization level is aggressive, we will first collect all
         # recompute subgraph by passing empty memory_optimizer_config to get_serialized_ortmodule_memory_stat.
-        if runtime_options.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE:
+        if runtime_options.memory_optimization_level in [
+            _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE,
+            _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE,
+        ]:
             memory_optimizer_config = ""
 
         (
@@ -579,16 +582,27 @@ def find_memory_optimization_opportunity(self, execution_agent: TrainingAgent, r
             self.cluster_id_combination_to_saving_symbolics_map[cluster_id] = values
 
         # For aggressive memory optimization, we update the memory_optimizer_config using all.
-        if runtime_options.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE:
+        if runtime_options.memory_optimization_level > 0:
             recompute_configs = []
             for cluster_id in self.cluster_id_combination_to_saving_symbolics_map:
                 config_values = cluster_id.split(":")
                 opt_type = int(config_values[1])
-                # TODO(pengwa): use enum instead of 1 here.
-                if opt_type != 1:
-                    continue
-
-                recompute_configs.append(cluster_id)
+                if (
+                    runtime_options.memory_optimization_level
+                    == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE
+                    and opt_type == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE
+                ):
+                    recompute_configs.append(cluster_id)
+                elif (
+                    runtime_options.memory_optimization_level
+                    == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE
+                    and opt_type
+                    in [
+                        _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE,
+                        _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE,
+                    ]
+                ):
+                    recompute_configs.append(cluster_id)
 
             runtime_options.memory_optimizer_config = ",".join(recompute_configs)
 
@@ -621,29 +635,13 @@ def inspect_memory(self, cur_phase: Phase):
         need_print = self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0)
 
         if need_print:
-            cur_mem_allocated = self._normalize(torch.cuda.memory_allocated())
-            max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated())
-            cur_mem_cached = self._normalize(torch.cuda.memory_reserved())
-            max_mem_cached = self._normalize(torch.cuda.max_memory_reserved())
-            torch_mem_stat = torch.cuda.memory_stats()
-            cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0))
-            max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0))
-
-            mem_stats = [
-                ["phase", _convert_phase_to_string(cur_phase)],
-                ["allocated", cur_mem_allocated],  # current memory allocated for tensors
-                ["max allocated", max_mem_allocated],  # peak memory allocated for tensors
-                ["cached", cur_mem_cached],  # current memory cached for the caching allocator
-                ["max cached", max_mem_cached],  # peak memory cached for caching allocator.
-                ["inactive", cur_mem_inactive],  # amount of inactive, non-releasable memory
-                ["max inactive", max_mem_inactive],  # peak of inactive, non-releasable memory
-            ]
-
-            summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})"
-            for stat in mem_stats:
-                summ += f" | {stat[0]}: {stat[1]}"
-
-            self._logger.info(summ)
+            log_memory_usage(
+                _convert_phase_to_string(cur_phase),
+                rank_0_only=True,
+                step_info=f"step {self._current_step}",
+                logger=self._logger,
+                module=self._m,
+            )
 
         if cur_phase == self._last_phase:
             self._increase_step()
@@ -655,9 +653,6 @@ def inspect_memory(self, cur_phase: Phase):
     def _increase_step(self):
         self._current_step += 1
 
-    def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str:
-        return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}"
-
     def display_memory_optimization_plans(self, memory_optimizer_config, details=False) -> Tuple[List[str], PTable]:
         mem_plan_count = len(self.cluster_id_combination_to_saving_symbolics_map)
 
@@ -700,9 +695,11 @@ def _get_user_config_without_freq(configs: str):
                     [
                         f" - Plan {index}",
                         ":",
-                        "ON"
-                        if all(cluster_id in user_configs_with_out_freq for cluster_id in cluster_ids_without_freq)
-                        else "OFF",
+                        (
+                            "ON"
+                            if all(cluster_id in user_configs_with_out_freq for cluster_id in cluster_ids_without_freq)
+                            else "OFF"
+                        ),
                         ":",
                         cluster_id,
                         saving_symbolic.freq if details else "",
@@ -716,14 +713,16 @@ def _get_user_config_without_freq(configs: str):
             notes = []
             if details:
                 notes.append(
-                    "[Memory Optimizer] Use ORTMODULE_MEMORY_OPT_LEVEL=1 to enable all recomputable subgraphs per transformer layer."
+                    "[Memory Optimizer] Use ORTMODULE_MEMORY_OPT_LEVEL=1/2 to enable all recomputable subgraphs per transformer layer."
                 )
                 saving_recommendation = "[Memory Optimizer] Or use comma as a delimiter to selectively enable multiple memory optimization plans:\n"
                 saving_recommendation += "  export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,..."
 
                 notes.append(saving_recommendation)
 
-                saving_recommendation = "memory saving is calculated based on the 1st batch symbolic dim values:\n"
+                saving_recommendation = (
+                    "[Memory Optimizer] memory saving is calculated based on the 1st batch symbolic dim values:\n"
+                )
                 for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items():
                     saving_recommendation += f"  {dim_param}={dim_value},"
                 notes.append(saving_recommendation)
diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py
index cc533e549db92..5fa332d12f01c 100644
--- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py
+++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py
@@ -171,10 +171,10 @@ def backward(ctx, *grad_outputs):
                 for idx, grad_output in enumerate(grad_outputs):
                     if idx in self._graph_info.output_grad_indices_non_differentiable:
                         assert grad_output is None, (
-                            "ORT found the {}-th module output '{}' is "
+                            f"ORT found the {idx}-th module output '{self._graph_info.user_output_names[idx]}' is "
                             "non-differentiable according to the onnx graph. "
                             "However, the gradient value is still provided by "
-                            "PyTorch's autograd engine.".format(idx, self._graph_info.user_output_names[idx])
+                            "PyTorch's autograd engine."
                         )
                         continue
 
@@ -196,18 +196,20 @@ def backward(ctx, *grad_outputs):
 
                 # Run and get results
                 backward_outputs = C.OrtValueVector()
-                self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
-                # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
-                # affect peak memory usage in a subsequent graph run.
-                del ctx.run_info.state
-
-                # Fast version: all backward_outputs are converted first.
-                # This version only works if backward_outputs is an OrtValueVector.
-                transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)
-
-                self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD)
-
-                return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
+                try:
+                    self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
+                    # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
+                    # affect peak memory usage in a subsequent graph run.
+
+                    # Fast version: all backward_outputs are converted first.
+                    # This version only works if backward_outputs is an OrtValueVector.
+                    transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)
+
+                    self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD)
+                    res = tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
+                    return res
+                finally:
+                    del ctx.run_info.state
 
         return _ORTModuleFunction
 
diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py
index 91825fc492208..5faa1c62bae4f 100644
--- a/orttraining/orttraining/python/training/ortmodule/_utils.py
+++ b/orttraining/orttraining/python/training/ortmodule/_utils.py
@@ -91,7 +91,7 @@ def _ortvalues_to_torch_tensor(
         # Second option makes it impossible to directly use `_from_dlpack` or
         # or `from_dlpack` from torch.
         # The best option would be to add boolean type in DLDataTypeCode.
-        for i in range(0, len(bool_indices)):
+        for i in range(len(bool_indices)):
             j = bool_indices[i]
             res[j] = res[j].to(torch.bool)
 
diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py
index 3d3538a62da61..368d1b238fd9e 100644
--- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py
+++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py
@@ -13,7 +13,7 @@
 if (
     "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ
     and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1
-    and Version(torch.__version__) >= Version("2.1.1")
+    and Version(torch.__version__) >= Version("2.3.0")
 ):
     from ._aten_attn import optimize_graph_for_aten_efficient_attention  # noqa: F401
 
diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py
index b1e8809f03fc0..c1fb6e68568f5 100644
--- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py
+++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py
@@ -5,9 +5,12 @@
 
 """
 PyTorch's _efficient_attention_forward/_efficient_attention_backward APIs is keep changing. Current implementation
-is tested well on version 2.2.0.dev20231010+cu121, and should be run well since official version 2.2.0. If may fail to
+is tested well on version 2.3.0.dev20240221+cu118, and should be run well since official version 2.3.0. If may fail to
 run is you are using PyTorch with older versions.
 
+This file is more like an example of how to add a new graph optimizer. Ideally user can add graph optimizer according
+to the specific model they are using on their own instead of putting every possible graph optimizer here.
+
 PyTorch also has API for flash attention (currently doesn't support random attention mask or Dropout), we can add
 support if we want to try in the future.
 """
@@ -40,13 +43,14 @@ def _make_efficient_attention_nodes(
     scale_node = make_constant_node("scale_" + str(idx), TensorProto.FLOAT, [], [scale])
     dropout_ratio_node = make_constant_node("dropout_ratio_" + str(idx), TensorProto.FLOAT, [], [dropout_ratio])
     causal_node = make_constant_node("causal_" + str(idx), TensorProto.INT64, [], [1 if causal else 0])
-    int_zero_node = make_constant_node("int_zero_" + str(idx), TensorProto.INT64, [], [0])
-    true_node = make_constant_node("true_" + str(idx), TensorProto.BOOL, [], [True])
-    false_node = make_constant_node("false_" + str(idx), TensorProto.BOOL, [], [False])
+    one_node = make_constant_node("one_" + str(idx), TensorProto.INT64, [], [1])
+    zero_node = make_constant_node("zero_" + str(idx), TensorProto.INT64, [], [0])
     logsumexp = helper.make_tensor_value_info("logsumexp" + str(idx), TensorProto.FLOAT, [])
     seed = helper.make_tensor_value_info("seed" + str(idx), TensorProto.INT64, [])
     offset = helper.make_tensor_value_info("offset" + str(idx), TensorProto.INT64, [])
-    new_value_infos = [logsumexp, seed, offset]
+    msb_q = helper.make_tensor_value_info("msb_q_" + str(idx), TensorProto.INT64, [])
+    msb_k = helper.make_tensor_value_info("msb_k_" + str(idx), TensorProto.INT64, [])
+    new_value_infos = [logsumexp, seed, offset, msb_q, msb_k]
     if expand_bias:
         shape_0 = helper.make_node("Shape", [q], ["shape_0_" + str(idx)], start=0, end=1)
         shape_1 = helper.make_node("Shape", [q], ["shape_1_" + str(idx)], start=2, end=3)
@@ -54,13 +58,13 @@ def _make_efficient_attention_nodes(
         shape_3 = helper.make_node("Shape", [k], ["shape_3_" + str(idx)], start=1, end=2)
         concat = helper.make_node(
             "Concat",
-            ["shape_0_" + str(idx), "shape_1_" + str(idx), "shape_2_" + str(idx), "shape_3_" + str(idx)],
+            [shape_0.output[0], shape_1.output[0], shape_2.output[0], shape_3.output[0]],
             ["concated_shape_" + str(idx)],
             axis=0,
         )
-        expand = helper.make_node("Expand", [bias, "concated_shape_" + str(idx)], ["expanded_bias_" + str(idx)])
+        expand = helper.make_node("Expand", [bias, concat.output[0]], ["expanded_bias_" + str(idx)])
         nodes_to_add.extend([shape_0, shape_1, shape_2, shape_3, concat, expand])
-        bias = "expanded_bias_" + str(idx)
+        bias = expand.output[0]
     fwd_node = helper.make_node(
         "ATen",
         [
@@ -71,18 +75,21 @@ def _make_efficient_attention_nodes(
             "",
             "",
             "",
+            "",
             dropout_ratio_node.output[0],
             causal_node.output[0],
-            true_node.output[0],
+            one_node.output[0],
             scale_node.output[0],
             "",
             "",
         ],
-        [y, logsumexp.name, seed.name, offset.name],
+        [y, logsumexp.name, seed.name, offset.name, msb_q.name, msb_k.name],
         "efficient_attention_forward_" + str(idx),
         None,
         "org.pytorch.aten",
         operator="_efficient_attention_forward",
+        cpu_input_args=[4, 5, 12, 13],
+        cpu_output_args=[2, 3, 4, 5],
     )
     bwd_node = helper.make_node(
         "ATen",
@@ -95,14 +102,14 @@ def _make_efficient_attention_nodes(
             y,
             "",
             "",
-            int_zero_node.output[0],
-            int_zero_node.output[0],
+            msb_q.name,
+            msb_k.name,
             logsumexp.name,
             dropout_ratio_node.output[0],
             seed.name,
             offset.name,
             causal_node.output[0],
-            false_node.output[0],
+            zero_node.output[0],
             scale_node.output[0],
             "",
         ],
@@ -111,10 +118,9 @@ def _make_efficient_attention_nodes(
         None,
         "org.pytorch.aten",
         operator="_efficient_attention_backward",
+        cpu_input_args=[6, 7, 12, 13],
     )
-    nodes_to_add.extend(
-        [scale_node, dropout_ratio_node, causal_node, int_zero_node, true_node, false_node, fwd_node, bwd_node]
-    )
+    nodes_to_add.extend([scale_node, dropout_ratio_node, causal_node, one_node, zero_node, fwd_node, bwd_node])
     return nodes_to_add, new_value_infos
 
 
@@ -240,140 +246,9 @@ def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodePro
     return nodes, nodes_to_add, new_value_infos
 
 
-# No causal mask, no attention mask, without Dropout.
-_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
-    ("MatMul", False, []),  # 0
-    ("Mul", True, [(0, 0, 0)]),  # 1
-    ("Mul", True, [(0, 0, 1)]),  # 2
-    ("Transpose", True, [(1, 0, 0)]),  # 3
-    ("Transpose", True, [(2, 0, 0)]),  # 4
-    ("Softmax", False, [(0, 0, 0)]),  # 5
-    ("MatMul", False, [(5, 0, 0)]),  # 6
-    ("Transpose", True, [(6, 0, 1)]),  # 7
-    ("Transpose", False, [(6, 0, 0)]),  # 8
-    ("FusedMatMul", False, [(7, 0, 1)]),  # 9
-    ("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]),  # 10
-    ("FusedMatMul", False, [(2, 0, 1), (10, 0, 0)]),  # 11
-    ("FusedMatMul", False, [(1, 0, 0), (10, 0, 1)]),  # 12
-    ("Mul", False, [(11, 0, 0)]),  # 13
-    ("Mul", False, [(12, 0, 0)]),  # 14
-    ("Identity", False, [(13, 0, 0)]),  # 15
-    ("Identity", False, [(14, 0, 0)]),  # 16
-    ("Transpose", False, [(15, 0, 0)]),  # 17
-    ("Transpose", False, [(16, 0, 0)]),  # 18
-    ("FusedMatMul", False, [(5, 0, 0)]),  # 19
-    ("Transpose", True, [(19, 0, 1)]),  # 20
-    ("Transpose", False, [(19, 0, 0)]),  # 21
-]
-
-
-def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
-    # Check forward only as the backward is expected to be consistent if it's built correctly.
-    scale_value_1 = matcher.get_constant_value(nodes[1].input[1])
-    scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1
-    scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
-    scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
-    if not (
-        check_attribute_value(nodes[3], "perm", [0, 2, 1, 3])
-        and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1])
-        and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3])
-        and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3])
-        and scale_value_1 == scale_value_2
-    ):
-        return [], [], []
-
-    nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
-        idx,
-        nodes[3].input[0],
-        nodes[4].input[0],
-        nodes[7].input[0],
-        nodes[8].output[0],
-        nodes[20].input[0],
-        nodes[17].output[0],
-        nodes[18].output[0],
-        nodes[21].output[0],
-        "",
-        False,
-        scale_value_1,
-        0.0,
-        False,
-    )
-    return nodes, nodes_to_add, new_value_infos
-
-
-# Has causal mask, no attention mask, without Dropout.
-_PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
-    ("MatMul", False, []),  # 0
-    ("Mul", True, [(0, 0, 0)]),  # 1
-    ("Mul", True, [(0, 0, 1)]),  # 2
-    ("Transpose", True, [(1, 0, 0)]),  # 3
-    ("Transpose", True, [(2, 0, 0)]),  # 4
-    ("Add", False, [(0, 0, 0)]),  # 5
-    ("Slice", True, [(5, 0, 1)]),  # 6
-    ("Slice", True, [(6, 0, 0)]),  # 7
-    ("Unsqueeze", True, [(6, 0, 2)]),  # 8
-    ("Gather", True, [(8, 0, 0)]),  # 9
-    ("Shape", True, [(9, 0, 0)]),  # 10
-    ("Softmax", False, [(5, 0, 0)]),  # 11
-    ("MatMul", False, [(11, 0, 0)]),  # 12
-    ("Transpose", True, [(12, 0, 1)]),  # 13
-    ("Transpose", False, [(12, 0, 0)]),  # 14
-    ("FusedMatMul", False, [(13, 0, 1)]),  # 15
-    ("SoftmaxGrad_13", False, [(15, 0, 0), (11, 0, 1)]),  # 16
-    ("Identity", False, [(16, 0, 0)]),  # 17
-    ("FusedMatMul", False, [(2, 0, 1), (17, 0, 0)]),  # 18
-    ("FusedMatMul", False, [(1, 0, 0), (17, 0, 1)]),  # 19
-    ("Mul", False, [(18, 0, 0)]),  # 20
-    ("Mul", False, [(19, 0, 0)]),  # 21
-    ("Identity", False, [(20, 0, 0)]),  # 22
-    ("Identity", False, [(21, 0, 0)]),  # 23
-    ("Transpose", False, [(22, 0, 0)]),  # 24
-    ("Transpose", False, [(23, 0, 0)]),  # 25
-    ("FusedMatMul", False, [(11, 0, 0)]),  # 26
-    ("Transpose", True, [(26, 0, 1)]),  # 27
-    ("Transpose", False, [(26, 0, 0)]),  # 28
-]
-
-
-def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
-    # Check forward only as the backward is expected to be consistent if it's built correctly.
-    scale_value_1 = matcher.get_constant_value(nodes[1].input[1])
-    scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1
-    scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
-    scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
-    if not (
-        check_attribute_value(nodes[3], "perm", [0, 2, 1, 3])
-        and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1])
-        and check_attribute_value(nodes[13], "perm", [0, 2, 1, 3])
-        and check_attribute_value(nodes[14], "perm", [0, 2, 1, 3])
-        and scale_value_1 == scale_value_2
-    ):
-        return [], [], []
-
-    nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
-        idx,
-        nodes[3].input[0],
-        nodes[4].input[0],
-        nodes[13].input[0],
-        nodes[14].output[0],
-        nodes[27].input[0],
-        nodes[24].output[0],
-        nodes[25].output[0],
-        nodes[28].output[0],
-        "",
-        False,
-        scale_value_1,
-        0.0,
-        True,
-    )
-    return nodes, nodes_to_add, new_value_infos
-
-
 _PATTERNS = [
     (_PATTERN_0, _optimize_for_pattern_0),
     (_PATTERN_1, _optimize_for_pattern_1),
-    (_PATTERN_2, _optimize_for_pattern_2),
-    (_PATTERN_3, _optimize_for_pattern_3),
 ]
 
 
diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py
index 539859a0d58a6..1bde07dc29ba9 100644
--- a/orttraining/orttraining/python/training/ortmodule/options.py
+++ b/orttraining/orttraining/python/training/ortmodule/options.py
@@ -196,7 +196,10 @@ class _MemoryOptimizationLevel(IntFlag):
     """Enumeration to specify memory optimization level"""
 
     USER_SPECIFIED = 0  # Fully respect user-specified config
-    TRANSFORMER_LAYERWISE_RECOMPUTE = 1  # Enable all recomputable subgraphs per layer
+    TRANSFORMER_LAYERWISE_RECOMPUTE = (
+        1  # Enable all recomputable subgraphs (excluding compromised recomptable graphs) per layer
+    )
+    TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE = 2  # Enable all recomputable subgraphs per layer
 
     @staticmethod
     def to_string(memory_optimization_level):
@@ -206,6 +209,9 @@ def to_string(memory_optimization_level):
         if memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE:
             return "TRANSFORMER_LAYERWISE_RECOMPUTE"
 
+        if memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE:
+            return "TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE"
+
         return ""
 
 
@@ -271,7 +277,7 @@ def __init__(self, logger: Logger):
         self.enable_sparse_optimizer = True
         self.label_sparsity_ratio = ""
         self.embed_sparsity_ratio = ""
-        self.enable_embedding_sparse_optimizer = False  # TODO(pengwa): remove once validation on more models are done.
+        self.enable_embedding_sparse_optimizer = True
 
         # Configuration for memory optimization.
         self.memory_optimization_level = (
@@ -344,7 +350,10 @@ def _override_from_env_vars(self):
         self.memory_optimization_level = int(os.getenv("ORTMODULE_MEMORY_OPT_LEVEL", self.memory_optimization_level))
         user_given_memory_optimizer_config = os.getenv("ORTMODULE_MEMORY_OPT_CONFIG", self.memory_optimizer_config)
         self.memory_optimizer_config = ",".join([c for c in user_given_memory_optimizer_config.split(",") if c])
-        if self.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE:
+        if self.memory_optimization_level in [
+            _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE,
+            _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE,
+        ]:
             # For transformer layer-wise recompute, we enable layer boundary when detecting subgraphs.
             # Then all detected subgraphs will not cross different layers.
             self.recompute_probe_config = "1:1"
@@ -378,7 +387,6 @@ def _override_from_env_vars(self):
             try:
                 import triton  # noqa: F401
             except ImportError:
-                pass
                 self._logger.warning(
                     "triton library missing. Please install triton with `pip install triton`. Triton feature will be off."
                 )
@@ -419,7 +427,10 @@ def memory_optimizer_is_enabled(self) -> bool:
         """Check whether memory optimizer is enabled."""
         if self.memory_optimization_level == _MemoryOptimizationLevel.USER_SPECIFIED:
             return len(self.memory_optimizer_config) > 0
-        elif self.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE:
+        elif self.memory_optimization_level in [
+            _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE,
+            _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE,
+        ]:
             return True
 
         return False
diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc
index 88e93b26e0e22..d511743c4b698 100644
--- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc
+++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc
@@ -60,9 +60,10 @@ std::vector<PyObject*> custom_function_backward_runner(const char* func_name_cha
         tensor = torch::utils::tensor_fromDLPack(args[arg_index]);
       } else {
         TORCH_CHECK(args[arg_index] == Py_None, "Only None is supported for non-tensor input.");
-        PyObject* fw_kernel_invoke_id = PyObject_GetAttrString(ctx.ptr(), "fw_kernel_invoke_id");
+        py::object fw_kernel_invoke_id = PyObject_FastGetAttrString(ctx.ptr(), "fw_kernel_invoke_id");
+        TORCH_CHECK(fw_kernel_invoke_id.ptr() != nullptr, "fw_kernel_invoke_id is not found in the context.");
         std::string fw_kernel_invoke_id_str =
-            py::cast<std::string>(py::reinterpret_borrow<py::object>(fw_kernel_invoke_id));
+            py::cast<std::string>(fw_kernel_invoke_id);
         CustomFuncOpKernelInfo& fw_kernel_info =
             KernelInfoStore::GetInstance().GetKernelInfoMap().at(fw_kernel_invoke_id_str);
         if (fw_kernel_info.materialize_grads) {
diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc
index 599bdf813907b..3bb5151265eff 100644
--- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc
+++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc
@@ -255,7 +255,7 @@ static py::object get_mockup_context_class() {
       throw std::runtime_error("Fails to import the module.");
     }
 
-    auto python_class = py::reinterpret_steal<py::object>(PyObject_GetAttrString(module.ptr(), "FakeContext"));
+    auto python_class = PyObject_FastGetAttrString(module.ptr(), "FakeContext");
     if (!PyCallable_Check(python_class.ptr())) {
       throw std::runtime_error("Cannot instantiate the Python class");
     }
diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py
index fa72f3b134917..898c242bb3c32 100644
--- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py
+++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py
@@ -23,7 +23,7 @@
     cur_file_dir,
 ]
 
-extra_compile_args = {"cxx": ["-O3"]}
+extra_compile_args = {"cxx": ["-O3", "-std=c++17"]}
 setup(
     name="torch_interop_utils",
     ext_modules=[
diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py
index b4a518d573998..ecfb7d7907f3c 100644
--- a/orttraining/orttraining/python/training/utils/__init__.py
+++ b/orttraining/orttraining/python/training/utils/__init__.py
@@ -12,6 +12,7 @@
     unflatten_data_using_schema,
 )
 from onnxruntime.training.utils.torch_profile_utils import (
+    log_memory_usage,
     nvtx_function_decorator,
     torch_nvtx_range_pop,
     torch_nvtx_range_push,
@@ -31,6 +32,7 @@
     "torch_nvtx_range_push",
     "torch_nvtx_range_pop",
     "nvtx_function_decorator",
+    "log_memory_usage",
     "pytorch_type_to_onnx_dtype",
     "onnx_dtype_to_pytorch_dtype",
     "pytorch_scalar_type_to_pytorch_dtype",
diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py
index 68b78f8df70f1..a8e730488d76d 100644
--- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py
+++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py
@@ -14,6 +14,7 @@
 import torch
 
 from ._subscriber_base import RuntimeStates, SubscriberBase
+from ._subscriber_manager import ORT_NO_INCREASE_GLOBAL_STEP
 
 
 class _InspectActivation(torch.autograd.Function):
@@ -176,21 +177,23 @@ def _summarize_activations(self, tensor: torch.Tensor, depth: int, name: str, st
         display_name = name + " forward run" if is_forward is True else name + " backward run"
         output_file_name = name + "_forward" if is_forward is True else name + "_backward"
 
-        if tensor is None or not isinstance(tensor, torch.Tensor):
-            print(f"{display_name} not a torch tensor, value: {tensor}")
-            return
+        # Skip dump during model pre-export output schema preparison run and export run.
+        if ORT_NO_INCREASE_GLOBAL_STEP[0] is False:
+            if tensor is None or not isinstance(tensor, torch.Tensor):
+                print(f"{display_name} not a torch tensor, value: {tensor}")
+                return
 
-        step_path = Path(step_folder)
-        if not step_path.exists():
-            step_path.mkdir(parents=True, exist_ok=False)
-        order_file_path = step_path / "order.txt"
-        tensor_file_path = step_path / output_file_name
+            step_path = Path(step_folder)
+            if not step_path.exists():
+                step_path.mkdir(parents=True, exist_ok=False)
+            order_file_path = step_path / "order.txt"
+            tensor_file_path = step_path / output_file_name
 
-        with order_file_path.open(mode="a", encoding="utf-8") as f:
-            f.write(f"{output_file_name}\n")
+            with order_file_path.open(mode="a", encoding="utf-8") as f:
+                f.write(f"{output_file_name}\n")
 
-        with tensor_file_path.open(mode="w", encoding="utf-8") as f:
-            _summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size)
+            with tensor_file_path.open(mode="w", encoding="utf-8") as f:
+                _summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size)
 
 
 def _summarize_tensor(
diff --git a/orttraining/orttraining/python/training/utils/torch_profile_utils.py b/orttraining/orttraining/python/training/utils/torch_profile_utils.py
index 382d7dac142fe..9e8a41e0dc7c8 100644
--- a/orttraining/orttraining/python/training/utils/torch_profile_utils.py
+++ b/orttraining/orttraining/python/training/utils/torch_profile_utils.py
@@ -3,6 +3,8 @@
 # Licensed under the MIT License.
 # --------------------------------------------------------------------------
 
+from __future__ import annotations
+
 import torch
 
 
@@ -26,3 +28,77 @@ def wrapped_fn(*args, **kwargs):
         return ret_val
 
     return wrapped_fn
+
+
+def log_memory_usage(cur_phase: str, rank_0_only=True, step_info="", logger=None, module=None):
+    """Log memory usage for the current phase.
+    Args:
+        cur_phase (str): The current phase.
+        rank_0_only (bool, optional): Only log the memory usage for rank 0. Defaults to True.
+        step_info (str, optional): The step information. Defaults to "".
+        logger (logging.Logger, optional): The logger to log the memory usage. Defaults to None, which means print to stdout.
+        module (torch.nn.Module, optional): The module to get parameter, buffer and grad sizes. Defaults to None.
+    """
+    rank = 0
+    if rank_0_only is True:
+        if torch.distributed.is_initialized():
+            rank = torch.distributed.get_rank()
+        if rank != 0:
+            return
+
+    _normalizer_factor = float(1024 * 1024)
+    _normalizer_unit = "MiB"
+
+    def _normalize(mem_size_in_bytes: float | int) -> str:
+        return f"{float(mem_size_in_bytes) / _normalizer_factor:.0f}"
+
+    cur_mem_allocated = _normalize(torch.cuda.memory_allocated())
+    max_mem_allocated = _normalize(torch.cuda.max_memory_allocated())
+    cur_mem_cached = _normalize(torch.cuda.memory_reserved())
+    max_mem_cached = _normalize(torch.cuda.max_memory_reserved())
+    torch_mem_stat = torch.cuda.memory_stats()
+    cur_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0))
+    max_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0))
+
+    mem_stats = [
+        ["phase", cur_phase],
+        ["allocated", cur_mem_allocated],  # current memory allocated for tensors
+        ["max allocated", max_mem_allocated],  # peak memory allocated for tensors
+        ["cached", cur_mem_cached],  # current memory cached for the caching allocator
+        ["max cached", max_mem_cached],  # peak memory cached for caching allocator.
+        ["inactive", cur_mem_inactive],  # amount of inactive, non-releasable memory
+        ["max inactive", max_mem_inactive],  # peak of inactive, non-releasable memory
+    ]
+
+    # Calculate the total size of parameters and gradients in the model
+    if module:
+        param_total_size = 0
+        grad_total_size = 0
+        for p in module.parameters():
+            if p.is_cuda:
+                param_total_size += p.numel() * p.element_size()
+            if p.grad is not None and p.grad.is_cuda:
+                grad_total_size += p.grad.numel() * p.grad.element_size()
+
+        # Calculate the total size of buffers in the model
+        buffer_total_size = 0
+        for b in module.buffers():
+            if b.is_cuda:
+                buffer_total_size += b.numel() * b.element_size()
+
+        mem_stats.extend(
+            [
+                ["param", _normalize(param_total_size)],
+                ["grad", _normalize(grad_total_size)],
+                ["buffer", _normalize(buffer_total_size)],
+            ]
+        )
+
+    summ = f"rank-{rank} {step_info} memory ({_normalizer_unit})"
+    for stat in mem_stats:
+        summ += f" | {stat[0]}: {stat[1]}"
+
+    if logger is None:
+        print(summ)
+    else:
+        logger.info(summ)
diff --git a/orttraining/orttraining/test/external_custom_ops/setup.py b/orttraining/orttraining/test/external_custom_ops/setup.py
index 435b83b818380..29383e3618346 100644
--- a/orttraining/orttraining/test/external_custom_ops/setup.py
+++ b/orttraining/orttraining/test/external_custom_ops/setup.py
@@ -28,9 +28,7 @@ def build_extension(self, ext):
         subprocess.check_call(
             [
                 "cmake",
-                "-DPYBIND11_PYTHON_VERSION={}.{}.{}".format(
-                    sys.version_info.major, sys.version_info.minor, sys.version_info.micro
-                ),
+                f"-DPYBIND11_PYTHON_VERSION={sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
                 f"-Dpybind11_DIR={pybind11.get_cmake_dir()}",
                 f"-DONNX_INCLUDE={os.path.dirname(os.path.dirname(onnx.__file__))}",
                 "-DONNXRUNTIME_EXTERNAL_INCLUDE={}".format(
diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc
index feca94ae27c13..94ca96c68f2ce 100644
--- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc
+++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc
@@ -607,6 +607,10 @@ TEST(GradientCheckerTest, ReduceMeanGrad) {
 
   OpDef op_def_opset13{"ReduceMean", kOnnxDomain, 13};
   RunReductionTests(op_def_opset13);
+
+  // axes is input from opset 18.
+  OpDef op_def_opset18{"ReduceMean", kOnnxDomain, 18};
+  RunReductionTests(op_def_opset18, true, true);
 }
 
 TEST(GradientCheckerTest, ReduceSumGrad) {
@@ -619,6 +623,10 @@ TEST(GradientCheckerTest, ReduceSumGrad) {
   OpDef op_def_13{"ReduceSum", kOnnxDomain, 13};
 
   RunReductionTests(op_def_13, true, true);
+
+  OpDef op_def_18{"ReduceSum", kOnnxDomain, 18};
+
+  RunReductionTests(op_def_18, true, true);
 }
 
 TEST(GradientCheckerTest, ReduceL2Grad) {
@@ -641,6 +649,11 @@ TEST(GradientCheckerTest, ReduceL2Grad) {
                                                            {MakeAttribute("axes", axes)}));
     EXPECT_IS_TINY(max_error);
   }
+
+  // axes is input from opset 18
+  OpDef op_def_18{"ReduceL2", kOnnxDomain, 18};
+
+  RunReductionTests(op_def_18, true, true);
 }
 
 TEST(GradientCheckerTest, ReduceLogSumExpGrad) {
@@ -648,6 +661,10 @@ TEST(GradientCheckerTest, ReduceLogSumExpGrad) {
   OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11};
 
   RunReductionTests(op_def);
+
+  OpDef op_def_opset18{"ReduceLogSumExp", kOnnxDomain, 18};
+
+  RunReductionTests(op_def_opset18, true, true);
 }
 
 TEST(GradientCheckerTest, ReluGrad) {
@@ -698,6 +715,13 @@ TEST(GradientCheckerTest, SplitGrad) {
   ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def_13, {shape}, {{3, 5}, {3, 5}, {3, 5}}, &max_error,
                                                          {MakeAttribute("axis", int64_t(0))}));
   EXPECT_IS_TINY(max_error);
+
+  // opset18 test
+  OpDef op_def_18{"Split", kOnnxDomain, 18};
+  ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def_18, {shape}, {{3, 5}, {3, 5}, {3, 5}}, &max_error,
+                                                         {MakeAttribute("axis", int64_t(0)),
+                                                          MakeAttribute("num_outputs", int64_t(3))}));
+  EXPECT_IS_TINY(max_error);
 }
 
 template <typename T>
@@ -2733,7 +2757,7 @@ TEST(GradientCheckerTest, TileGrad) {
 TEST(GradientCheckerTest, PadGrad) {
   float max_error;
   GradientChecker<float, float, float> gradient_checker;
-  OpDef op_def{"Pad", kOnnxDomain, 11};
+  OpDef op_def{"Pad", kOnnxDomain, 18};
 
   {
     TensorInfo x_info({2, 4}, true);
@@ -2803,7 +2827,7 @@ TEST(GradientCheckerTest, PadGrad) {
 TEST(GradientCheckerTest, ScatterNDGrad) {
   float max_error;
   GradientChecker<float, float, float> gradient_checker;
-  OpDef op_def{"ScatterND", kOnnxDomain, 11};
+  OpDef op_def{"ScatterND", kOnnxDomain, 18};
 
   {
     TensorInfo data_info({8}, true);
@@ -2887,7 +2911,7 @@ TEST(GradientCheckerTest, ScatterNDGrad) {
 TEST(GradientCheckerTest, ScatterElementsGrad) {
   float max_error;
   GradientChecker<float, float, float> gradient_checker;
-  OpDef op_def{"ScatterElements", kOnnxDomain, 13};
+  OpDef op_def{"ScatterElements", kOnnxDomain, 18};
 
   {  // without axis
     TensorInfo data_info({3, 3}, true);
diff --git a/orttraining/orttraining/test/gradient/optimizer_ops_test.cc b/orttraining/orttraining/test/gradient/optimizer_ops_test.cc
index bfb59f1525e47..18c1364f5d1f6 100644
--- a/orttraining/orttraining/test/gradient/optimizer_ops_test.cc
+++ b/orttraining/orttraining/test/gradient/optimizer_ops_test.cc
@@ -144,6 +144,8 @@ TEST(OptimizerTest, AdamBiasCorrection) {
   test.AddOutput<float>("Moment_2_Out", {3}, {1.7400e-04f, 8.9966e-04f, 1.5102e-03f});
   test.AddOutput<float>("W_Out", {3}, {-1.4634f, -0.6416f, -1.2121f});
 
+  test.SetOutputTolerance(0.0001f);
+
   test.AddAttribute("do_bias_correction", static_cast<int64_t>(1));
   test.AddAttribute("weight_decay_mode", static_cast<int64_t>(0));
 
@@ -167,6 +169,8 @@ TEST(OptimizerTest, AdamWeightDecayMode0NoBiasCorrection) {
   test.AddOutput<float>("W_Out", {3}, {-3.6210f, -2.8075f, -3.3723f});
   test.AddOutput<float>("G_Out", {3}, {-3.1576f, -3.1658f, -3.1601f});
 
+  test.SetOutputTolerance(0.0001f);
+
   test.AddAttribute("do_bias_correction", static_cast<int64_t>(0));
   test.AddAttribute("lambda", 0.01f);
   test.AddAttribute("weight_decay_mode", static_cast<int64_t>(0));
@@ -191,6 +195,8 @@ TEST(OptimizerTest, AdamWeightDecayMode0WithBiasCorrection) {
   test.AddOutput<float>("W_Out", {3}, {-1.4587f, -0.6452f, -1.2099f});
   test.AddOutput<float>("G_Out", {3}, {-0.9954f, -1.0036f, -0.9979f});
 
+  test.SetOutputTolerance(0.0001f);
+
   test.AddAttribute("do_bias_correction", static_cast<int64_t>(1));
   test.AddAttribute("lambda", 0.01f);
   test.AddAttribute("weight_decay_mode", static_cast<int64_t>(0));
@@ -214,6 +220,8 @@ TEST(OptimizerTest, AdamWeightDecayMode1NoBiasCorrection) {
   test.AddOutput<float>("Moment_2_Out", {3}, {1.7400e-04f, 8.9966e-04f, 1.5102e-03f});
   test.AddOutput<float>("W_Out", {3}, {-3.5894f, -2.7758f, -3.3406f});
 
+  test.SetOutputTolerance(0.0001f);
+
   test.AddAttribute("do_bias_correction", static_cast<int64_t>(0));
   test.AddAttribute("lambda", 0.01f);
   test.AddAttribute("weight_decay_mode", static_cast<int64_t>(1));
@@ -237,6 +245,8 @@ TEST(OptimizerTest, AdamWeightDecayMode1WithBiasCorrection) {
   test.AddOutput<float>("Moment_2_Out", {3}, {1.7400e-04f, 8.9966e-04f, 1.5102e-03f});
   test.AddOutput<float>("W_Out", {3}, {-1.4488f, -0.6352f, -1.1999f});
 
+  test.SetOutputTolerance(0.0001f);
+
   test.AddAttribute("do_bias_correction", static_cast<int64_t>(1));
   test.AddAttribute("lambda", 0.01f);
   test.AddAttribute("weight_decay_mode", static_cast<int64_t>(1));
@@ -368,6 +378,11 @@ TEST(OptimizerTest, AdamOptimizerMixPrecision_FP16Weight_ClipNorm_Test) {
   test.AddOptionalOutputEdge<MLFloat16>();
   test.AddOutput<MLFloat16>("FP16_W_Out", {3}, w_new_half);
 
+  test.SetOutputAbsErr("Moment_1_Out", 0.005f);
+  test.SetOutputAbsErr("Moment_2_Out", 0.005f);
+  test.SetOutputAbsErr("W_Out", 0.001f);
+  test.SetOutputAbsErr("FP16_W_Out", 0.005f);
+
   test.AddAttribute("do_bias_correction", static_cast<int64_t>(0));
   test.AddAttribute("weight_decay_mode", static_cast<int64_t>(0));
   test.AddAttribute("max_norm_clip", 0.001f);
@@ -617,6 +632,8 @@ void run_lamb_test_with_baseline(
     test.AddOptionalOutputEdge<MLFloat16>();
   }
 
+  test.SetOutputTolerance(0.005f);
+
   test.Run();
 }
 
@@ -737,6 +754,8 @@ void run_multi_tensor_lamb_test_with_baseline(
   test.AddAttribute("ratio_min", ratio_min);
   test.AddAttribute("ratio_max", ratio_max);
 
+  test.SetOutputTolerance(0.005f);
+
   test.Run();
 }
 
diff --git a/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc b/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc
index cf510ea43c89f..509937bdd0c3a 100644
--- a/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc
+++ b/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc
@@ -135,7 +135,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_Allowed) {
       }
     };
 
-    std::vector<int> opsets{12, 13, 14, 15};
+    std::vector<int> opsets{12, 13, 14, 15, 17};
     for (auto opset : opsets) {
       std::unique_ptr<GraphTransformer> transformer =
           std::make_unique<InsertGatherBeforeSceLoss>(compatible_eps, std::vector<std::string>{"label"});
@@ -206,7 +206,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_NotAllowed_LabelNameNotMat
       }
     };
 
-    std::vector<int> opsets{12, 13, 14, 15};
+    std::vector<int> opsets{12, 13, 14, 15, 17};
     for (auto opset : opsets) {
       std::unique_ptr<GraphTransformer> transformer =
           std::make_unique<InsertGatherBeforeSceLoss>(compatible_eps, std::vector<std::string>{"label"});
@@ -277,7 +277,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_NotAllowed_ReduceNone) {
       }
     };
 
-    std::vector<int> opsets{12, 13, 14, 15};
+    std::vector<int> opsets{12, 13, 14, 15, 17};
     for (auto opset : opsets) {
       std::unique_ptr<GraphTransformer> transformer =
           std::make_unique<InsertGatherBeforeSceLoss>(compatible_eps, std::vector<std::string>{"label"});
@@ -344,7 +344,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_NotAllowed_NoIgnoreIndex)
       }
     };
 
-    std::vector<int> opsets{12, 13, 14, 15};
+    std::vector<int> opsets{12, 13, 14, 15, 17};
     for (auto opset : opsets) {
       std::unique_ptr<GraphTransformer> transformer =
           std::make_unique<InsertGatherBeforeSceLoss>(compatible_eps, std::vector<std::string>{"label"});
diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc
index b774fec11cc8d..109937ff96d1d 100644
--- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc
+++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc
@@ -1200,7 +1200,7 @@ TEST_P(QDQFusionTestsParameterized, CheckModelComposition) {
   ASSERT_EQ(op_to_count_post_fusion["com.microsoft.FakeQuant"], 1);
 }
 
-TEST_F(GraphTransformationTests, Conv1dReplacement) {
+TEST_F(GraphTransformationTests, Conv1dReplacement_TakeEffect) {
   auto pre_graph_checker = [&](Graph& graph) {
     auto op_count_map = CountOpsInGraph(graph);
     TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1);
@@ -1208,7 +1208,7 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) {
   };
 
   for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) {
-    for (auto group : {1, 2}) {
+    for (auto group : {1, 2, 4}) {
       auto build_test_case = [&](ModelTestBuilder& builder) {
         auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128);
         auto out_channel = 64;
@@ -1222,6 +1222,8 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) {
         conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{1});
         conv_node.AddAttribute("strides", std::vector<int64_t>{1});
         conv_node.AddAttribute("group", static_cast<int64_t>(group));
+        conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0});
+        conv_node.AddAttribute("auto_pad", "NOTSET");
       };
 
       auto post_graph_checker = [&](Graph& graph) {
@@ -1243,28 +1245,64 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) {
   }
 }
 
-TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) {
+// node has bias input so conv not replaced
+TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect1) {
   auto pre_graph_checker = [&](Graph& graph) {
     auto op_count_map = CountOpsInGraph(graph);
     TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1);
     return Status::OK();
   };
 
-  // "group" is 3 so conv not replaced
   for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) {
     auto build_test_case = [&](ModelTestBuilder& builder) {
       auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128);
       auto out_channel = 64;
       auto* data_arg = builder.MakeInput<float>({{batch_size, in_channel, in_length}});
 
-      auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel / 3, 1}, {-1.0f, 1.0f});
+      auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel, 1}, {-1.0f, 1.0f});
+      auto* bias_arg = builder.MakeInitializer<float>({out_channel}, {-1.0f, 1.0f});
+      auto* conv_output = builder.MakeOutput();
+
+      auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg, bias_arg}, {conv_output});
+      conv_node.AddAttribute("dilations", std::vector<int64_t>{1});
+      conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{1});
+      conv_node.AddAttribute("strides", std::vector<int64_t>{1});
+      conv_node.AddAttribute("group", static_cast<int64_t>(1));
+      conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0});
+      conv_node.AddAttribute("auto_pad", "NOTSET");
+    };
+
+    std::unique_ptr<GraphTransformer> transformer = std::make_unique<Conv1dReplacement>();
+    ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer),
+                                          TransformerLevel::Level1, 1,
+                                          pre_graph_checker, pre_graph_checker));
+  }
+}
+
+// "auto_pad " is not NOTSET so conv not replaced
+TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect2) {
+  auto pre_graph_checker = [&](Graph& graph) {
+    auto op_count_map = CountOpsInGraph(graph);
+    TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1);
+    return Status::OK();
+  };
+
+  for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) {
+    auto build_test_case = [&](ModelTestBuilder& builder) {
+      auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128);
+      auto out_channel = 64;
+      auto* data_arg = builder.MakeInput<float>({{batch_size, in_channel, in_length}});
+
+      auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel, 1}, {-1.0f, 1.0f});
       auto* conv_output = builder.MakeOutput();
 
       auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output});
       conv_node.AddAttribute("dilations", std::vector<int64_t>{1});
       conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{1});
       conv_node.AddAttribute("strides", std::vector<int64_t>{1});
-      conv_node.AddAttribute("group", static_cast<int64_t>(3));
+      conv_node.AddAttribute("group", static_cast<int64_t>(1));
+      conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0});
+      conv_node.AddAttribute("auto_pad", "VALID");
     };
 
     std::unique_ptr<GraphTransformer> transformer = std::make_unique<Conv1dReplacement>();
@@ -1272,8 +1310,16 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) {
                                           TransformerLevel::Level1, 1,
                                           pre_graph_checker, pre_graph_checker));
   }
+}
+
+// pads is not all zero, so conv not replaced
+TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect3) {
+  auto pre_graph_checker = [&](Graph& graph) {
+    auto op_count_map = CountOpsInGraph(graph);
+    TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1);
+    return Status::OK();
+  };
 
-  // "kernel_shape" is not 1 so conv not replaced
   for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) {
     auto build_test_case = [&](ModelTestBuilder& builder) {
       auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128);
@@ -1285,9 +1331,11 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) {
 
       auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output});
       conv_node.AddAttribute("dilations", std::vector<int64_t>{1});
-      conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{2});
+      conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{1});
       conv_node.AddAttribute("strides", std::vector<int64_t>{1});
       conv_node.AddAttribute("group", static_cast<int64_t>(1));
+      conv_node.AddAttribute("pads", std::vector<int64_t>{1, 0});
+      conv_node.AddAttribute("auto_pad", "NOTSET");
     };
 
     std::unique_ptr<GraphTransformer> transformer = std::make_unique<Conv1dReplacement>();
@@ -1523,7 +1571,7 @@ TEST_F(GraphTransformationTests, ScaledSumFusionThreeInputs) {
       builder.AddNode("Identity", {add2_out}, {graph_out});
     };
 
-    const std::vector<int> opsets{12, 13, 14, 15};
+    const std::vector<int> opsets{12, 13, 14, 15, 17};
     for (auto& opset_version : opsets) {
       std::unique_ptr<GraphTransformer> transformer = std::make_unique<ScaledSumFusion>();
       ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer),
@@ -1616,7 +1664,7 @@ TEST_F(GraphTransformationTests, ScaledSumFusionThreeInputs_LastAddNotHaveScaleI
       builder.AddNode("Identity", {add2_out}, {graph_out});
     };
 
-    const std::vector<int> opsets{12, 13, 14, 15};
+    const std::vector<int> opsets{12, 13, 14, 15, 17};
     for (auto& opset_version : opsets) {
       std::unique_ptr<GraphTransformer> transformer = std::make_unique<ScaledSumFusion>();
       ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer),
@@ -1710,7 +1758,7 @@ TEST_F(GraphTransformationTests, ScaledSumFusionTwoInputs) {
       builder.AddNode("Identity", {add1_out}, {graph_output2});
     };
 
-    const std::vector<int> opsets{12, 13, 14, 15};
+    const std::vector<int> opsets{12, 13, 14, 15, 17};
     for (auto& opset_version : opsets) {
       std::unique_ptr<GraphTransformer> transformer = std::make_unique<ScaledSumFusion>();
       ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer),
diff --git a/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc b/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc
index ea05b29c8668b..a1629eb73eeb6 100644
--- a/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc
+++ b/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc
@@ -67,7 +67,7 @@ TEST(ShapeOptimizerTests, Shape15CannotFold) {
     return Status::OK();
   };
 
-  std::vector<int> opset_candidates{15};
+  std::vector<int> opset_candidates{15, 17};
   for (auto opset : opset_candidates) {
     auto build_test_case = [&](ModelTestBuilder& builder) {
       std::vector<std::variant<int64_t, std::string>> identity_input_shape;
@@ -145,7 +145,7 @@ TEST(ShapeOptimizerTests, Shape15) {
     return Status::OK();
   };
 
-  std::vector<int> opset_candidates{15};
+  std::vector<int> opset_candidates{15, 17};
   for (auto opset : opset_candidates) {
     auto build_test_case = [&](ModelTestBuilder& builder) {
       std::vector<std::variant<int64_t, std::string>> identity_input_shape;
@@ -218,7 +218,7 @@ TEST(ShapeOptimizerTests, Shape15TakesGraphInput) {
     return Status::OK();
   };
 
-  std::vector<int> opset_candidates{15};
+  std::vector<int> opset_candidates{15, 17};
   for (auto opset : opset_candidates) {
     auto build_test_case = [&](ModelTestBuilder& builder) {
       std::vector<std::variant<int64_t, std::string>> shape_input_shape;
@@ -289,7 +289,7 @@ TEST(ShapeOptimizerTests, Shape15GeneratesGraphOutput) {
     return Status::OK();
   };
 
-  std::vector<int> opset_candidates{15};
+  std::vector<int> opset_candidates{15, 17};
   for (auto opset : opset_candidates) {
     auto build_test_case = [&](ModelTestBuilder& builder) {
       std::vector<std::variant<int64_t, std::string>> identity_input_shape;
@@ -366,7 +366,7 @@ TEST(ShapeOptimizerTests, Slice) {
     return Status::OK();
   };
 
-  std::vector<int> opset_candidates{10, 11, 12, 13, 14, 15};
+  std::vector<int> opset_candidates{10, 11, 12, 13, 14, 15, 17};
   for (auto opset : opset_candidates) {
     auto build_test_case = [&](ModelTestBuilder& builder) {
       std::vector<std::variant<int64_t, std::string>> shape_input_shape;
@@ -446,7 +446,7 @@ TEST(ShapeOptimizerTests, SliceGeneratesGraphOutput) {
     return Status::OK();
   };
 
-  std::vector<int> opset_candidates{10, 11, 12, 13, 14, 15};
+  std::vector<int> opset_candidates{10, 11, 12, 13, 14, 15, 17};
   for (auto opset : opset_candidates) {
     auto build_test_case = [&](ModelTestBuilder& builder) {
       std::vector<std::variant<int64_t, std::string>> shape_input_shape;
@@ -530,7 +530,7 @@ TEST(ShapeOptimizerTests, Gather) {
     return Status::OK();
   };
 
-  std::vector<int> opset_candidates{10, 11, 12, 13, 14, 15};
+  std::vector<int> opset_candidates{10, 11, 12, 13, 14, 15, 17};
   for (auto opset : opset_candidates) {
     auto build_test_case = [&](ModelTestBuilder& builder) {
       std::vector<std::variant<int64_t, std::string>> shape_input_shape;
@@ -639,7 +639,7 @@ TEST(ShapeOptimizerTests, ConcreteDimUsedBySlice) {
     return Status::OK();
   };
 
-  std::vector<int> opset_candidates{10, 11, 12, 13, 14, 15};
+  std::vector<int> opset_candidates{10, 11, 12, 13, 14, 15, 17};
   for (auto opset : opset_candidates) {
     auto build_test_case = [&](ModelTestBuilder& builder) {
       std::vector<std::variant<int64_t, std::string>> dropout_input_shape;
@@ -810,7 +810,7 @@ TEST(ShapeOptimizerTests, ConcreteDimUsedByGatherSlice) {
     return Status::OK();
   };
 
-  std::vector<int> opset_candidates{10, 11, 12, 13, 14, 15};
+  std::vector<int> opset_candidates{10, 11, 12, 13, 14, 15, 17};
   for (auto opset : opset_candidates) {
     auto build_test_case = [&](ModelTestBuilder& builder) {
       std::vector<std::variant<int64_t, std::string>> reshape_input_shape;
@@ -976,7 +976,7 @@ TEST(ShapeOptimizerTests, SymbolicDimUsedByGather_ConcreteDimUsedByGather) {
     return Status::OK();
   };
 
-  std::vector<int> opset_candidates{10, 11, 12, 13, 14, 15};
+  std::vector<int> opset_candidates{10, 11, 12, 13, 14, 15, 17};
   for (auto opset : opset_candidates) {
     auto build_test_case = [&](ModelTestBuilder& builder) {
       std::vector<std::variant<int64_t, std::string>> reshape_input_shape;
diff --git a/orttraining/orttraining/test/python/_test_commons.py b/orttraining/orttraining/test/python/_test_commons.py
index fb7e62551de63..762c4c4d55f9f 100644
--- a/orttraining/orttraining/test/python/_test_commons.py
+++ b/orttraining/orttraining/test/python/_test_commons.py
@@ -25,5 +25,5 @@ def run_subprocess(args, cwd=None, capture=False, dll_path=None, shell=False, en
     completed_process = subprocess.run(args, cwd=cwd, check=True, stdout=stdout, stderr=stderr, env=my_env, shell=shell)
 
     if log:
-        log.debug("Subprocess completed. Return code=" + str(completed_process.returncode))
+        log.debug("Subprocess completed. Return code=%s", completed_process.returncode)
     return completed_process
diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py
index 8f2a18b5ec00b..65043c10d8a01 100644
--- a/orttraining/orttraining/test/python/_test_helpers.py
+++ b/orttraining/orttraining/test/python/_test_helpers.py
@@ -288,7 +288,6 @@ def cpu_barrier_func():
 
     def cuda_barrier_func():
         torch.cuda.synchronize()
-        pass
 
     cuda = torch.device("cuda:0")
     run_evaluate_test_on_device_and_compare(
diff --git a/orttraining/orttraining/test/python/orttraining_test_gru.py b/orttraining/orttraining/test/python/orttraining_test_gru.py
index fcb7e13b1694f..c9e22bf7384af 100644
--- a/orttraining/orttraining/test/python/orttraining_test_gru.py
+++ b/orttraining/orttraining/test/python/orttraining_test_gru.py
@@ -355,9 +355,7 @@ def backward_np(
                 prev_h = (
                     all_hidden_states[t - 1, 0, idx, :]
                     if t > 0
-                    else initial_hidden_state[0, idx, :]
-                    if initial_hidden_state is not None
-                    else 0
+                    else initial_hidden_state[0, idx, :] if initial_hidden_state is not None else 0
                 )
 
                 grad_update_gate = (prev_h - hidden_gate) * grad_h
diff --git a/orttraining/orttraining/test/python/orttraining_test_lstm.py b/orttraining/orttraining/test/python/orttraining_test_lstm.py
index 2b296cf70c2c1..4debe73951b2f 100644
--- a/orttraining/orttraining/test/python/orttraining_test_lstm.py
+++ b/orttraining/orttraining/test/python/orttraining_test_lstm.py
@@ -480,9 +480,7 @@ def backward_np(
                 grad_forget_gate = grad_c * (
                     all_cell_states[t - 1, 0, idx, :]
                     if t > 0
-                    else initial_cell_state[0, idx, :]
-                    if initial_cell_state is not None
-                    else 0
+                    else initial_cell_state[0, idx, :] if initial_cell_state is not None else 0
                 )
                 grad_control_gate = grad_c * input_gate
 
@@ -522,9 +520,7 @@ def backward_np(
                 prev_h = (
                     all_hidden_states[t - 1, 0, idx, :]
                     if t > 0
-                    else initial_hidden_state[0, idx, :]
-                    if initial_hidden_state is not None
-                    else 0
+                    else initial_hidden_state[0, idx, :] if initial_hidden_state is not None else 0
                 )
                 grad_recurrence_weights[0, : self._hidden_size, :] += np.dot(
                     np.expand_dims(grad_input_activation, axis=0).T, np.expand_dims(prev_h, axis=0)
@@ -553,9 +549,7 @@ def backward_np(
                     grad_peephole_weights[0, : self._hidden_size] += grad_input_activation * (
                         all_cell_states[t - 1, 0, idx, :]
                         if t > 0
-                        else initial_cell_state[0, idx, :]
-                        if initial_cell_state is not None
-                        else 0
+                        else initial_cell_state[0, idx, :] if initial_cell_state is not None else 0
                     )
                     grad_peephole_weights[0, self._hidden_size : 2 * self._hidden_size] += (
                         grad_output_activation * all_cell_states[t, 0, idx, :]
@@ -565,9 +559,7 @@ def backward_np(
                     ] += grad_forget_activation * (
                         all_cell_states[t - 1, 0, idx, :]
                         if t > 0
-                        else initial_cell_state[0, idx, :]
-                        if initial_cell_state is not None
-                        else 0
+                        else initial_cell_state[0, idx, :] if initial_cell_state is not None else 0
                     )
 
                 grad_c = grad_prev_c
diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py
index 3d41c8678278c..ac49c1c2834c7 100644
--- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py
@@ -190,9 +190,11 @@ def _get_training_ort_inputs(x, target, pt_model, onnx_model, target_type=None):
 
     ort_inputs = {
         onnx_model.graph.input[0].name: _to_numpy(copy.deepcopy(x)),
-        onnx_model.graph.input[1].name: _to_numpy(copy.deepcopy(target))
-        if target_type is None
-        else _to_numpy(copy.deepcopy(target).type(target_type)),
+        onnx_model.graph.input[1].name: (
+            _to_numpy(copy.deepcopy(target))
+            if target_type is None
+            else _to_numpy(copy.deepcopy(target).type(target_type))
+        ),
     }
     if target_type is not None:
         ort_inputs[onnx_model.graph.input[1].name]
@@ -1070,3 +1072,30 @@ def test_save_nominal_checkpoint():
             os.stat(os.path.join(temp_dir, "checkpoint")).st_size
             > os.stat(os.path.join(temp_dir, "nominal_checkpoint")).st_size
         )
+
+
+def test_custom_optimizer_block():
+    device = "cpu"
+    batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
+    _, base_model = _get_models(device, batch_size, input_size, hidden_size, output_size)
+    weight_decay = 123
+    optimizer = onnxblock.optim.AdamW(weight_decay=weight_decay)
+
+    with tempfile.TemporaryDirectory() as temp_dir:
+        artifacts.generate_artifacts(
+            base_model,
+            requires_grad=["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"],
+            loss=artifacts.LossType.CrossEntropyLoss,
+            optimizer=optimizer,
+            artifact_directory=temp_dir,
+        )
+
+        assert os.path.exists(os.path.join(temp_dir, "checkpoint"))
+        assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx"))
+
+        optimizer_model = onnx.load(os.path.join(temp_dir, "optimizer_model.onnx"))
+        for node in optimizer_model.graph.node:
+            if node.op_type == "AdamW":
+                for attr in node.attribute:
+                    if attr.name == "weight_decay":
+                        assert attr.f == weight_decay
diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py
index 51aa1564cbfbe..7afad9145ed27 100644
--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py
@@ -34,7 +34,7 @@
 from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient
 from onnxruntime.training.ortmodule.options import _SkipCheck
 
-DEFAULT_OPSET = 15
+DEFAULT_OPSET = 17
 
 
 # PyTorch model definitions for tests
@@ -417,24 +417,38 @@ def _get_bert_for_sequence_classification_model(
     return model
 
 
-def _get_bert_for_sequence_classification_sample_data(device):
-    """Returns sample data to be used with BertForSequenceClassification model"""
+def _generate_attention_mask_for_encoder_following_hf(batch_size, seq_length, device, past_key_values_length=0):
+    """Generate attention mask for encoder following the implementation in HuggingFace.
 
-    input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
-    input_mask = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
-    labels = torch.randint(0, 1, (32,), dtype=torch.long, device=device)
+    Be noted: past_key_values_length is 0 for training.
 
-    return input_ids, input_mask, labels
+    Generate mask using this
+        https://github.com/huggingface/transformers/blame/4f27ee936a861f56f32ea6db138978b274008006/src/transformers/models/bert/modeling_bert.py#L974C81-L974C81
+
+    """
+
+    attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+    return attention_mask
 
 
 def _get_bert_for_sequence_classification_sample_data_with_random_shapes(device):
     """Returns sample data with random shape to be used with BertForSequenceClassification model"""
 
-    x = random.randint(1, 100)
-    y = random.randint(1, 100)
-    input_ids = torch.randint(0, 100, (x, y), dtype=torch.long, device=device)
-    input_mask = torch.randint(0, 100, (x, y), dtype=torch.long, device=device)
-    labels = torch.randint(0, 1, (x,), dtype=torch.long, device=device)
+    bsz = random.randint(1, 100)
+    seq_length = random.randint(1, 100)
+    input_ids = torch.randint(0, 100, (bsz, seq_length), dtype=torch.long, device=device)
+    input_mask = _generate_attention_mask_for_encoder_following_hf(bsz, seq_length, device)
+    labels = torch.randint(0, 1, (bsz,), dtype=torch.long, device=device)
+
+    return input_ids, input_mask, labels
+
+
+def _get_bert_for_sequence_classification_sample_data(device):
+    """Returns sample data to be used with BertForSequenceClassification model"""
+
+    input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
+    input_mask = _generate_attention_mask_for_encoder_following_hf(32, 64, device)
+    labels = torch.randint(0, 1, (32,), dtype=torch.long, device=device)
 
     return input_ids, input_mask, labels
 
@@ -2211,32 +2225,27 @@ def run_step(model, x):
         _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
 
 
-# TODO(askhade): This test is failing with smaller tolerance, need to investigate! Disabling it right now to
-# unblock the move to a later version of transformers to resolve security vulnerability.
-# (Moving from transformers v4.4.2 to v4.30.0)
-# def test_bert_inputs_with_dynamic_shape():
-#     # create pytorch model with dropout disabled
-#     pt_model = _get_bert_for_sequence_classification_model(
-#         "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
-#     )
-#     ort_model = ORTModule(copy.deepcopy(pt_model))
+def test_bert_inputs_with_dynamic_shape():
+    # create pytorch model with dropout disabled
+    pt_model = _get_bert_for_sequence_classification_model(
+        "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
+    )
+    ort_model = ORTModule(copy.deepcopy(pt_model))
 
-#     def run_step(model, x, y, z):
-#         outputs = model(x, y, None, None, None, None, z)
-#         loss = outputs[0]
-#         loss.backward()
-#         return outputs[0]
+    def run_step(model, x, y, z):
+        outputs = model(x, y, None, None, None, None, z)
+        loss = outputs[0]
+        loss.backward()
+        return outputs[0]
 
-#     for _step in range(10):
-#         x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda")
+    for _step in range(10):
+        x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda")
 
-#         pt_p = run_step(pt_model, x, y, z)
-#         ort_p = run_step(ort_model, x, y, z)
+        pt_p = run_step(pt_model, x, y, z)
+        ort_p = run_step(ort_model, x, y, z)
 
-#         _test_helpers.assert_values_are_close(
-#             ort_p, pt_p, atol=1e-01
-#         )  # TODO: this assert is failing with smaller tolerance, need to investigate!!
-#         # _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)  #TODO - enable this check after the investigation
+        _test_helpers.assert_values_are_close(ort_p, pt_p, atol=1e-01)
+        _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
 
 
 @pytest.mark.parametrize("device", ["cuda", "cpu"])
@@ -3788,7 +3797,7 @@ def forward(self, input1=None, input2=None):
             model.eval()
 
         # Must work because forward() and dict order match
-        y1, y2 = model(**{"input1": input1, "input2": input2})
+        y1, y2 = model(input1=input1, input2=input2)
         assert y1 is not None
         assert y2 is not None
         if model._is_training():
@@ -3796,7 +3805,7 @@ def forward(self, input1=None, input2=None):
             loss.backward()
 
         # Must work even when forward() and dict order mismatch
-        y1, y2 = model(**{"input2": input2, "input1": input1})
+        y1, y2 = model(input2=input2, input1=input1)
         assert y1 is not None
         assert y2 is not None
         if model._is_training():
@@ -3878,17 +3887,20 @@ def run_step(expected, a, b, c, d, e, f, y, z):
             None,
             None,
         )
-        run_step(
-            a.item() + f.item(), **{"a": a, "b": None, "c": None, "d": None, "e": None, "f": f, "y": None, "z": None}
-        )
+        run_step(a.item() + f.item(), a=a, b=None, c=None, d=None, e=None, f=f, y=None, z=None)
         run_step(a.item() + z.item(), a, None, None, None, None, None, None, z)
-        run_step(
-            a.item() + z.item(), **{"a": a, "b": None, "c": None, "d": None, "e": None, "f": None, "y": None, "z": z}
-        )
+        run_step(a.item() + z.item(), a=a, b=None, c=None, d=None, e=None, f=None, y=None, z=z)
         run_step(a.item() + c.item() + y.item(), a, None, c, None, None, None, y, None)
         run_step(
             a.item() + c.item() + y.item(),
-            **{"a": a, "b": None, "c": c, "d": None, "e": None, "f": None, "y": y, "z": None},
+            a=a,
+            b=None,
+            c=c,
+            d=None,
+            e=None,
+            f=None,
+            y=y,
+            z=None,
         )
         run_step(
             a.item() + b.item() + c.item() + d.item() + e.item() + f.item() + y.item() + z.item(),
@@ -3903,7 +3915,14 @@ def run_step(expected, a, b, c, d, e, f, y, z):
         )
         run_step(
             a.item() + b.item() + c.item() + d.item() + e.item() + f.item() + y.item() + z.item(),
-            **{"a": a, "b": b, "c": c, "d": d, "e": e, "f": f, "y": y, "z": z},
+            a=a,
+            b=b,
+            c=c,
+            d=d,
+            e=e,
+            f=f,
+            y=y,
+            z=z,
         )
 
     del os.environ["ORTMODULE_SKIPCHECK_POLICY"]
@@ -5280,7 +5299,7 @@ def run_step(model, x):
     assert ort_model._torch_module._execution_manager(True)._runtime_options.onnx_opset_version == 13
 
 
-@pytest.mark.parametrize("opset_version", [12, 13, 14, 15])
+@pytest.mark.parametrize("opset_version", [12, 13, 14, 15, 17])
 def test_opset_version_change(opset_version):
     original_env = None
     if "ORTMODULE_ONNX_OPSET_VERSION" in os.environ:
@@ -6424,9 +6443,6 @@ def run_step(model, x):
         del os.environ["ORTMODULE_CONV_ALGO_SEARCH"]
 
 
-@pytest.mark.skip(
-    reason="This test fail because bert forward loss is nan in updated transformers lib, disable for now."
-)
 def test_bert_result_with_layerwise_recompute():
     original_val = os.environ.get("ORTMODULE_MEMORY_OPT_LEVEL", None)
     # Create PyTorch model with dropout disabled.
diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py
index 3d92e0b323c19..a1a7d4660f266 100644
--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py
@@ -441,7 +441,7 @@ def main():
 
     # 4. Train loop (fine-tune)
     total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0
-    for epoch_i in range(0, args.epochs):
+    for epoch_i in range(args.epochs):
         total_training_time += train(model, optimizer, scheduler, train_dataloader, epoch_i, device, args)
         if not args.pytorch_only and epoch_i == 0:
             epoch_0_training = total_training_time
diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py
index 87c8e66231a29..0d5aba1a1a5c4 100644
--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py
@@ -446,7 +446,7 @@ def main():
 
     # 4. Train loop (fine-tune)
     total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0
-    for epoch_i in range(0, args.epochs):
+    for epoch_i in range(args.epochs):
         total_training_time += train(model, optimizer, scaler, scheduler, train_dataloader, epoch_i, device, args)
         if not args.pytorch_only and epoch_i == 0:
             epoch_0_training = total_training_time
diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py
index 86e8d9aea1d37..5b28e9c52b480 100644
--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py
@@ -8,6 +8,7 @@
     --deepspeed_config=orttraining_test_ortmodule_deepspeed_zero_stage_1_config.json
 ```
 """
+
 import argparse
 import time
 
@@ -36,11 +37,7 @@ def forward(self, input1):
 
 
 def train(args, model, device, optimizer, loss_fn, train_loader, epoch):
-    print(
-        "\n======== Epoch {:} / {:} with batch size {:} ========".format(
-            epoch + 1, args.epochs, model.train_batch_size()
-        )
-    )
+    print(f"\n======== Epoch {epoch + 1} / {args.epochs} with batch size {model.train_batch_size()} ========")
     model.train()
     # Measure how long the training epoch takes.
     t0 = time.time()
@@ -77,13 +74,7 @@ def train(args, model, device, optimizer, loss_fn, train_loader, epoch):
             curr_time = time.time()
             elapsed_time = curr_time - start_time
             print(
-                "[{:5}/{:5} ({:2.0f}%)]\tLoss: {:.6f}\tExecution time: {:.4f}".format(
-                    iteration * len(data),
-                    len(train_loader.dataset),
-                    100.0 * iteration / len(train_loader),
-                    loss,
-                    elapsed_time,
-                )
+                f"[{iteration * len(data):5}/{len(train_loader.dataset):5} ({100.0 * iteration / len(train_loader):2.0f}%)]\tLoss: {loss:.6f}\tExecution time: {elapsed_time:.4f}"
             )
             start_time = curr_time
 
@@ -115,13 +106,7 @@ def test(args, model, device, loss_fn, test_loader):
             correct += pred.eq(target.view_as(pred)).sum().item()
     test_loss /= len(test_loader.dataset)
     print(
-        "\nTest set: Batch size: {:}, Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
-            args.test_batch_size,
-            test_loss,
-            correct,
-            len(test_loader.dataset),
-            100.0 * correct / len(test_loader.dataset),
-        )
+        f"\nTest set: Batch size: {args.test_batch_size}, Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100.0 * correct / len(test_loader.dataset):.0f}%)\n"
     )
 
     # Report the final accuracy for this validation run.
@@ -251,7 +236,7 @@ def main():
 
     # Train loop
     total_training_time, total_test_time, epoch_0_training = 0, 0, 0
-    for epoch in range(0, args.epochs):
+    for epoch in range(args.epochs):
         total_training_time += train(args, model, device, optimizer, my_loss, train_loader, epoch)
         if not args.pytorch_only and epoch == 0:
             epoch_0_training = total_training_time
diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fairscale_sharded_optimizer.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fairscale_sharded_optimizer.py
index 53e1928e2d2f3..4437611283122 100755
--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fairscale_sharded_optimizer.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fairscale_sharded_optimizer.py
@@ -123,13 +123,7 @@ def train_step(args, model, device, optimizer, loss_fn, train_loader, epoch):
             curr_time = time.time()
             elapsed_time = curr_time - start_time
             print(
-                "[{:5}/{:5} ({:2.0f}%)]\tLoss: {:.6f}\tExecution time: {:.4f}".format(
-                    iteration * len(data),
-                    len(train_loader.dataset),
-                    100.0 * iteration / len(train_loader),
-                    loss,
-                    elapsed_time,
-                )
+                f"[{iteration * len(data):5}/{len(train_loader.dataset):5} ({100.0 * iteration / len(train_loader):2.0f}%)]\tLoss: {loss:.6f}\tExecution time: {elapsed_time:.4f}"
             )
             start_time = curr_time
 
@@ -160,13 +154,7 @@ def test(args, model, device, loss_fn, test_loader):
             correct += pred.eq(target.view_as(pred)).sum().item()
     test_loss /= len(test_loader.dataset)
     print(
-        "\nTest set: Batch size: {:}, Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
-            args.test_batch_size,
-            test_loss,
-            correct,
-            len(test_loader.dataset),
-            100.0 * correct / len(test_loader.dataset),
-        )
+        f"\nTest set: Batch size: {args.test_batch_size}, Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100.0 * correct / len(test_loader.dataset):.0f}%)\n"
     )
 
     # Report the final accuracy for this validation run.
diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py
index 4f0925c5c855b..df0b5f195f0b9 100644
--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py
@@ -1,6 +1,7 @@
 """
 @brief      test log(time=3s)
 """
+
 import copy
 import unittest
 
@@ -79,7 +80,7 @@ def run_step(model, x):
         for onnx_model in [onnx_graph_inf, onnx_graph_train]:
             for oimp in onnx_model.opset_import:
                 if oimp.domain == "":
-                    self.assertEqual(oimp.version, 15)
+                    self.assertEqual(oimp.version, 17)  # Needs to match latest default ORTModule opset
         if op_grad_type is not None:
             if isinstance(op_grad_type, tuple):
                 text = str(onnx_graph_train)
diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py
index 1cb0b3626e54e..d6f84d94c2838 100644
--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py
@@ -64,13 +64,7 @@ def train(args, model, device, optimizer, loss_fn, train_loader, epoch):
             curr_time = time.time()
             elapsed_time = curr_time - start_time
             print(
-                "[{:5}/{:5} ({:2.0f}%)]\tLoss: {:.6f}\tExecution time: {:.4f}".format(
-                    iteration * len(data),
-                    len(train_loader.dataset),
-                    100.0 * iteration / len(train_loader),
-                    loss,
-                    elapsed_time,
-                )
+                f"[{iteration * len(data):5}/{len(train_loader.dataset):5} ({100.0 * iteration / len(train_loader):2.0f}%)]\tLoss: {loss:.6f}\tExecution time: {elapsed_time:.4f}"
             )
             start_time = curr_time
 
@@ -102,13 +96,7 @@ def test(args, model, device, loss_fn, test_loader):
             correct += pred.eq(target.view_as(pred)).sum().item()
     test_loss /= len(test_loader.dataset)
     print(
-        "\nTest set: Batch size: {:}, Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
-            args.test_batch_size,
-            test_loss,
-            correct,
-            len(test_loader.dataset),
-            100.0 * correct / len(test_loader.dataset),
-        )
+        f"\nTest set: Batch size: {args.test_batch_size}, Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100.0 * correct / len(test_loader.dataset):.0f}%)\n"
     )
 
     # Report the final accuracy for this validation run.
@@ -221,7 +209,7 @@ def main():
 
     # Train loop
     total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0
-    for epoch in range(0, args.epochs):
+    for epoch in range(args.epochs):
         total_training_time += train(args, model, device, optimizer, my_loss, train_loader, epoch)
         if not args.pytorch_only and epoch == 0:
             epoch_0_training = total_training_time
diff --git a/orttraining/orttraining/test/python/qat_poc_example/README.md b/orttraining/orttraining/test/python/qat_poc_example/README.md
index 6840e98bd9c86..05072b410b730 100644
--- a/orttraining/orttraining/test/python/qat_poc_example/README.md
+++ b/orttraining/orttraining/test/python/qat_poc_example/README.md
@@ -48,7 +48,7 @@ We use `onnxruntime.training.onnxblock` to perform the above operations to get t
 
 > **_NOTE:_**  As of this writing, ORT does not have its own `"Observers"`. Instead, we rely on the `onnxruntime.quantization` tool to quantize the model and give us an initial estimate of the quantization parameters using its calibration process. Here the calibration process is used as a substitute for the observers to present the POC.
 
-> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag AddQDQPairToWeight=True`
+> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag `AddQDQPairToWeight=True`
 
 > **_NOTE:_**  Typically, the bias term in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since it is quantized as int32 as opposed to int8. So, we disable quantizing the bias term using the flag QuantizeBias=False`
 
diff --git a/orttraining/orttraining/test/python/qat_poc_example/model.py b/orttraining/orttraining/test/python/qat_poc_example/model.py
index 91d7ccd7294f5..601362a59e379 100644
--- a/orttraining/orttraining/test/python/qat_poc_example/model.py
+++ b/orttraining/orttraining/test/python/qat_poc_example/model.py
@@ -5,7 +5,7 @@
 import onnx
 import torch
 
-import onnxruntime.training.onnxblock as onnxblock
+from onnxruntime.training import artifacts
 
 
 class MNIST(torch.nn.Module):
@@ -96,42 +96,26 @@ def create_training_artifacts(model_path, artifacts_dir, model_prefix):
     4. The checkpoint file
     """
 
-    class MNISTWithLoss(onnxblock.TrainingModel):
-        def __init__(self):
-            super().__init__()
-            self.loss = onnxblock.loss.CrossEntropyLoss()
-
-        def build(self, output_name):
-            return self.loss(output_name)
-
-    mnist_with_loss = MNISTWithLoss()
-    onnx_model, eval_model, optimizer_model = onnx.load(model_path), None, None
-
-    # Build the training and eval graphs
-    logging.info("Using onnxblock to create the training artifacts.")
-    with onnxblock.onnx_model(onnx_model) as model_accessor:
-        _ = mnist_with_loss(onnx_model.graph.output[0].name)
-        eval_model = model_accessor.eval_model
-
-    # Build the optimizer graph
-    optimizer = onnxblock.optim.AdamW()
-    with onnxblock.onnx_model() as accessor:
-        _ = optimizer(mnist_with_loss.parameters())
-        optimizer_model = accessor.model
+    onnx_model = onnx.load(model_path)
+
+    requires_grad = [
+        param.name
+        for param in onnx_model.graph.initializer
+        if (not param.name.endswith("_scale") and not param.name.endswith("_zero_point"))
+    ]
+    artifacts.generate_artifacts(
+        onnx_model,
+        requires_grad=requires_grad,
+        loss=artifacts.LossType.CrossEntropyLoss,
+        optimizer=artifacts.OptimType.AdamW,
+        artifact_directory=artifacts_dir,
+        prefix=model_prefix,
+    )
 
     # Create the training artifacts
-    train_model_path = os.path.join(artifacts_dir, f"{model_prefix}_train.onnx")
-    logging.info(f"Saving the training model to {train_model_path}.")
-    onnx.save(onnx_model, train_model_path)
-    eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}_eval.onnx")
-    logging.info(f"Saving the eval model to {eval_model_path}.")
-    onnx.save(eval_model, eval_model_path)
-    optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}_optimizer.onnx")
-    logging.info(f"Saving the optimizer model to {optimizer_model_path}.")
-    onnx.save(optimizer_model, optimizer_model_path)
-    trainable_params, non_trainable_params = mnist_with_loss.parameters()
-    checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}_checkpoint.ckpt")
-    logging.info(f"Saving the checkpoint to {checkpoint_path}.")
-    onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_path)
+    train_model_path = os.path.join(artifacts_dir, f"{model_prefix}training_model.onnx")
+    eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}eval_model.onnx")
+    optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}optimizer_model.onnx")
+    checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}checkpoint")
 
     return train_model_path, eval_model_path, optimizer_model_path, checkpoint_path
diff --git a/orttraining/orttraining/test/python/qat_poc_example/qat.py b/orttraining/orttraining/test/python/qat_poc_example/qat.py
index 51a15475ee911..dcc9e116fda7d 100644
--- a/orttraining/orttraining/test/python/qat_poc_example/qat.py
+++ b/orttraining/orttraining/test/python/qat_poc_example/qat.py
@@ -46,7 +46,7 @@
     )
 
     logging.info("Preparing the training artifacts for QAT.")
-    training_model_name = "mnist_qat"
+    training_model_name = "mnist_qat_"
     artifacts_dir = os.path.join(model_dir, "training_artifacts")
     utils.makedir(artifacts_dir)
     training_artifacts = create_training_artifacts(
diff --git a/orttraining/orttraining/test/python/qat_poc_example/quantize.py b/orttraining/orttraining/test/python/qat_poc_example/quantize.py
index 6d9ea284fd3ef..225fb2f8e81b4 100644
--- a/orttraining/orttraining/test/python/qat_poc_example/quantize.py
+++ b/orttraining/orttraining/test/python/qat_poc_example/quantize.py
@@ -53,7 +53,7 @@ def quantize_static(input_model_dir, output_model_dir):
     logging.info(
         "Invoking onnxruntime.quantization.quantize_static with AddQDQPairToWeight=True and QuantizeBias=False.."
     )
-    logging.info("Quantized model will be saved to %s." % output_model_dir)
+    logging.info("Quantized model will be saved to %s.", output_model_dir)
     quantization.quantize_static(
         input_model_dir,
         output_model_dir,
diff --git a/orttraining/orttraining/test/python/qat_poc_example/train.py b/orttraining/orttraining/test/python/qat_poc_example/train.py
index 9a429d2adc6f1..a25c071c58a48 100644
--- a/orttraining/orttraining/test/python/qat_poc_example/train.py
+++ b/orttraining/orttraining/test/python/qat_poc_example/train.py
@@ -26,14 +26,10 @@ def _train_epoch(model, optimizer, train_loader):
     model.train()
     cumulative_loss = 0
     for data, target in train_loader:
-        forward_inputs = [
-            data.reshape(len(data), 784).numpy(),
-            target.numpy().astype(np.int32),
-        ]
-        train_loss = model(forward_inputs)
+        train_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64))
         optimizer.step()
         model.lazy_reset_grad()
-        cumulative_loss += train_loss[0]
+        cumulative_loss += train_loss
 
     return cumulative_loss / len(train_loader)
 
@@ -43,12 +39,8 @@ def _eval(model, test_loader):
     model.eval()
     cumulative_loss = 0
     for data, target in test_loader:
-        forward_inputs = [
-            data.reshape(len(data), 784).numpy(),
-            target.numpy().astype(np.int32),
-        ]
-        test_loss = model(forward_inputs)
-        cumulative_loss += test_loss[0]
+        test_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64))
+        cumulative_loss += test_loss
 
     return cumulative_loss / len(test_loader)
 
@@ -65,7 +57,7 @@ def train_model(qat_train_model, qat_eval_model, qat_optimizer_model, qat_checkp
     train_loader, test_loader = _get_dataloaders("data", batch_size)
 
     # Load the checkpoint state.
-    state = orttraining.CheckpointState(qat_checkpoint)
+    state = orttraining.CheckpointState.load_checkpoint(qat_checkpoint)
 
     # Create the training module.
     model = orttraining.Module(qat_train_model, state, qat_eval_model)
diff --git a/orttraining/orttraining/test/training_ops/cpu/nn/batchnorm_internal_test.cc b/orttraining/orttraining/test/training_ops/cpu/nn/batchnorm_internal_test.cc
index e9795a24681cb..e89883bfd4d94 100644
--- a/orttraining/orttraining/test/training_ops/cpu/nn/batchnorm_internal_test.cc
+++ b/orttraining/orttraining/test/training_ops/cpu/nn/batchnorm_internal_test.cc
@@ -37,6 +37,8 @@ TEST(BatchNormInternalTest, ForwardTrainingTest) {
   test.AddOutput<float>("saved_mean", channel_dims, {-0.306f, 0.114562f});
   test.AddOutput<float>("saved_inv_std", channel_dims, {1.2288f, 0.861317f});
 
+  test.SetOutputTolerance(0.0001f);
+
   std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
   execution_providers.emplace_back(DefaultCpuExecutionProvider());
 
diff --git a/orttraining/orttraining/test/training_ops/cuda/batch_norm_internal_test.cc b/orttraining/orttraining/test/training_ops/cuda/batch_norm_internal_test.cc
index 6335a666e0381..d842d4f1ea736 100644
--- a/orttraining/orttraining/test/training_ops/cuda/batch_norm_internal_test.cc
+++ b/orttraining/orttraining/test/training_ops/cuda/batch_norm_internal_test.cc
@@ -68,6 +68,7 @@ static void TestBatchNormInternal(bool test_double = false, bool T_is_half = fal
     test.AddOutput<double>("running_var", channel_dims, running_var_double);
     test.AddOutput<double>("saved_mean", channel_dims, saved_mean_double);
     test.AddOutput<double>("saved_inv_std", channel_dims, saved_inv_std_double);
+    test.SetOutputTolerance(0.0001f);
   } else {
     if (T_is_half) {
       std::vector<MLFloat16> X_half(X.size());
diff --git a/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc b/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc
index d9800ce0e0d3e..d36f9b307ec70 100644
--- a/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc
+++ b/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc
@@ -311,11 +311,9 @@ template <typename T, typename TOut>
 static std::vector<OrtValue> RunSCELossWithEP(const char* op,
                                               int opset_version,
                                               const char* domain,
-                                              std::function<std::unique_ptr<IExecutionProvider>()>
-                                                  ep_creator,
+                                              std::function<std::unique_ptr<IExecutionProvider>()> ep_creator,
                                               const std::string& reduction,
                                               const std::int64_t ignore_index,
-                                              const double error_tolerance,
                                               const std::vector<int64_t>* X_dims,
                                               const std::vector<int64_t>* index_dims,
                                               const std::vector<int64_t>* weight_dims,
@@ -403,7 +401,7 @@ static void TestSCELoss(const char* op, int opset_version,
     cpu_fetches = RunSCELossWithEP<float, float>(
         op, opset_version, domain,
         []() -> std::unique_ptr<IExecutionProvider> { return DefaultCpuExecutionProvider(); },
-        reduction, ignore_index, error_tolerance,
+        reduction, ignore_index,
         X_dims, index_dims, weight_dims,
         Y_dims, log_prob_dims,
         X_data_temp, index_data, weight_data_temp);
@@ -411,7 +409,7 @@ static void TestSCELoss(const char* op, int opset_version,
     cpu_fetches = RunSCELossWithEP<T, float>(
         op, opset_version, domain,
         []() -> std::unique_ptr<IExecutionProvider> { return DefaultCpuExecutionProvider(); },
-        reduction, ignore_index, error_tolerance,
+        reduction, ignore_index,
         X_dims, index_dims, weight_dims,
         Y_dims, log_prob_dims,
         X_data, index_data, weight_data);
@@ -429,7 +427,7 @@ static void TestSCELoss(const char* op, int opset_version,
         return DefaultRocmExecutionProvider();
 #endif
       },
-      reduction, ignore_index, error_tolerance,
+      reduction, ignore_index,
       X_dims, index_dims, weight_dims,
       Y_dims, log_prob_dims,
       X_data, index_data, weight_data);
diff --git a/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc b/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc
index e86aa871b6c5f..13ad2f6150acf 100644
--- a/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc
+++ b/orttraining/orttraining/test/training_ops/cuda/layer_norm_test.cc
@@ -49,7 +49,7 @@ static void TestLayerNormGrad(
 
   test.AddAttribute("axis", axis);
 
-  RandomValueGenerator random{};
+  RandomValueGenerator random{optional<RandomValueGenerator::RandomSeedType>{2345}};
   const auto Y_grad_data = random.Uniform<float>(n_x_m_dims, k_random_data_min, k_random_data_max);
   const auto X_data = random.Uniform<float>(n_x_m_dims, k_random_data_min, k_random_data_max);
   const auto scale_data = random.Uniform<float>(m_dims, k_random_data_min, k_random_data_max);
@@ -152,7 +152,7 @@ static void TestInvertibleLayerNormGrad(
 
   test.AddAttribute("axis", axis);
 
-  RandomValueGenerator random{};
+  RandomValueGenerator random{optional<RandomValueGenerator::RandomSeedType>{2345}};
   const auto Y_grad_data = random.Uniform<float>(n_x_m_dims, k_random_data_min, k_random_data_max);
   const auto X_data = random.Uniform<float>(n_x_m_dims, k_random_data_min, k_random_data_max);
   const auto scale_data = random.Uniform<float>(m_dims, k_random_data_min, k_random_data_max);
diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc
index 84c35e6100385..4647f890729f4 100644
--- a/orttraining/orttraining/training_api/optimizer.cc
+++ b/orttraining/orttraining/training_api/optimizer.cc
@@ -61,32 +61,19 @@ Status GraphInputsAreExpected(gsl::span<const std::string> actual_graph_inputs,
 }  // namespace
 
 std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance(
-    std::shared_ptr<Model> model, int32_t& group_count) {
+    const GraphViewer& graph_viewer, int32_t& group_count) {
   std::map<std::pair<std::string, std::string>, int32_t> opt_type_to_freq_map;
-#if !defined(ORT_MINIMAL_BUILD)
-  if (model != nullptr) {
-    Graph& graph = model->MainGraph();
-    for (auto& node : graph.Nodes()) {
-      if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) {
-        auto domain_type_pair = std::make_pair(node.Domain(), node.OpType());
-        if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) {
-          opt_type_to_freq_map[domain_type_pair] = 0;
-        }
 
-        opt_type_to_freq_map[domain_type_pair] += 1;
+  for (const auto& node : graph_viewer.Nodes()) {
+    if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) {
+      auto domain_type_pair = std::make_pair(node.Domain(), node.OpType());
+      if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) {
+        opt_type_to_freq_map[domain_type_pair] = 0;
       }
+
+      opt_type_to_freq_map[domain_type_pair] += 1;
     }
-  } else {
-#else
-  ORT_UNUSED_PARAMETER(model);
-#endif
-    // TODO(baijumeswani): Figure out the best way to extract the optimizer type
-    // from the model (either onnx model or ort format model) or from the checkpoint.
-    // For now, assume that the optimizer type is AdamWOptimizer when using ort format models.
-    opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1;
-#if !defined(ORT_MINIMAL_BUILD)
   }
-#endif
 
   ORT_ENFORCE(opt_type_to_freq_map.size() == 1U, "Only support one type of optimizer algorithm, but got: " +
                                                      std::to_string(opt_type_to_freq_map.size()));
@@ -105,42 +92,6 @@ std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance
   }
 }
 
-std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance(
-    const PathString& optim_path, int32_t& group_count) {
-  std::shared_ptr<Model> model = nullptr;
-#if !defined(ORT_MINIMAL_BUILD)
-  if (!fbs::utils::IsOrtFormatModel(optim_path)) {
-    ORT_ENFORCE(Model::Load(optim_path, model, nullptr,
-                            logging::LoggingManager::DefaultLogger())
-                    .IsOK());
-  }
-#else
-  ORT_UNUSED_PARAMETER(optim_path);
-#endif
-  return CreateInstance(model, group_count);
-}
-
-std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance(
-    const uint8_t* optim_model_data, size_t optim_model_data_len, int32_t& group_count) {
-  std::shared_ptr<Model> model = nullptr;
-#if !defined(ORT_MINIMAL_BUILD)
-  if (!fbs::utils::IsOrtFormatModelBytes(optim_model_data, static_cast<int>(optim_model_data_len))) {
-    ONNX_NAMESPACE::ModelProto model_proto;
-    ORT_ENFORCE(model_proto.ParseFromArray(optim_model_data, static_cast<int>(optim_model_data_len)) == true,
-                "Failed to load model because protobuf parsing failed.");
-
-    ORT_ENFORCE(Model::Load(std::move(model_proto), model, nullptr,
-                            logging::LoggingManager::DefaultLogger(), ModelOptions(true, true))
-                    .IsOK());
-  }
-#else
-  ORT_UNUSED_PARAMETER(optim_model_data);
-  ORT_UNUSED_PARAMETER(optim_model_data_len);
-#endif
-
-  return CreateInstance(model, group_count);
-}
-
 Status Optimizer::GenerateMomentumNamedStates(OptimizerCheckpointState& optimizer_checkpoint_states) {
   auto group_optimizer_state_it =
       optimizer_checkpoint_states.group_named_optimizer_states.find(GROUP_ZERO_NAME);
@@ -280,17 +231,15 @@ void Optimizer::Initialize(const ModelIdentifiers& model_identifiers,
     auto optimizer_model = std::get<std::optional<std::string>>(model_identifiers.optim_model);
     // The above call to IsOptimizerModelAvailable() ensures that optimizer_model is not nullopt
     ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.value()));
-    optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(ToWideString(optimizer_model.value()), group_count_);
   } else {
     auto optimizer_model = std::get<gsl::span<const uint8_t>>(model_identifiers.optim_model);
     ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.data(),
                                          static_cast<int>(optimizer_model.size())));
-    optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optimizer_model.data(),
-                                                                   optimizer_model.size(),
-                                                                   group_count_);
   }
 
   ORT_THROW_IF_ERROR(optim_sess_->Initialize());
+  optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optim_sess_->GetSessionState().GetGraphViewer(),
+                                                                 group_count_);
 
   // Make sure that the checkpoint state can copy tensors
   state_->optimizer_checkpoint_state.optimizer_session_data_transfer_mgr = &optim_sess_->GetDataTransferManager();
diff --git a/orttraining/orttraining/training_api/optimizer.h b/orttraining/orttraining/training_api/optimizer.h
index 031b11426539b..5b908acf7c9e3 100644
--- a/orttraining/orttraining/training_api/optimizer.h
+++ b/orttraining/orttraining/training_api/optimizer.h
@@ -64,11 +64,8 @@ struct SGDOptimizerV2Algorithm : public OptimizerAlgorithmBase {
 };
 
 struct OptimizerAlorithmFactory {
-  static std::unique_ptr<OptimizerAlgorithmBase> CreateInstance(const PathString& optim_path,
+  static std::unique_ptr<OptimizerAlgorithmBase> CreateInstance(const GraphViewer& graph_viewer,
                                                                 int32_t& group_count);
-  static std::unique_ptr<OptimizerAlgorithmBase> CreateInstance(const uint8_t* optim_model_data,
-                                                                size_t optim_model_data_len, int32_t& group_count);
-  static std::unique_ptr<OptimizerAlgorithmBase> CreateInstance(std::shared_ptr<Model> model, int32_t& group_count);
 };
 
 struct CheckpointState;
diff --git a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h
index f226db76f7ed7..db8e8558ab884 100644
--- a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h
+++ b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h
@@ -25,12 +25,15 @@ class TritonOp final : public OpKernel {
           attr.first == "onnx_string") {
         continue;
       }
-      // Support int64 and float only for now, skip other types.
+      // Support int64, float and string only for now, skip other types.
       if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT) {
         kwargs_.insert({attr.first, {std::to_string(attr.second.i()), ONNX_NAMESPACE::TensorProto_DataType_INT64}});
       } else if (attr.second.type() ==
                  ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_FLOAT) {
         kwargs_.insert({attr.first, {std::to_string(attr.second.f()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT}});
+      } else if (attr.second.type() ==
+                 ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_STRING) {
+        kwargs_.insert({attr.first, {attr.second.s(), ONNX_NAMESPACE::TensorProto_DataType_STRING}});
       }
     }
   }
diff --git a/orttraining/orttraining/training_ops/cuda/communication/nccl_service.cc b/orttraining/orttraining/training_ops/cuda/communication/nccl_service.cc
index f604e4c4aaf3e..c642a87e22de6 100644
--- a/orttraining/orttraining/training_ops/cuda/communication/nccl_service.cc
+++ b/orttraining/orttraining/training_ops/cuda/communication/nccl_service.cc
@@ -233,6 +233,7 @@ void NcclService::Initialize() {
   //   CPUs
   //   Other devices
 
+#ifdef USE_MPI
   const int mpi_rank = onnxruntime::training::MPIContext::GetInstance().GetWorldRank();
   const int mpi_local_rank = onnxruntime::training::MPIContext::GetInstance().GetLocalRank();
   const int mpi_size = onnxruntime::training::MPIContext::GetInstance().GetWorldSize();
@@ -248,6 +249,7 @@ void NcclService::Initialize() {
   if (mpi_rank == 0) NCCL_CALL_THROW(ncclGetUniqueId(&id));
   MPI_CHECK(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
   NCCL_CALL_THROW(ncclCommInitRank(&comm_, mpi_size, id, mpi_rank));
+#endif  // USE_MPI
 }
 
 void NcclService::Launch() {
diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc
index dcf733153bdad..8b2bc7e2ef2b3 100644
--- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc
+++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc
@@ -196,6 +196,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
 
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MixedPrecisionScale);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, LayerNormalizationGrad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalizationGrad);
 
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float, ReduceAllL2);
 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2);
@@ -452,6 +453,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
 
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MixedPrecisionScale)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, LayerNormalizationGrad)>,
+    BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalizationGrad)>,
 
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float, ReduceAllL2)>,
     BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2)>,
diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc
index f6c58445c0a5d..fc5d9b65d0f89 100644
--- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc
+++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc
@@ -114,7 +114,8 @@ Status ConvGrad<T>::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor&
     ORT_RETURN_IF_ERROR(args_.y_tensor.Set(dy_dims, args_.params.data_type));
     ORT_RETURN_IF_ERROR(args_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
                                             gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
-                                            args_.params.data_type));
+                                            args_.params.data_type,
+                                            UseTF32()));
 
     if (dB) {
       const TensorShape& db_shape = dB->Shape();
diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc
index 5dc16c68f6210..9b30bd128b161 100644
--- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc
+++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc
@@ -105,7 +105,8 @@ struct AlgoSearch<T_BwdDataPerf> {
         CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
         CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED};
     static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
-    ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward data algorithms.");
+    static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos,
+                  "Missing cuDNN convolution backward data algorithms.");
     int perf_count;
     std::unique_ptr<T_BwdDataPerf[]> candidates = std::make_unique<T_BwdDataPerf[]>(num_algos);
     if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) {
@@ -146,7 +147,9 @@ struct AlgoSearch<T_BwdFilterPerf> {
 
     // NOTE: - 1 because ALGO_WINOGRAD is not implemented.
     static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1;
-    ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms.");
+    static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos,
+                  "Missing cuDNN convolution backward filter algorithms.");
+
     std::unique_ptr<T_BwdFilterPerf[]> candidates = std::make_unique<T_BwdFilterPerf[]>(num_algos);
     int perf_count;
     if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) {
@@ -188,7 +191,9 @@ struct AlgoSearch<T_FwdPerf> {
     };
 
     static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
-    ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms.");
+    static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos,
+                  "Missing cuDNN convolution backward filter algorithms.");
+
     std::unique_ptr<T_FwdPerf[]> candidates = std::make_unique<T_FwdPerf[]>(num_algos);
     int perf_count;
     if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) {
@@ -233,11 +238,13 @@ bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const
 }
 
 template <typename T_Perf>
-Status AlgoIterator<T_Perf>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results) {
+Status AlgoIterator<T_Perf>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results, bool use_tf32) {
   perf_results.resize(1);
   perf_results[0].algo = AlgoSearch<T_Perf>::DEFAULT_ALGO;
   if (args.params.data_type == CUDNN_DATA_HALF) {
     perf_results[0].mathType = CUDNN_TENSOR_OP_MATH;
+  } else if (args.params.data_type == CUDNN_DATA_FLOAT && !use_tf32) {
+    perf_results[0].mathType = CUDNN_FMA_MATH;
   } else {
     perf_results[0].mathType = CUDNN_DEFAULT_MATH;
   }
@@ -256,7 +263,7 @@ Status AlgoIterator<T_Perf>::TryAll(const CUDAExecutionProvider* provider, const
 
   std::vector<T_Perf> perf_results;
   ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault
-                          ? OnlyDefaultAlgorithm(args_, perf_results)
+                          ? OnlyDefaultAlgorithm(args_, perf_results, provider->UseTF32())
                           : AlgoSearch<T_Perf>::FindAlgorithms(args_, provider, allocator, perf_results));
   for (auto& algo_perf : perf_results) {
     if (f(algo_perf) == Status::OK()) {
diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h
index a2d4bf3bdc006..3fdb4306bfbbb 100644
--- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h
+++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h
@@ -75,7 +75,7 @@ class AlgoIterator {
   Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator,
                 std::function<Status(const T_Perf& perf)> f);
 
-  static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results);
+  static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results, bool use_tf32);
 
  private:
   const ConvArgs& args_;
diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc
index 5f7206fc121ec..5d12e0ac312c0 100644
--- a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc
+++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc
@@ -53,7 +53,6 @@ Status ConvTransposeGrad<T>::ComputeInputGradient(onnxruntime::Stream* stream, c
             algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.y_tensor, args.y_data));
         return Status::OK();
       });
-  return Status::OK();
 }
 
 template <typename T>
@@ -71,7 +70,6 @@ Status ConvTransposeGrad<T>::ComputeWeightGradient(onnxruntime::Stream* stream,
             algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.w_desc, args.dw_data));
         return Status::OK();
       });
-  return Status::OK();
 }
 
 template <typename T>
@@ -182,7 +180,8 @@ Status ConvTransposeGrad<T>::PrepareConvForwardArgs(const Tensor& X, const Tenso
     ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type));
     ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
                                            gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
-                                           args.params.data_type));
+                                           args.params.data_type,
+                                           UseTF32()));
   }
 
   return Status::OK();
@@ -287,7 +286,8 @@ Status ConvTransposeGrad<T>::PrepareConvBackwardFilterArgs(const Tensor& X, cons
     ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type));
     ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
                                            gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
-                                           args.params.data_type));
+                                           args.params.data_type,
+                                           UseTF32()));
 
     if (dB) {
       const auto& b_shape = dB->Shape();
diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu
index 2d89ed05712e0..ad577afa06c18 100644
--- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu
+++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu
@@ -30,8 +30,6 @@
 namespace onnxruntime {
 namespace cuda {
 
-using namespace onnxruntime::cuda;
-
 namespace {
   // This is the un-specialized struct.  Note that we prevent instantiation of this
   // struct by putting an undefined symbol in the function body so it won't compile.
diff --git a/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu b/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu
index c90809eb2fdcc..fd55f7c30ff75 100644
--- a/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu
+++ b/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu
@@ -619,7 +619,7 @@ CudaKernel::CudaAsyncBuffer<LambMultiTensorSyncRangeAndLock> compute_tensor_rang
 
 template <typename TIn1, typename TIn2, typename TOut1, typename TOut2, typename TBuf>
 void LambMultiTensorReductionFunctor<TIn1, TIn2, TOut1, TOut2, TBuf>::operator()(
-    cudaStream_t stream,
+    cudaStream_t /*stream*/,
     ChunkGroup<4> chunk_group,
     const CudaKernel& kernel,
     void* reduction_buffer,
diff --git a/orttraining/tools/amdgpu/script/rocprof.py b/orttraining/tools/amdgpu/script/rocprof.py
index e5b107ba285bf..21dd8501f3f1d 100644
--- a/orttraining/tools/amdgpu/script/rocprof.py
+++ b/orttraining/tools/amdgpu/script/rocprof.py
@@ -68,18 +68,10 @@ def gpu_kernel_calls(activities):
 for name in groups:
     activities = groups[name]
     print(
-        "{}: N={}, calls={}, absolute={:.3f}s, percent={:.2f}%".format(
-            name,
-            len(activities),
-            gpu_kernel_calls(activities),
-            gpu_absolute_time(activities),
-            gpu_percent_time(activities),
-        )
+        f"{name}: N={len(activities)}, calls={gpu_kernel_calls(activities)}, absolute={gpu_absolute_time(activities):.3f}s, percent={gpu_percent_time(activities):.2f}%"
     )
 
 total = [item for name in groups for item in groups[name]]
 print(
-    "Total: N={}, calls={}, absolute={:.3f}s, percent={:.2f}%".format(
-        len(total), gpu_kernel_calls(total), gpu_absolute_time(total), gpu_percent_time(total)
-    )
+    f"Total: N={len(total)}, calls={gpu_kernel_calls(total)}, absolute={gpu_absolute_time(total):.3f}s, percent={gpu_percent_time(total):.2f}%"
 )
diff --git a/orttraining/tools/ci_test/run_bert_perf_test.py b/orttraining/tools/ci_test/run_bert_perf_test.py
index bb15d6f5965b6..13d5e9f140958 100644
--- a/orttraining/tools/ci_test/run_bert_perf_test.py
+++ b/orttraining/tools/ci_test/run_bert_perf_test.py
@@ -99,8 +99,8 @@ def main():
 
         subprocess.run(cmds).check_returncode()  # noqa: PLW1510
         if c.expected_perf > 0.0:
-            json_filename = "onnxruntime_perf_metrics_{}.onnx_bert_{}_{}_Lamb.json".format(
-                model, precision_prefix, c.max_seq_length
+            json_filename = (
+                f"onnxruntime_perf_metrics_{model}.onnx_bert_{precision_prefix}_{c.max_seq_length}_Lamb.json"
             )
             with open(os.path.join(SCRIPT_DIR, "results", json_filename)) as json_file:
                 results = json.load(json_file)
diff --git a/orttraining/tools/scripts/nv_run_pretraining.py b/orttraining/tools/scripts/nv_run_pretraining.py
index f64460f3ff0b9..8c57101f72ddb 100644
--- a/orttraining/tools/scripts/nv_run_pretraining.py
+++ b/orttraining/tools/scripts/nv_run_pretraining.py
@@ -81,9 +81,11 @@ def __len__(self):
 
     def __getitem__(self, index):
         [input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, next_sentence_labels] = [
-            torch.from_numpy(input[index].astype(np.int64))
-            if indice < 5
-            else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
+            (
+                torch.from_numpy(input[index].astype(np.int64))
+                if indice < 5
+                else torch.from_numpy(np.asarray(input[index].astype(np.int64)))
+            )
             for indice, input in enumerate(self.inputs)
         ]
 
@@ -231,9 +233,7 @@ def setup_training(args):
         )
     if args.train_batch_size % args.gradient_accumulation_steps != 0:
         raise ValueError(
-            "Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format(
-                args.gradient_accumulation_steps, args.train_batch_size
-            )
+            f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, batch size {args.train_batch_size} should be divisible"
         )
 
     args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
diff --git a/orttraining/tools/scripts/watch_experiment.py b/orttraining/tools/scripts/watch_experiment.py
index aefa1f57cfc16..d2255b63c66b5 100644
--- a/orttraining/tools/scripts/watch_experiment.py
+++ b/orttraining/tools/scripts/watch_experiment.py
@@ -57,11 +57,7 @@
     remote_root = args.remote_dir
 
     if run.get_status() in ["Completed", "Failed", "Canceled"]:
-        print(
-            "Downloading Experiment files from remote directory: '{}' to local directory: '{}'".format(
-                remote_root, local_root
-            )
-        )
+        print(f"Downloading Experiment files from remote directory: '{remote_root}' to local directory: '{local_root}'")
         files = [f for f in run.get_file_names() if f.startswith(remote_root)]
         for remote_path in files:
             local_path = os.path.join(local_root, os.path.basename(remote_path))
@@ -71,11 +67,7 @@
         event = Event()
         session = Session()
 
-        print(
-            "Streaming Experiment files from remote directory: '{}' to local directory: '{}'".format(
-                remote_root, local_root
-            )
-        )
+        print(f"Streaming Experiment files from remote directory: '{remote_root}' to local directory: '{local_root}'")
         watcher = RunWatcher(
             run, local_root=local_root, remote_root=remote_root, executor=executor, event=event, session=session
         )
diff --git a/pyproject.toml b/pyproject.toml
index 97515cb9fa62b..8fe114d4692c9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -44,19 +44,26 @@ reportMissingImports = false
 [tool.ruff]
 # NOTE: Do not create an exclude list. Edit .lintrunner.toml instead
 target-version = "py38"
+
+[tool.ruff.lint]
 select = [
     "B", # flake8-bugbear
     "E", # pycodestyle
     "F", # Pyflakes
+    "FURB", # refurb
+    "G", # flake8-logging-format
     "ISC", # flake8-implicit-str-concat
     "N", # pep8-naming
     "NPY", # numpy
     "PERF", # Perflint
+    "PIE", # flake8-pie
     "PLC", # pylint conventions
     "PLE", # pylint errors
     "PLW", # pylint warnings
+    "PYI", # flake8-pyi
     "RUF", # Ruff-specific rules
     "SIM", # flake8-simplify
+    "SLOT", # flake8-slots
     "T10", # flake8-debugger
     "UP", # pyupgrade
     "W", # pycodestyle
@@ -67,12 +74,15 @@ select = [
 ignore = [
     "B028", # FIXME: Add stacklevel to warnings
     "E501", # Line length controlled by black
+    "G004", # FIXME: Enable when the rule can be autofixed
     "N803", # Argument casing
     "N812", # Allow import torch.nn.functional as F
     "N999", # Module names
     "NPY002", # np.random.Generator may not always fit our use cases
     "PERF203", # "try-except-in-loop" only affects Python <3.11, and the improvement is minor; can have false positives
     "PERF401", # List comprehensions are not always readable
+    "PYI041", # May create confusion
+    "PYI024", # May create confusion
     "SIM102", # We don't perfer always combining if branches
     "SIM108", # We don't encourage ternary operators
     "SIM114", # Don't combine if branches for debugability
@@ -84,7 +94,7 @@ unfixable = [
     "SIM112", # Use upper case for env vars
 ]
 
-[tool.ruff.per-file-ignores]
+[tool.ruff.lint.per-file-ignores]
 # NOTE: Refrain from growing the ignore list unless for exceptional cases.
 # Prefer inline ignores with `noqa: xxx`.
 # Eventually this list should become empty.
diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt
index 6836d5df69324..d19ebe379b50b 100644
--- a/requirements-lintrunner.txt
+++ b/requirements-lintrunner.txt
@@ -1,9 +1,9 @@
 # This file is auto updated by dependabot
 lintrunner-adapters>=0.11.0
 # RUFF
-ruff==0.2.1
+ruff==0.3.2
 # BLACK-ISORT
-black==23.10.1
+black==24.2.0
 isort==5.12.0
 # CLANGFORMAT
 clang-format==17.0.4
diff --git a/setup.py b/setup.py
index 03e1cb75ba581..ffe2958b357b8 100644
--- a/setup.py
+++ b/setup.py
@@ -205,18 +205,23 @@ def run(self):
                 rocm_dependencies = [
                     "libamd_comgr.so.2",
                     "libamdhip64.so.5",
+                    "libamdhip64.so.6",
                     "libdrm.so.2",
                     "libdrm_amdgpu.so.1",
                     "libelf.so.1",
                     "libhipfft.so.0",
                     "libhiprtc.so.5",
+                    "libhiprtc.so.6",
                     "libhsa-runtime64.so.1",
                     "libMIOpen.so.1",
                     "libnuma.so.1",
                     "librccl.so.1",
                     "librocblas.so.3",
+                    "librocblas.so.4",
                     "librocfft.so.0",
+                    "libroctx64.so.4",
                     "librocm_smi64.so.5",
+                    "librocm_smi64.so.6",
                     "libroctracer64.so.4",
                     "libtinfo.so.6",
                     "libmigraphx_c.so.3",
@@ -227,6 +232,8 @@ def run(self):
 
                 tensorrt_dependencies = ["libnvinfer.so.8", "libnvinfer_plugin.so.8", "libnvonnxparser.so.8"]
 
+                cann_dependencies = ["libascendcl.so", "libacl_op_compiler.so", "libfmk_onnx_parser.so"]
+
                 dest = "onnxruntime/capi/libonnxruntime_providers_openvino.so"
                 if path.isfile(dest):
                     subprocess.run(
@@ -250,9 +257,9 @@ def run(self):
                 file = glob(path.join(self.dist_dir, "*linux*.whl"))[0]
                 logger.info("repairing %s for manylinux1", file)
                 auditwheel_cmd = ["auditwheel", "-v", "repair", "-w", self.dist_dir, file]
-                for i in cuda_dependencies + rocm_dependencies + tensorrt_dependencies:
+                for i in cuda_dependencies + rocm_dependencies + tensorrt_dependencies + cann_dependencies:
                     auditwheel_cmd += ["--exclude", i]
-                logger.info("Running {}".format(" ".join([shlex.quote(arg) for arg in auditwheel_cmd])))
+                logger.info("Running %s", " ".join([shlex.quote(arg) for arg in auditwheel_cmd]))
                 try:
                     subprocess.run(auditwheel_cmd, check=True, stdout=subprocess.PIPE)
                 finally:
@@ -609,9 +616,7 @@ def reformat_run_count(count_str):
             # TODO: this is the last time we have to do this!!!
             # We shall bump up release number right after release cut.
             if ort_version.major == 1 and ort_version.minor == 8 and ort_version.micro == 0:
-                version_number = "{major}.{minor}.{macro}".format(
-                    major=ort_version.major, minor=ort_version.minor + 1, macro=ort_version.micro
-                )
+                version_number = f"{ort_version.major}.{ort_version.minor + 1}.{ort_version.micro}"
 
     version_number = version_number + ".dev" + build_suffix
 
@@ -662,9 +667,11 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm
                 else:
                     print(
                         "Error getting cudart version. ",
-                        "did not find any cudart library"
-                        if not cudart_versions or len(cudart_versions) == 0
-                        else "found multiple cudart libraries",
+                        (
+                            "did not find any cudart library"
+                            if not cudart_versions or len(cudart_versions) == 0
+                            else "found multiple cudart libraries"
+                        ),
                     )
             elif rocm_version:
                 f.write(f"rocm_version = '{rocm_version}'\n")
diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py
index e286236ba6447..f1d3702e3245e 100644
--- a/tools/ci_build/amd_hipify.py
+++ b/tools/ci_build/amd_hipify.py
@@ -181,6 +181,8 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path):
     s = s.replace("rocm_device_prop_", "cuda_device_prop_")
     s = s.replace("rocm_device_arch_", "cuda_device_arch_")
 
+    s = s.replace("HipTuningContext", "RocmTuningContext")
+
     # We want hipfft, which needs hipDataType etc, but only do this for files that have "fft" in their names
     # And we do this last, undoing or fixing hipify mistakes.
     if "fft" in src_file_path:
diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py
index ab60819dee2e1..c64096fb32bda 100644
--- a/tools/ci_build/build.py
+++ b/tools/ci_build/build.py
@@ -38,8 +38,6 @@ def version_to_tuple(version: str) -> tuple:
 class BaseError(Exception):
     """Base class for errors originating from build.py."""
 
-    pass
-
 
 class BuildError(BaseError):
     """Error from running build steps."""
@@ -75,13 +73,14 @@ def _str_to_bool(s):
 
 
 def _openvino_verify_device_type(device_read):
-    choices = ["CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16"]
+    choices = ["CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16", "NPU"]
 
     choices1 = [
         "CPU_FP32_NO_PARTITION",
         "CPU_FP16_NO_PARTITION",
         "GPU_FP32_NO_PARTITION",
         "GPU_FP16_NO_PARTITION",
+        "NPU_NO_PARTITION",
     ]
     status_hetero = True
     res = False
@@ -89,14 +88,14 @@ def _openvino_verify_device_type(device_read):
         res = True
     elif device_read in choices1:
         res = True
-    elif device_read.startswith("HETERO:") or device_read.startswith("MULTI:") or device_read.startswith("AUTO:"):
+    elif device_read.startswith(("HETERO:", "MULTI:", "AUTO:")):
         res = True
         comma_separated_devices = device_read.split(":")
         comma_separated_devices = comma_separated_devices[1].split(",")
         if len(comma_separated_devices) < 2:
             print("At least two devices required in Hetero/Multi/Auto Mode")
             status_hetero = False
-        dev_options = ["CPU", "GPU"]
+        dev_options = ["CPU", "GPU", "NPU"]
         for dev in comma_separated_devices:
             if dev not in dev_options:
                 status_hetero = False
@@ -107,7 +106,7 @@ def invalid_hetero_build():
         print("specify the keyword HETERO or MULTI or AUTO followed by the devices ")
         print("in the order of priority you want to build\n")
         print("The different hardware devices that can be added in HETERO or MULTI or AUTO")
-        print("are ['CPU','GPU'] \n")
+        print("are ['CPU','GPU','NPU'] \n")
         print("An example of how to specify the hetero build type. Ex: HETERO:GPU,CPU \n")
         print("An example of how to specify the MULTI build type. Ex: MULTI:GPU,CPU \n")
         print("An example of how to specify the AUTO build type. Ex: AUTO:GPU,CPU \n")
@@ -118,7 +117,7 @@ def invalid_hetero_build():
         print("pick the build type for specific Hardware Device from following options: ", choices)
         print("(or) from the following options with graph partitioning disabled: ", choices1)
         print("\n")
-        if not (device_read.startswith("HETERO") or device_read.startswith("MULTI") or device_read.startswith("AUTO")):
+        if not (device_read.startswith(("HETERO", "MULTI", "AUTO"))):
             invalid_hetero_build()
         sys.exit("Wrong Build Type selected")
 
@@ -402,6 +401,12 @@ def convert_arg_line_to_args(self, arg_line):
 
     parser.add_argument("--ios", action="store_true", help="build for ios")
 
+    parser.add_argument(
+        "--macos",
+        choices=["MacOSX", "Catalyst"],
+        help="Specify the target platform for macOS build. Only specify this argument when --build_apple_framework is present.",
+    )
+
     parser.add_argument(
         "--apple_sysroot", default="", help="Specify the location name of the macOS platform SDK to be used"
     )
@@ -421,7 +426,7 @@ def convert_arg_line_to_args(self, arg_line):
         action="store_const",
         const="Xcode",
         dest="cmake_generator",
-        help="Use Xcode as cmake generator, this is only supported on MacOS. Equivalent to '--cmake_generator Xcode'.",
+        help="Use Xcode as cmake generator, this is only supported on MacOS. (non Catalyst build). Equivalent to '--cmake_generator Xcode'.",
     )
     parser.add_argument(
         "--osx_arch",
@@ -1222,6 +1227,7 @@ def generate_build_tree(
             "-Donnxruntime_USE_OPENVINO_GPU_FP16=" + ("ON" if args.use_openvino == "GPU_FP16" else "OFF"),
             "-Donnxruntime_USE_OPENVINO_CPU_FP32=" + ("ON" if args.use_openvino == "CPU_FP32" else "OFF"),
             "-Donnxruntime_USE_OPENVINO_CPU_FP16=" + ("ON" if args.use_openvino == "CPU_FP16" else "OFF"),
+            "-Donnxruntime_USE_OPENVINO_NPU=" + ("ON" if args.use_openvino == "NPU" else "OFF"),
             "-Donnxruntime_USE_OPENVINO_GPU_FP32_NP="
             + ("ON" if args.use_openvino == "GPU_FP32_NO_PARTITION" else "OFF"),
             "-Donnxruntime_USE_OPENVINO_GPU_FP16_NP="
@@ -1230,15 +1236,22 @@ def generate_build_tree(
             + ("ON" if args.use_openvino == "CPU_FP32_NO_PARTITION" else "OFF"),
             "-Donnxruntime_USE_OPENVINO_CPU_FP16_NP="
             + ("ON" if args.use_openvino == "CPU_FP16_NO_PARTITION" else "OFF"),
+            "-Donnxruntime_USE_OPENVINO_NPU_NP=" + ("ON" if args.use_openvino == "NPU_NO_PARTITION" else "OFF"),
             "-Donnxruntime_USE_OPENVINO_HETERO=" + ("ON" if args.use_openvino.startswith("HETERO") else "OFF"),
             "-Donnxruntime_USE_OPENVINO_DEVICE=" + (args.use_openvino),
             "-Donnxruntime_USE_OPENVINO_MULTI=" + ("ON" if args.use_openvino.startswith("MULTI") else "OFF"),
             "-Donnxruntime_USE_OPENVINO_AUTO=" + ("ON" if args.use_openvino.startswith("AUTO") else "OFF"),
         ]
 
-    # TensorRT and OpenVINO providers currently only support
-    # full_protobuf option.
-    if args.use_full_protobuf or args.use_tensorrt or args.use_openvino or args.use_vitisai or args.gen_doc:
+    # VitisAI and OpenVINO providers currently only support
+    # full_protobuf option. TensorRT provider only requires it if built with oss_parser
+    if (
+        args.use_full_protobuf
+        or (args.use_tensorrt and args.use_tensorrt_oss_parser)
+        or args.use_openvino
+        or args.use_vitisai
+        or args.gen_doc
+    ):
         cmake_args += ["-Donnxruntime_USE_FULL_PROTOBUF=ON", "-DProtobuf_USE_STATIC_LIBS=ON"]
 
     if args.use_tvm and args.llvm_path is not None:
@@ -1319,8 +1332,12 @@ def generate_build_tree(
     if args.use_snpe:
         cmake_args += ["-Donnxruntime_USE_SNPE=ON"]
 
-    if args.build_apple_framework or args.ios:
-        if not args.cmake_generator == "Xcode":
+    if args.macos or args.ios:
+        # Note: Xcode CMake generator doesn't have a good support for Mac Catalyst yet.
+        if args.macos == "Catalyst" and args.cmake_generator == "Xcode":
+            raise BuildError("Xcode CMake generator ('--cmake_generator Xcode') doesn't support Mac Catalyst build.")
+
+        if (args.ios or args.macos == "MacOSX") and not args.cmake_generator == "Xcode":
             raise BuildError(
                 "iOS/MacOS framework build requires use of the Xcode CMake generator ('--cmake_generator Xcode')."
             )
@@ -1338,12 +1355,15 @@ def generate_build_tree(
                 "iOS/MacOS framework build on MacOS canceled due to missing arguments: "
                 + ", ".join(val for val, cond in zip(arg_names, needed_args) if not cond)
             )
+        # note: this value is mainly used in framework_info.json file to specify the build osx type
+        platform_name = "macabi" if args.macos == "Catalyst" else args.apple_sysroot
         cmake_args += [
             "-Donnxruntime_BUILD_SHARED_LIB=ON",
             "-DCMAKE_OSX_SYSROOT=" + args.apple_sysroot,
             "-DCMAKE_OSX_DEPLOYMENT_TARGET=" + args.apple_deploy_target,
             # we do not need protoc binary for ios cross build
             "-Dprotobuf_BUILD_PROTOC_BINARIES=OFF",
+            "-DPLATFORM_NAME=" + platform_name,
         ]
         if args.ios:
             cmake_args += [
@@ -1351,6 +1371,21 @@ def generate_build_tree(
                 "-DCMAKE_TOOLCHAIN_FILE="
                 + (args.ios_toolchain_file if args.ios_toolchain_file else "../cmake/onnxruntime_ios.toolchain.cmake"),
             ]
+        # for catalyst build, we need to manually specify cflags for target e.g. x86_64-apple-ios14.0-macabi, etc.
+        # https://forums.developer.apple.com/forums/thread/122571
+        if args.macos == "Catalyst":
+            macabi_target = f"{args.osx_arch}-apple-ios{args.apple_deploy_target}-macabi"
+            cmake_args += [
+                "-DCMAKE_CXX_COMPILER_TARGET=" + macabi_target,
+                "-DCMAKE_C_COMPILER_TARGET=" + macabi_target,
+                "-DCMAKE_CC_COMPILER_TARGET=" + macabi_target,
+                f"-DCMAKE_CXX_FLAGS=--target={macabi_target}",
+                f"-DCMAKE_CXX_FLAGS_RELEASE=-O3 -DNDEBUG --target={macabi_target}",
+                f"-DCMAKE_C_FLAGS=--target={macabi_target}",
+                f"-DCMAKE_C_FLAGS_RELEASE=-O3 -DNDEBUG --target={macabi_target}",
+                f"-DCMAKE_CC_FLAGS=--target={macabi_target}",
+                f"-DCMAKE_CC_FLAGS_RELEASE=-O3 -DNDEBUG --target={macabi_target}",
+            ]
 
     if args.build_wasm:
         emsdk_dir = os.path.join(cmake_dir, "external", "emsdk")
@@ -1445,6 +1480,13 @@ def generate_build_tree(
             # tools need to use the symbols.
             add_default_definition(cmake_extra_defines, "CMAKE_MSVC_DEBUG_INFORMATION_FORMAT", "ProgramDatabase")
 
+        if number_of_parallel_jobs(args) > 0:
+            # https://devblogs.microsoft.com/cppblog/improved-parallelism-in-msbuild/
+            # NOTE: this disables /MP if set (according to comments on blog post).
+            # By default, MultiProcMaxCount and CL_MPCount value are equal to the number of CPU logical processors.
+            # See logic around setting CL_MPCount below
+            cmake_args += ["-DCMAKE_VS_GLOBALS=UseMultiToolTask=true;EnforceProcessCountAcrossBuilds=true"]
+
     cmake_args += [f"-D{define}" for define in cmake_extra_defines]
 
     cmake_args += cmake_extra_args
@@ -1519,7 +1561,8 @@ def generate_build_tree(
                 ldflags = ["/profile", "/DYNAMICBASE"]
                 # Address Sanitizer libs do not have a Qspectre version. So they two cannot be both enabled.
                 if not args.enable_address_sanitizer:
-                    cflags += ["/Qspectre"]
+                    # Also enable a special perf patch that was made for Intel Meteor Lake mobile CPUs
+                    cflags += ["/Qspectre", "/DONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH"]
                 if config == "Release":
                     cflags += ["/O2", "/Ob2", "/DNDEBUG"]
                 elif config == "RelWithDebInfo":
@@ -1624,9 +1667,11 @@ def generate_build_tree(
             [
                 *temp_cmake_args,
                 f"-DCMAKE_BUILD_TYPE={config}",
-                f"-DCMAKE_PREFIX_PATH={build_dir}/{config}/installed"
-                if preinstalled_dir.exists() and not (args.arm64 or args.arm64ec or args.arm)
-                else "",
+                (
+                    f"-DCMAKE_PREFIX_PATH={build_dir}/{config}/installed"
+                    if preinstalled_dir.exists() and not (args.arm64 or args.arm64ec or args.arm)
+                    else ""
+                ),
             ],
             cwd=config_build_dir,
             cuda_home=cuda_home,
@@ -1653,15 +1698,24 @@ def build_targets(args, cmake_path, build_dir, configs, num_parallel_jobs, targe
         build_tool_args = []
         if num_parallel_jobs != 1:
             if is_windows() and args.cmake_generator != "Ninja" and not args.build_wasm:
+                # https://github.com/Microsoft/checkedc-clang/wiki/Parallel-builds-of-clang-on-Windows suggests
+                # not maxing out CL_MPCount
+                # Start by having one less than num_parallel_jobs (default is num logical cores),
+                # limited to a range of 1..3
+                # that gives maxcpucount projects building using up to 3 cl.exe instances each
                 build_tool_args += [
                     f"/maxcpucount:{num_parallel_jobs}",
+                    # one less than num_parallel_jobs, at least 1, up to 3
+                    f"/p:CL_MPCount={min(max(num_parallel_jobs - 1, 1), 3)}",
                     # if nodeReuse is true, msbuild processes will stay around for a bit after the build completes
                     "/nodeReuse:False",
-                    f"/p:CL_MPCount={num_parallel_jobs}",
                 ]
             elif args.cmake_generator == "Xcode":
-                # CMake will generate correct build tool args for Xcode
-                cmd_args += ["--parallel", str(num_parallel_jobs)]
+                build_tool_args += [
+                    "-parallelizeTargets",
+                    "-jobs",
+                    str(num_parallel_jobs),
+                ]
             else:
                 build_tool_args += [f"-j{num_parallel_jobs}"]
 
@@ -1696,9 +1750,7 @@ def setup_cuda_vars(args):
         if not cuda_home_valid or (not is_windows() and not cudnn_home_valid):
             raise BuildError(
                 "cuda_home and cudnn_home paths must be specified and valid.",
-                "cuda_home='{}' valid={}. cudnn_home='{}' valid={}".format(
-                    cuda_home, cuda_home_valid, cudnn_home, cudnn_home_valid
-                ),
+                f"cuda_home='{cuda_home}' valid={cuda_home_valid}. cudnn_home='{cudnn_home}' valid={cudnn_home_valid}",
             )
 
     return cuda_home, cudnn_home
@@ -2464,11 +2516,11 @@ def diff_file(path, regenerate_qualifiers=""):
                     nonlocal have_diff
                     have_diff = True
                     log.warning(
-                        "The updated document {} is different from the checked in version. "
-                        "Please regenerate the file{}, or copy the updated version from the "
-                        "CI build's published artifacts if applicable.".format(path, regenerate_qualifiers)
+                        f"The updated document {path} is different from the checked in version. "
+                        f"Please regenerate the file{regenerate_qualifiers}, or copy the updated version from the "
+                        "CI build's published artifacts if applicable."
                     )
-                    log.debug("diff:\n" + diff)
+                    log.debug("diff:\n" + diff)  # noqa: G003
 
             diff_file(opkernel_doc_path, " with CPU, CUDA and DML execution providers enabled")
             diff_file(contrib_op_doc_path)
@@ -2483,7 +2535,7 @@ def diff_file(path, regenerate_qualifiers=""):
 
 
 def main():
-    log.debug("Command line arguments:\n  {}".format(" ".join(shlex.quote(arg) for arg in sys.argv[1:])))
+    log.debug("Command line arguments:\n  {}".format(" ".join(shlex.quote(arg) for arg in sys.argv[1:])))  # noqa: G001
 
     args = parse_arguments()
 
@@ -2580,7 +2632,7 @@ def main():
         raise BuildError("Using --get-api-doc requires a single build config")
 
     # Disabling unit tests for GPU on nuget creation
-    if args.use_openvino != "CPU_FP32" and args.build_nuget:
+    if args.use_openvino and args.use_openvino != "CPU_FP32" and args.build_nuget:
         args.test = False
 
     # GDK builds don't support testing
@@ -2719,7 +2771,13 @@ def main():
             cmake_extra_args += ["-G", args.cmake_generator]
 
         if is_macOS():
-            if not args.ios and not args.android and args.osx_arch == "arm64" and platform.machine() == "x86_64":
+            if (
+                not args.ios
+                and args.macos != "Catalyst"
+                and not args.android
+                and args.osx_arch == "arm64"
+                and platform.machine() == "x86_64"
+            ):
                 if args.test:
                     log.warning("Cannot test ARM64 build on X86_64. Will skip test running after build.")
                     args.test = False
diff --git a/tools/ci_build/clean_docker_image_cache.py b/tools/ci_build/clean_docker_image_cache.py
index f9b41ce31f92a..8ec2b6b438176 100755
--- a/tools/ci_build/clean_docker_image_cache.py
+++ b/tools/ci_build/clean_docker_image_cache.py
@@ -237,13 +237,13 @@ def main():
     def sorted_image_names(image_infos):
         return sorted([get_image_name(image_info) for image_info in image_infos])
 
-    log.debug("All images:\n{}".format("\n".join(sorted_image_names(all_images))))
-    log.debug("Valid images:\n{}".format("\n".join(sorted_image_names(valid_images))))
+    log.debug("All images:\n{}".format("\n".join(sorted_image_names(all_images))))  # noqa: G001
+    log.debug("Valid images:\n{}".format("\n".join(sorted_image_names(valid_images))))  # noqa: G001
 
     images_to_clean = all_images - valid_images
     image_names_to_clean = sorted_image_names(images_to_clean)
 
-    log.info("Images to clean:\n{}".format("\n".join(image_names_to_clean)))
+    log.info("Images to clean:\n{}".format("\n".join(image_names_to_clean)))  # noqa: G001
 
     if args.dry_run:
         log.info("Dry run, no images will be cleaned.")
diff --git a/tools/ci_build/get_docker_image.py b/tools/ci_build/get_docker_image.py
index 2ce1764c96327..99ecaf677f339 100755
--- a/tools/ci_build/get_docker_image.py
+++ b/tools/ci_build/get_docker_image.py
@@ -56,11 +56,7 @@ def parse_args():
 def main():
     args = parse_args()
 
-    log.debug(
-        "Dockerfile: {}, context: {}, docker build args: '{}'".format(
-            args.dockerfile, args.context, args.docker_build_args
-        )
-    )
+    log.debug(f"Dockerfile: {args.dockerfile}, context: {args.context}, docker build args: '{args.docker_build_args}'")
 
     use_container_registry = args.container_registry is not None
 
diff --git a/tools/ci_build/github/android/build_aar_package.py b/tools/ci_build/github/android/build_aar_package.py
index f9688a1453e12..3aaced63dd410 100644
--- a/tools/ci_build/github/android/build_aar_package.py
+++ b/tools/ci_build/github/android/build_aar_package.py
@@ -149,9 +149,11 @@ def _build_aar(args):
         "-DminSdkVer=" + str(build_settings["android_min_sdk_version"]),
         "-DtargetSdkVer=" + str(build_settings["android_target_sdk_version"]),
         "-DbuildVariant=" + str(build_settings["build_variant"]),
-        "-DENABLE_TRAINING_APIS=1"
-        if "--enable_training_apis" in build_settings["build_params"]
-        else "-DENABLE_TRAINING_APIS=0",
+        (
+            "-DENABLE_TRAINING_APIS=1"
+            if "--enable_training_apis" in build_settings["build_params"]
+            else "-DENABLE_TRAINING_APIS=0"
+        ),
     ]
 
     # clean, build, and publish to a local directory
diff --git a/tools/ci_build/github/apple/build_and_assemble_apple_pods.py b/tools/ci_build/github/apple/build_and_assemble_apple_pods.py
index 006dc4c33ffce..6188c7d7c0678 100755
--- a/tools/ci_build/github/apple/build_and_assemble_apple_pods.py
+++ b/tools/ci_build/github/apple/build_and_assemble_apple_pods.py
@@ -86,9 +86,7 @@ def run(arg_list, cwd=None):
     import shlex
     import subprocess
 
-    log.info(
-        "Running subprocess in '{}'\n  {}".format(cwd or os.getcwd(), " ".join([shlex.quote(arg) for arg in arg_list]))
-    )
+    log.info("Running subprocess in '%s'\n  %s", cwd or os.getcwd(), " ".join([shlex.quote(arg) for arg in arg_list]))
 
     return subprocess.run(arg_list, check=True, cwd=cwd)
 
diff --git a/tools/ci_build/github/apple/build_apple_framework.py b/tools/ci_build/github/apple/build_apple_framework.py
index 5137a0644b2e7..e17bcd65d8814 100644
--- a/tools/ci_build/github/apple/build_apple_framework.py
+++ b/tools/ci_build/github/apple/build_apple_framework.py
@@ -50,9 +50,11 @@ def _build_for_apple_sysroot(
     # Build binary for each arch, one by one
     for current_arch in archs:
         build_dir_current_arch = os.path.join(intermediates_dir, sysroot + "_" + current_arch)
+        # Use MacOS SDK for Catalyst builds
+        apple_sysroot = "macosx" if sysroot == "macabi" else sysroot
         build_command = [
             *base_build_command,
-            "--apple_sysroot=" + sysroot,
+            "--apple_sysroot=" + apple_sysroot,
             "--osx_arch=" + current_arch,
             "--build_dir=" + build_dir_current_arch,
         ]
@@ -65,9 +67,11 @@ def _build_for_apple_sysroot(
             build_dir_current_arch,
             build_config,
             build_config + "-" + sysroot,
-            "onnxruntime.framework"
-            if build_dynamic_framework
-            else os.path.join("static_framework", "onnxruntime.framework"),
+            (
+                "onnxruntime.framework"
+                if build_dynamic_framework
+                else os.path.join("static_framework", "onnxruntime.framework")
+            ),
         )
         ort_libs.append(os.path.join(framework_dir, "onnxruntime"))
 
diff --git a/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json b/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json
index 86b4efdc63750..04a73ae450e5f 100644
--- a/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json
+++ b/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json
@@ -23,6 +23,7 @@
             "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF"
         ],
         "macosx": [
+            "--macos=MacOSX",
             "--apple_deploy_target=11.0"
         ],
         "iphoneos": [
diff --git a/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json b/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json
index 445bfca9889ff..4bc978956d7fc 100644
--- a/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json
+++ b/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json
@@ -6,25 +6,35 @@
         "iphonesimulator": [
             "arm64",
             "x86_64"
+        ],
+        "macabi": [
+            "arm64",
+            "x86_64"
         ]
     },
     "build_params": {
         "base": [
             "--parallel",
-            "--use_xcode",
             "--build_apple_framework",
             "--use_coreml",
-            "--use_xnnpack",
             "--skip_tests",
             "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF"
         ],
         "iphoneos": [
             "--ios",
+            "--use_xcode",
+            "--use_xnnpack",
             "--apple_deploy_target=12.0"
         ],
         "iphonesimulator": [
             "--ios",
+            "--use_xcode",
+            "--use_xnnpack",
             "--apple_deploy_target=12.0"
+        ],
+        "macabi":[
+            "--macos=Catalyst",
+            "--apple_deploy_target=14.0"
         ]
     }
 }
diff --git a/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json b/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json
index f88934cd44a66..2066af7843e0a 100644
--- a/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json
+++ b/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json
@@ -32,6 +32,7 @@
             "--apple_deploy_target=12.0"
         ],
         "macosx": [
+            "--macos=MacOSX",
             "--apple_deploy_target=11.0"
         ]
     }
diff --git a/tools/ci_build/github/apple/framework_info.json.template b/tools/ci_build/github/apple/framework_info.json.template
index b4c4fb8d16ebf..1f7eeb5948799 100644
--- a/tools/ci_build/github/apple/framework_info.json.template
+++ b/tools/ci_build/github/apple/framework_info.json.template
@@ -1,5 +1,5 @@
 {
-    "@CMAKE_OSX_SYSROOT@": {
+    "@PLATFORM_NAME@": {
         "APPLE_DEPLOYMENT_TARGET": "@CMAKE_OSX_DEPLOYMENT_TARGET@",
         "WEAK_FRAMEWORK": "@APPLE_WEAK_FRAMEWORK@"
     }
diff --git a/tools/ci_build/github/apple/test_apple_packages.py b/tools/ci_build/github/apple/test_apple_packages.py
index cd360a63a3a0f..3987a37fcc76c 100644
--- a/tools/ci_build/github/apple/test_apple_packages.py
+++ b/tools/ci_build/github/apple/test_apple_packages.py
@@ -130,22 +130,70 @@ def _test_apple_packages(args):
 
             simulator_device_info = json.loads(simulator_device_info)
 
-            subprocess.run(
-                [
-                    "xcrun",
-                    "xcodebuild",
-                    "test",
-                    "-workspace",
-                    "./apple_package_test.xcworkspace",
-                    "-scheme",
-                    "ios_package_test",
-                    "-destination",
-                    f"platform=iOS Simulator,id={simulator_device_info['device_udid']}",
-                ],
-                shell=False,
-                check=True,
-                cwd=target_proj_path,
-            )
+            # Xcode UI tests seem to be flaky: https://github.com/orgs/community/discussions/68807
+            # Add a couple of retries if we get this error:
+            #   ios_package_testUITests-Runner Failed to initialize for UI testing:
+            #   Error Domain=com.apple.dt.XCTest.XCTFuture Code=1000 "Timed out while loading Accessibility."
+            attempts = 0
+            cmd = [
+                "xcrun",
+                "xcodebuild",
+                "test",
+                "-workspace",
+                "./apple_package_test.xcworkspace",
+                "-scheme",
+                "ios_package_test",
+                "-destination",
+                f"platform=iOS Simulator,id={simulator_device_info['device_udid']}",
+            ]
+
+            while True:
+                attempts += 1
+                completed_process = subprocess.run(
+                    cmd,
+                    shell=False,
+                    capture_output=True,
+                    check=False,
+                    text=True,
+                    cwd=target_proj_path,
+                )
+
+                # print so it's in CI output
+                print(completed_process.stdout)
+
+                if completed_process.returncode != 0:
+                    print(f"Running ios_package_test failed. Return code was {completed_process.returncode}")
+                    print("xcrun xcodebuild test stderr:")
+                    print(completed_process.stderr)
+                    print("---")
+
+                    if "Timed out while loading Accessibility" in completed_process.stderr and attempts < 3:
+                        continue
+
+                    raise subprocess.CalledProcessError(
+                        completed_process.returncode, " ".join(cmd), completed_process.stdout, completed_process.stderr
+                    )
+
+                break
+
+            if args.mac_catalyst_enabled:
+                subprocess.run(
+                    [
+                        "xcrun",
+                        "xcodebuild",
+                        "test",
+                        "-workspace",
+                        "./apple_package_test.xcworkspace",
+                        "-scheme",
+                        "ios_package_test",
+                        "-destination",
+                        "platform=macOS,variant=Mac Catalyst",
+                        "CODE_SIGNING_ALLOWED=NO",
+                    ],
+                    shell=False,
+                    check=True,
+                    cwd=target_proj_path,
+                )
 
             if PackageVariant[args.variant] != PackageVariant.Mobile and not args.skip_macos_test:
                 subprocess.run(
@@ -215,6 +263,12 @@ def parse_args():
         help="Skip macos platform tests. Specify this argument when build targets only contain ios archs. ",
     )
 
+    parser.add_argument(
+        "--mac_catalyst_enabled",
+        action="store_true",
+        help="Run tests for mac catalyst variants. Specify this argument when build targets contains catalyst archs. ",
+    )
+
     return parser.parse_args()
 
 
diff --git a/tools/ci_build/github/apple/test_minimal_training_ios_simulator_framework_build_settings.json b/tools/ci_build/github/apple/test_minimal_training_ios_simulator_framework_build_settings.json
new file mode 100644
index 0000000000000..1a89d941e5e52
--- /dev/null
+++ b/tools/ci_build/github/apple/test_minimal_training_ios_simulator_framework_build_settings.json
@@ -0,0 +1,22 @@
+{
+  "build_osx_archs": {
+    "iphonesimulator": [
+      "x86_64"
+    ]
+  },
+  "build_params": {
+    "base": [
+      "--parallel",
+      "--use_xcode",
+      "--build_apple_framework",
+      "--minimal_build=extended",
+      "--enable_training_apis",
+      "--skip_tests",
+      "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF"
+    ],
+    "iphonesimulator": [
+      "--ios",
+      "--apple_deploy_target=12.0"
+    ]
+  }
+}
diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml
index 2b181810b0788..d37266a8e96d8 100644
--- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml
@@ -31,7 +31,7 @@ parameters:
 - name: QnnSdk
   displayName: QNN SDK version
   type: string
-  default: qnn-v2.18.0.240101
+  default: qnn-v2.19.2.240210
 
 jobs:
 - job: Build_QNN_EP
diff --git a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml
index 9136b21aec626..d0a22aae07741 100644
--- a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml
@@ -53,7 +53,7 @@ stages:
     Codeql.Enabled: false
   jobs:
   - job: Build_CPU_EP
-    pool: onnxruntime-Linux-CPU-For-Android-CI
+    pool: onnxruntime-Ubuntu2204-AMD-CPU
     workspace:
       clean: all
     timeoutInMinutes: 30
@@ -140,7 +140,7 @@ stages:
 
   jobs:
   - job: Build_NNAPI_EP
-    pool: onnxruntime-Linux-CPU-For-Android-CI
+    pool: onnxruntime-Ubuntu2204-AMD-CPU
     timeoutInMinutes: ${{ variables.JobsTimeout }}
     workspace:
       clean: all
@@ -456,7 +456,7 @@ stages:
     variables:
     - name: skipComponentGovernanceDetection
       value: true
-    pool: 'onnxruntime-Linux-CPU-For-Android-CI'
+    pool: 'onnxruntime-Ubuntu2204-AMD-CPU'
     condition: and(succeeded(), in(variables['Build.Reason'], 'IndividualCI', 'BatchedCI'))
     dependsOn:
     - NNAPI_EP_MASTER
diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml
index 65866fc9827a5..b9a47f6739fe8 100644
--- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml
@@ -25,7 +25,7 @@ pr:
     - BUILD.md
     - 'js/web'
     - 'onnxruntime/core/providers/js'
-#### end trigger ####parameters:
+#### end trigger ####
 
 # reference: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md
 parameters:
@@ -214,7 +214,7 @@ stages:
             python3 -m pip install /Release/*.whl; \
             pushd /workspace/onnxruntime/python/tools/transformers/models/stable_diffusion; \
             python3 -m pip install -r requirements-cuda11.txt; \
-            python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com; \
+            python3 -m pip install --upgrade polygraphy onnx-graphsurgeon ; \
             echo Generate an image guided by a text prompt; \
             python3 demo_txt2img.py --framework-model-dir /model_cache --seed 1 --deterministic "astronaut riding a horse on mars" ; \
             find $(pwd)/ORT_CUDA -name "*.png" -exec cp {} /images/ \; ; \
@@ -314,12 +314,81 @@ stages:
               pushd /workspace/onnxruntime/python/tools/transformers/ ; \
               python3 -m pip install --upgrade pip ; \
               pushd models/llama ; \
-              python3 -m pip install -r requirements-cuda.txt ; \
+              python3 -m pip install -r requirements.txt ; \
               popd ; \
               python3 -m pip install /ort-artifact/*.whl ; \
+              python3 -m pip uninstall -y torch ; \
               python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \
               python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --input /meta-llama2 --small_gpu ;\
               popd ; \
             "
       displayName: 'Run Llama2 to Onnx F16 and parity Test'
       workingDirectory: $(Build.SourcesDirectory)
+
+- stage: Whisper_ONNX
+  dependsOn:
+  - Build_Onnxruntime_Cuda
+  jobs:
+  - job: Whisper_ONNX
+    variables:
+      skipComponentGovernanceDetection: true
+    workspace:
+      clean: all
+    pool: Onnxruntime-Linux-A10-24G
+    steps:
+    - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
+      displayName: 'Clean Agent Directories'
+      condition: always()
+
+    - checkout: self
+      clean: true
+      submodules: none
+
+    - template: templates/flex-downloadPipelineArtifact.yml
+      parameters:
+        StepName: 'Download Onnxruntime Artifact'
+        ArtifactName: 'drop-ort-linux-gpu'
+        TargetPath: '$(Build.BinariesDirectory)/ort-artifact/'
+        SpecificArtifact: ${{ parameters.specificArtifact }}
+        BuildId: ${{ parameters.BuildId }}
+
+    - template: templates/get-docker-image-steps.yml
+      parameters:
+        Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu
+        Context: tools/ci_build/github/linux/docker/
+        ScriptName: tools/ci_build/get_docker_image.py
+        DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )"
+        Repository: onnxruntimepackagestest
+        UpdateDepsTxt: false
+
+    - task: DownloadPackage@1
+      # The model data in artifact is downloaded from openai/whisper-large-v3 in huggingface model hub
+      # In order to save size, removed .git directory and pickled files, and keep the safetensors model files
+      displayName: 'Download Whisper Model'
+      inputs:
+        packageType: upack
+        feed: '/7424c8e4-5c62-490e-95c4-79446f31017c'
+        version: 1.0.0
+        definition: 'b583ce7c-1a8f-4099-ae28-5d5f56c478b1'
+        downloadPath: $(Agent.TempDirectory)/whisper_large_v3
+
+    - script: |
+        docker run --rm --gpus all -v $(Build.SourcesDirectory):/workspace \
+           -v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \
+           -v $(Agent.TempDirectory)/whisper_large_v3:/whisper_large_v3 \
+           onnxruntimepackagestest \
+            bash -c '
+              set -ex; \
+              pushd /workspace/onnxruntime/python/tools/transformers/ ; \
+              python3 -m pip install --upgrade pip ; \
+              pushd models/whisper ; \
+              python3 -m pip install -r requirements.txt ; \
+              popd ; \
+              python3 -m pip install /ort-artifact/*.whl ; \
+              python3 -m pip uninstall -y torch ; \
+              python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \
+              python3 -m models.whisper.convert_to_onnx -m /whisper_large_v3 --output whisperlargev3 --use_external_data_format ; \
+              popd ; \
+            '
+      displayName: 'Convert Whisper Model'
+      workingDirectory: $(Build.SourcesDirectory)
diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml
index 5a50a9964bead..a63f1b74b7633 100644
--- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml
+++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml
@@ -394,12 +394,11 @@ stages:
     steps:
     - template: templates/set-version-number-variables-step.yml
 
-    - task: BatchScript@1
-      displayName: 'setup env'
-      inputs:
-        filename: '$(Build.SourcesDirectory)\tools\ci_build\github\windows\setup_env_cuda.bat'
-        modifyEnvironment: true
-        workingFolder: '$(Build.BinariesDirectory)'
+    - template: templates/jobs/download_win_gpu_library.yml
+      parameters:
+        CudaVersion: ${{ parameters.CudaVersion }}
+        DownloadCUDA: true
+        DownloadTRT: true
 
     - template: templates\flex-downloadPipelineArtifact.yml
       parameters:
@@ -507,12 +506,11 @@ stages:
       condition: always()
 
     - script: dir $(Build.SourcesDirectory)
-    - task: BatchScript@1
-      displayName: 'setup env'
-      inputs:
-        filename: '$(Build.SourcesDirectory)\onnxruntime\tools\ci_build\github\windows\setup_env_gpu.bat'
-        modifyEnvironment: true
-        workingFolder: '$(Build.BinariesDirectory)'
+    - template: templates/jobs/download_win_gpu_library.yml
+      parameters:
+        CudaVersion: ${{ parameters.CudaVersion }}
+        DownloadCUDA: true
+        DownloadTRT: true
     - template: templates/set-version-number-variables-step.yml
       parameters:
         versionFileDirectory: '$(Build.SourcesDirectory)\onnxruntime'
diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml
index a4bd24b4dd18b..82e571bf6519f 100644
--- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml
@@ -86,7 +86,7 @@ stages:
                     -e NIGHTLY_BUILD \
                     -e BUILD_BUILDNUMBER \
                     onnxruntimecpubuildcentos8x64 \
-                    /bin/bash -c "
+                    /bin/bash -c '
                       set -ex; \
                       python3.9 /onnxruntime_src/tools/ci_build/build.py \
                         --build_dir /build --cmake_generator 'Ninja' \
@@ -105,7 +105,8 @@ stages:
                         --parallel --use_binskim_compliant_compile_flags \
                         --build_csharp \
                         --enable_onnx_tests --enable_address_sanitizer \
-                        --test;"
+                        --test;
+                      '
                 workingDirectory: $(Build.SourcesDirectory)
 
       - task: PublishTestResults@2
@@ -115,6 +116,7 @@ stages:
           searchFolder: '$(Build.BinariesDirectory)'
           testRunTitle: 'Unit Test Run'
         condition: succeededOrFailed()
+
     - job: Linux_Release
       timeoutInMinutes: 180
       workspace:
@@ -243,7 +245,50 @@ stages:
           ln -s /data/models $(Build.BinariesDirectory)/models
         displayName: link model dir
 
-      
+      - bash: |
+          mkdir -p $HOME/.onnx
+          docker run --rm \
+            --volume /data/onnx:/data/onnx:ro \
+            --volume $(Build.SourcesDirectory):/onnxruntime_src \
+            --volume $(Build.BinariesDirectory):/build \
+            --volume /data/models:/build/models:ro \
+            --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \
+            -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \
+            -e NIGHTLY_BUILD \
+            -e BUILD_BUILDNUMBER \
+            onnxruntimecpubuild \
+            /bin/bash -c "
+              set -ex; \
+              pushd /onnxruntime_src/csharp; \
+              dotnet restore /onnxruntime_src/csharp/OnnxRuntime.DesktopOnly.CSharp.sln; \
+              dotnet build /onnxruntime_src/csharp/OnnxRuntime.DesktopOnly.CSharp.sln -c Release; \
+              dotnet test /onnxruntime_src/csharp/OnnxRuntime.DesktopOnly.CSharp.sln -c Release -f net6.0 --no-build -l \"console;verbosity=normal\"; \
+              popd
+              "
+        displayName: 'Dotnet build C# sln and Test'
+
+      - bash: |
+          mkdir -p $HOME/.onnx
+          docker run --rm \
+            --volume /data/onnx:/data/onnx:ro \
+            --volume $(Build.SourcesDirectory):/onnxruntime_src \
+            --volume $(Build.BinariesDirectory):/build \
+            --volume /data/models:/build/models:ro \
+            --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \
+            -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \
+            -e NIGHTLY_BUILD \
+            -e BUILD_BUILDNUMBER \
+            onnxruntimecpubuild \
+              /bin/bash -c "
+                set -ex; \
+                /bin/bash /onnxruntime_src/tools/scripts/python_test.sh /onnxruntime_src /build Release && \
+                /bin/bash /onnxruntime_src/tools/scripts/symbolic_shape_infer_test.sh /build
+              "
+        displayName: 'Run Release tests and symbolic shape infer test'
+
+      - template: templates/check_test_result.yml
+        parameters:
+          FileName: '$(Build.BinariesDirectory)/Release/onnxruntime_test_all.Release.results.xml'
 
       - task: PublishTestResults@2
         displayName: 'Publish unit test results'
diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml
index 1053a2518125f..bbea7a0d114e8 100644
--- a/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml
@@ -59,7 +59,7 @@ jobs:
   timeoutInMinutes: 120
   workspace:
     clean: all
-  pool: onnxruntime-Linux-CPU-For-Android-CI
+  pool: onnxruntime-Ubuntu2204-AMD-CPU
   variables:
     ORT_CACHE_DIR: $(Pipeline.Workspace)/ort_ccache
     TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)]
diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml
index b19a8b11db265..b7232e9dc4ba1 100644
--- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml
@@ -34,6 +34,17 @@ parameters:
     values:
       - 11.8
       - 12.2
+
+  - name: SpecificArtifact
+    displayName: Use Specific Artifact
+    type: boolean
+    default: false
+
+  - name: BuildId
+    displayName: Specific Artifact's BuildId
+    type: string
+    default: '0'
+
 resources:
   repositories:
   - repository: manylinux
@@ -61,162 +72,201 @@ variables:
     ${{ if eq(parameters.CudaVersion, '12.2') }}:
       value: 'onnxruntimecuda12build'
 
-jobs:
-- job: Linux_Build
-  timeoutInMinutes: 120
-  variables:
-    skipComponentGovernanceDetection: true
-    CCACHE_DIR: $(Pipeline.Workspace)/ccache
-  workspace:
-    clean: all
-  pool: onnxruntime-Ubuntu2204-AMD-CPU
-
-  steps:
-  - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
-    displayName: 'Clean Agent Directories'
-    condition: always()
-
-  - checkout: self
-    clean: true
-    submodules: none
-
-  - template: templates/get-docker-image-steps.yml
-    parameters:
-      Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda
-      Context: tools/ci_build/github/linux/docker
-      DockerBuildArgs: "
-      --network=host
-      --build-arg BASEIMAGE=$(docker_base_image)
-      --build-arg TRT_VERSION=$(linux_trt_version)
-      --build-arg BUILD_UID=$( id -u )
-      "
-      Repository: $(Repository)
-
-  - task: Cache@2
-    inputs:
-      key: '"ccache" | "${{parameters.CudaVersion}}" |"$(Build.SourceBranch)" | "$(Build.SourceVersion)"'
-      path: $(CCACHE_DIR)
-      restoreKeys: |
-        "ccache" | "${{parameters.CudaVersion}}" | "$(Build.SourceBranch)"
-        "ccache"
-      cacheHitVar: CACHE_RESTORED
-    displayName: Cach Task
-
-  - script: |
-      sudo mkdir -p $(Pipeline.Workspace)/ccache
-    condition: ne(variables.CACHE_RESTORED, 'true')
-    displayName: Create Cache Dir
-
-  - script: |
-      set -e -x
-      mkdir -p $HOME/.onnx
-      docker run -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" --rm \
-        --volume /data/onnx:/data/onnx:ro \
-        --volume $(Build.SourcesDirectory):/onnxruntime_src \
-        --volume $(Build.BinariesDirectory):/build \
-        --volume /data/models:/build/models:ro \
-        --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \
-        --volume $(Pipeline.Workspace)/ccache:/cache \
-        -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \
-        -e NIGHTLY_BUILD \
-        -e BUILD_BUILDNUMBER \
-        -e CCACHE_DIR=/cache \
-        $(Repository) \
-        /bin/bash -c "
-          set -ex; \
-          env; \
-          ccache -s; \
-          /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \
-            --build_dir /build --cmake_generator Ninja \
-            --config Release --update --build \
-            --skip_submodule_sync \
-            --build_shared_lib \
-            --parallel --use_binskim_compliant_compile_flags \
-            --build_wheel \
-            --enable_onnx_tests --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda-${{parameters.CudaVersion}} --cudnn_home=/usr/local/cuda-${{parameters.CudaVersion}} \
-            --enable_cuda_profiling --enable_cuda_nhwc_ops \
-            --enable_pybind --build_java \
-            --use_cache \
-            --cmake_extra_defines  CMAKE_CUDA_ARCHITECTURES=86; \
-              ccache -sv; \
-              ccache -z"
-    workingDirectory: $(Build.SourcesDirectory)
-    displayName: Build Onnxruntime
-
-  - task: CmdLine@2
-    inputs:
-      script: |
-        rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11
-        rm -f $(Build.BinariesDirectory)/Release/models
-        find $(Build.BinariesDirectory)/Release/_deps -mindepth 1 ! -regex '^$(Build.BinariesDirectory)/Release/_deps/onnx-src\(/.*\)?' -delete
-        cd $(Build.BinariesDirectory)/Release
-        find -executable -type f > $(Build.BinariesDirectory)/Release/perms.txt
-
-  - task: PublishPipelineArtifact@0
-    displayName: 'Publish Pipeline Artifact'
-    inputs:
-      artifactName: 'drop-linux'
-      targetPath: '$(Build.BinariesDirectory)/Release'
-
-  - template: templates/explicitly-defined-final-tasks.yml
-
-- job: Linux_Test
-  timeoutInMinutes: 180
-  variables:
-    skipComponentGovernanceDetection: true
-  workspace:
-    clean: all
-  pool: onnxruntime-Linux-GPU-A10
-  dependsOn:
-  - Linux_Build
-  steps:
-  - task: DownloadPipelineArtifact@2
-    displayName: 'Download Pipeline Artifact'
-    inputs:
-      buildType: 'current'
-      artifactName: 'drop-linux'
-      targetPath: '$(Build.BinariesDirectory)/Release'
-
-  - checkout: self
-    clean: true
-    submodules: none
-
-  - template: templates/get-docker-image-steps.yml
-    parameters:
-      Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda
-      Context: tools/ci_build/github/linux/docker
-      DockerBuildArgs: "
-      --network=host
-      --build-arg BASEIMAGE=$(docker_base_image)
-      --build-arg TRT_VERSION=$(linux_trt_version)
-      --build-arg BUILD_UID=$( id -u )
-      "
-      Repository: $(Repository)
-
-  - task: CmdLine@2
-    inputs:
-      script: |
+stages:
+- stage: Linux_Build
+  jobs:
+  - job: Linux_Build
+    timeoutInMinutes: 120
+    variables:
+      skipComponentGovernanceDetection: true
+      CCACHE_DIR: $(Pipeline.Workspace)/ccache
+    workspace:
+      clean: all
+    pool: onnxruntime-Ubuntu2204-AMD-CPU
+
+    steps:
+    - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
+      displayName: 'Clean Agent Directories'
+      condition: always()
+
+    - checkout: self
+      clean: true
+      submodules: none
+
+    - template: templates/get-docker-image-steps.yml
+      parameters:
+        Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda
+        Context: tools/ci_build/github/linux/docker
+        DockerBuildArgs: "
+        --network=host
+        --build-arg BASEIMAGE=$(docker_base_image)
+        --build-arg TRT_VERSION=$(linux_trt_version)
+        --build-arg BUILD_UID=$( id -u )
+        "
+        Repository: $(Repository)
+
+    - task: Cache@2
+      inputs:
+        key: '"ccache" | "${{parameters.CudaVersion}}" |"$(Build.SourceBranch)" | "$(Build.SourceVersion)"'
+        path: $(CCACHE_DIR)
+        restoreKeys: |
+          "ccache" | "${{parameters.CudaVersion}}" | "$(Build.SourceBranch)"
+          "ccache"
+        cacheHitVar: CACHE_RESTORED
+      displayName: Cach Task
+
+    - script: |
+        sudo mkdir -p $(Pipeline.Workspace)/ccache
+      condition: ne(variables.CACHE_RESTORED, 'true')
+      displayName: Create Cache Dir
+
+    - script: |
         set -e -x
         mkdir -p $HOME/.onnx
-        docker run --gpus all --rm \
-          --volume  $(Build.SourcesDirectory):/onnxruntime_src \
-          --volume $(Build.BinariesDirectory)/Release:/build/Release \
+        docker run -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" --rm \
+          --volume /data/onnx:/data/onnx:ro \
+          --volume $(Build.SourcesDirectory):/onnxruntime_src \
+          --volume $(Build.BinariesDirectory):/build \
           --volume /data/models:/build/models:ro \
           --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \
-          --volume /data/onnx:/data/onnx \
+          --volume $(Pipeline.Workspace)/ccache:/cache \
+          -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \
+          -e NIGHTLY_BUILD \
+          -e BUILD_BUILDNUMBER \
+          -e CCACHE_DIR=/cache \
           $(Repository) \
           /bin/bash -c "
             set -ex; \
-            cp /onnxruntime_src/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt /tmp/requirements.txt; \
-            ln -s /opt/python/cp38-cp38/bin/python3 /tmp/python3; \
-            /tmp/python3 -m pip install -r /tmp/requirements.txt; \
-            /tmp/python3 -m pip install /build/Release/dist/*.whl; \
-            cd /build/Release && xargs -a /build/Release/perms.txt chmod a+x; \
-            cd /onnxruntime_src/java && /onnxruntime_src/java/gradlew cmakeCheck -DcmakeBuildDir=/build/Release -DUSE_CUDA=1; \
-            cd /tmp; \
-            /tmp/python3 /onnxruntime_src/tools/ci_build/build.py \
-              --build_dir /build --config Release --test --skip_submodule_sync --build_shared_lib --parallel --use_binskim_compliant_compile_flags --build_wheel --enable_onnx_tests \
-              --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda --cudnn_home=/usr/local/cuda \
-              --enable_pybind --build_java --ctest_path '' "
-
-  - template: templates/clean-agent-build-directory-step.yml
+            env; \
+            ccache -s; \
+            /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \
+              --build_dir /build --cmake_generator Ninja \
+              --config Release --update --build \
+              --skip_submodule_sync \
+              --build_shared_lib \
+              --parallel --use_binskim_compliant_compile_flags \
+              --build_wheel \
+              --enable_onnx_tests --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda-${{parameters.CudaVersion}} --cudnn_home=/usr/local/cuda-${{parameters.CudaVersion}} \
+              --enable_cuda_profiling --enable_cuda_nhwc_ops \
+              --enable_pybind --build_java \
+              --use_cache \
+              --cmake_extra_defines  CMAKE_CUDA_ARCHITECTURES=75; \
+                ccache -sv; \
+                ccache -z"
+      workingDirectory: $(Build.SourcesDirectory)
+      displayName: Build Onnxruntime
+
+    - task: CmdLine@2
+      inputs:
+        script: |
+          rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11
+          rm -f $(Build.BinariesDirectory)/Release/models
+          find $(Build.BinariesDirectory)/Release/_deps -mindepth 1 ! -regex '^$(Build.BinariesDirectory)/Release/_deps/onnx-src\(/.*\)?' -delete
+          cd $(Build.BinariesDirectory)/Release
+          find -executable -type f > $(Build.BinariesDirectory)/Release/perms.txt
+
+    - task: PublishPipelineArtifact@0
+      displayName: 'Publish Pipeline Artifact'
+      inputs:
+        artifactName: 'drop-linux'
+        targetPath: '$(Build.BinariesDirectory)/Release'
+
+    - template: templates/explicitly-defined-final-tasks.yml
+
+- stage: Linux_Test
+  dependsOn:
+    - Linux_Build
+  jobs:
+  - job: Linux_Test
+    timeoutInMinutes: 180
+    variables:
+      skipComponentGovernanceDetection: true
+    workspace:
+      clean: all
+    pool: onnxruntime-Linux-GPU-T4
+    steps:
+    - checkout: self
+      clean: true
+      submodules: none
+
+    - template: templates/flex-downloadPipelineArtifact.yml
+      parameters:
+        ArtifactName: 'drop-linux'
+        StepName: 'Download Pipeline Artifact - Linux Build'
+        TargetPath: '$(Build.BinariesDirectory)/Release'
+        SpecificArtifact: ${{ parameters.SpecificArtifact }}
+        BuildId: ${{ parameters.BuildId }}
+
+    - template: templates/get-docker-image-steps.yml
+      parameters:
+        Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda
+        Context: tools/ci_build/github/linux/docker
+        DockerBuildArgs: "
+        --network=host
+        --build-arg BASEIMAGE=$(docker_base_image)
+        --build-arg TRT_VERSION=$(linux_trt_version)
+        --build-arg BUILD_UID=$( id -u )
+        "
+        Repository: $(Repository)
+
+    - task: CmdLine@2
+      inputs:
+        script: |
+          set -e -x
+          mkdir -p $HOME/.onnx
+          docker run --gpus all --rm \
+            --volume  $(Build.SourcesDirectory):/onnxruntime_src \
+            --volume $(Build.BinariesDirectory)/Release:/build/Release \
+            --volume /data/models:/build/models:ro \
+            --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \
+            --volume /data/onnx:/data/onnx \
+            -e NVIDIA_TF32_OVERRIDE=0 \
+            $(Repository) \
+            /bin/bash -c '
+              nvidia-smi; \
+              /sbin/ldconfig -N -v $(sed "s/:/ /" <<< $LD_LIBRARY_PATH) 2>/dev/null | grep -E "libcudart.so|libcudnn.so|libnvinfer.so"; \
+              cat /usr/local/cuda/include/cuda.h | grep -m1 CUDA_VERSION; \
+              cat /usr/include/cudnn_version.h | grep CUDNN_MAJOR -m1 -A 2; \
+              ln -s /opt/python/cp38-cp38/bin/python3 /tmp/python3; \
+              /tmp/python3 -m pip install /build/Release/dist/*.whl; \
+              /tmp/python3 -u -c "from onnxruntime.capi._pybind_state import (OrtDevice as C_OrtDevice) ; \
+                        ort_device = C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0); \
+                        print(ort_device); print(ort_device.device_type(), C_OrtDevice.cuda()); \
+                        assert(ort_device.device_type()==1); assert(C_OrtDevice.cuda()==1);" \
+            '
+      displayName: 'Check GPU'
+
+    - task: CmdLine@2
+      inputs:
+        script: |
+          set -e -x
+          mkdir -p $HOME/.onnx
+          docker run --gpus all --shm-size=1g --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --rm \
+            --volume  $(Build.SourcesDirectory):/onnxruntime_src \
+            --volume $(Build.BinariesDirectory)/Release:/build/Release \
+            --volume /data/models:/build/models:ro \
+            --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \
+            --volume /data/onnx:/data/onnx \
+            -e NVIDIA_TF32_OVERRIDE=0 \
+            $(Repository) \
+            /bin/bash -c '
+              set -ex; \
+              cp /onnxruntime_src/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt /tmp/requirements.txt; \
+              ln -s /opt/python/cp38-cp38/bin/python3 /tmp/python3; \
+              /tmp/python3 -m pip install -r /tmp/requirements.txt; \
+              /tmp/python3 -m pip install /build/Release/dist/*.whl; \
+              cd /build/Release && xargs -a /build/Release/perms.txt chmod a+x; \
+              cd /onnxruntime_src/java && /onnxruntime_src/java/gradlew cmakeCheck -DcmakeBuildDir=/build/Release -DUSE_CUDA=1; \
+              cd /tmp; \
+              /tmp/python3 /onnxruntime_src/tools/ci_build/build.py \
+                --build_dir /build --config Release --test --skip_submodule_sync --build_shared_lib --parallel --use_binskim_compliant_compile_flags --build_wheel --enable_onnx_tests \
+                --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda --cudnn_home=/usr/local/cuda \
+                --enable_pybind --build_java --ctest_path "" ; \
+              '
+      displayName: 'Run Tests'
+
+    - template: templates/check_test_result.yml
+      parameters:
+        FileName: '$(Build.BinariesDirectory)/Release/onnxruntime_test_all.Release.results.xml'
+
+    - template: templates/clean-agent-build-directory-step.yml
diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml
index e75bb68a8bfeb..af2d722a6b90c 100644
--- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml
@@ -8,13 +8,19 @@ parameters:
 - name: TrtVersion
   displayName: TensorRT Version
   type: string
-  default: 8.6.1.6
+  default: 8.6.cuda_11_8_cudnn_8
   values:
-  - 8.4.1.5
-  - 8.5.1.1
-  - 8.6.1.6
+  - 8.4.cuda_11_6_cudnn_8
+  - 8.5.cuda_11_8_cudnn_8
+  - 8.6.cuda_11_8_cudnn_8
+  - 8.6.cuda_12_3_cudnn_9
   - BIN
 
+- name: UseTensorrtOssParser
+  displayName: Use TensorRT-OSS Parser (not compatible with BIN)
+  type: boolean
+  default: false
+
 - name: ModelGroups
   type: object
   default: 
@@ -22,10 +28,15 @@ parameters:
     - "partner-models"
 
 - name: MemTest
-  displayName: Run Memory Test
+  displayName: Run Memory Test and Concurrency Test
   type: boolean
   default: true
 
+- name: ConcurrencyTest
+  displayName: Specifies the number of concurrency model test to invoke simultaneously
+  type: string
+  default: 2
+
 - name: TrtEPOptions
   displayName: TensorRT EP options
   type: object
@@ -71,26 +82,38 @@ jobs:
 
     - name: image
       value: ort-image-$(Build.BuildId)
+    
+    - name: parser
+      ${{ if eq(parameters.UseTensorrtOssParser, true) }}:
+        value: --use_tensorrt_oss_parser $(parameters.UseTensorrtOssParser) }}
 
   steps:
-    - ${{ if eq(parameters.TrtVersion, 'BIN') }}:
+    - ${{ if and(eq(parameters.TrtVersion, 'BIN'), eq(parameters.UseTensorrtOssParser, false)) }}:
       - script: 'ls -al $(trtBinsDir)'
         displayName: 'Show available TensorRT .tar.gz packages'
 
-      - script: 'cp $(trtBinsDir)/TensorRT-$(trtVersion).Linux.x86_64-gnu.cuda-$(tarCudaVersion).cudnn$(tarCudnnVersion).tar.gz $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/'
+      - script: 'cp $(trtBinsDir)/TensorRT-$(trtVersion).Linux.x86_64-gnu.cuda-$(tarCudaVersion).tar.gz $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/'
         displayName: 'Copy TensorRT .tar.gz package into Docker build directory'
 
-      - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75 --install_bin --tar_cuda_version=$(tarCudaVersion) --tar_cudnn_version=$(tarCudnnVersion) --trt_bins_dir=.'
-        displayName: 'Install TensorRT from binaries and build latest ORT Image'
+      - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75 --install_bin --tar_cuda_version=$(tarCudaVersion) --trt_bins_dir=.'
+        displayName: 'Install TensorRT $(tarTrtVersion) from binaries and build latest ORT Image'
         workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build'
-    - ${{ else }}:
+    
+    # Build ORT with TensorRT built-in parser 
+    - ${{ if and(ne(parameters.TrtVersion, 'BIN'), eq(parameters.UseTensorrtOssParser, false)) }}:
       - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75'
-        displayName: 'Build latest ORT Image'
+        displayName: 'Build latest ORT Image with TensorRT built-in parser'
         workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build'
-        
+    
+    # Build ORT with TensorRT OSS parser 
+    - ${{ if and(ne(parameters.TrtVersion, 'BIN'), eq(parameters.UseTensorrtOssParser, true)) }}:
+      - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75 --use_tensorrt_oss_parser'
+        displayName: 'Build latest ORT Image with TensorRT OSS parser'
+        workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build'
+    
     - ${{ if eq(parameters.MemTest, true) }}:
-      - script: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/mem_test/run_mem_test_docker.sh -d $(image) -p $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/mem_test/ -w /code/ -l false'
-        displayName: 'Run Memory Test'
+      - script: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/mem_test/run_mem_test_docker.sh -d $(image) -p $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/mem_test/ -w /code/ -l false -c ${{ parameters.ConcurrencyTest }}'
+        displayName: 'Run Memory Test and Concurrency Test'
         workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/mem_test/'
 
     - ${{ each option in parameters.ModelGroups }}:
@@ -134,7 +157,7 @@ jobs:
         displayName: 'Check and Install Azure CLI'
 
       - task: AzureCLI@2
-        displayName: 'Azure CLI Post to Dashboard'
+        displayName: 'Post EP Perf Results to Dashboard'
         inputs:
           azureSubscription: AIInfraBuildOnnxRuntimeOSS
           scriptLocation: inlineScript
@@ -142,8 +165,8 @@ jobs:
           inlineScript: |
             short_hash=$(git rev-parse --short HEAD) &&
             commit_date=$(git log -1 --date=iso-strict --pretty=format:%cd) &&
-            python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/post.py -r $(Build.SourcesDirectory)/Artifact/result -c $short_hash -d $commit_date -u "$(reportUrl)?buildId=$(Build.BuildId)" -t $(trtVersion) -b $(branchName) --kusto_conn $(kustoConn) --database $(database)
-    
+            python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/post.py -r $(Build.SourcesDirectory)/Artifact -c $short_hash -d $commit_date -u "$(reportUrl)?buildId=$(Build.BuildId)" -t $(trtVersion) -b $(branchName) --kusto_conn $(kustoConn) --database $(database) $(parser)
+
     - template: templates/component-governance-component-detection-steps.yml
       parameters :
         condition : 'succeeded'
diff --git a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml
index 9cf7a3fb42397..8b58d958ba899 100644
--- a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml
@@ -109,6 +109,7 @@ jobs:
               --rocm_version=$(RocmVersion) \
               --rocm_home /opt/rocm \
               --nccl_home /opt/rocm \
+              --enable_nccl \
               --update \
               --build_dir /build \
               --build \
diff --git a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml
index c92fc93abba37..03e0274fc198a 100644
--- a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml
@@ -32,5 +32,5 @@ jobs:
   parameters:
     AgentPool : 'Linux-CPU-2019'
     JobName: 'Linux_CI_Dev'
-    RunDockerBuildArgs: '-o ubuntu20.04 -d openvino -v 2023.0.0 -x "--use_openvino CPU_FP32 --build_wheel"'
+    RunDockerBuildArgs: '-o ubuntu20.04 -d openvino -v 2024.0.0 -x "--use_openvino CPU_FP32 --build_wheel"'
     TimeoutInMinutes: 120
diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml
index 0312b70d2b1d5..8fa5bdbf90931 100644
--- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml
@@ -32,7 +32,7 @@ parameters:
 - name: QnnSdk
   displayName: QNN SDK version
   type: string
-  default: qnn-v2.18.0.240101
+  default: qnn-v2.19.2.240210
 
 jobs:
   - job: Build_QNN_EP
diff --git a/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml
index a3f56f5c448a9..f0a35d809c700 100644
--- a/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml
@@ -32,7 +32,7 @@ jobs:
   workspace:
     clean: all
   pool:
-    vmImage: 'macOS-13'
+    vmImage: 'macOS-latest'
   variables:
     MACOSX_DEPLOYMENT_TARGET: '11.0'
     TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)]
@@ -43,6 +43,8 @@ jobs:
     displayName: Install coreutils and ninja
 
   - template: templates/use-xcode-version.yml
+    parameters:
+      xcodeVersion: 14.2
 
   - template: templates/mac-build-step-with-cache.yml
     parameters:
diff --git a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml
index a1ca68c8279e7..255531681b039 100644
--- a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml
@@ -30,7 +30,7 @@ pr:
 jobs:
 - job: iOS_CI_on_Mac
   pool:
-    vmImage: 'macOS-13'
+    vmImage: 'macOS-latest'
   variables:
     PROTO_CACHE_DIR: $(Pipeline.Workspace)/proto_ccache
     ORT_CACHE_DIR: $(Pipeline.Workspace)/ort_ccache
@@ -39,7 +39,7 @@ jobs:
   steps:
     - template: templates/use-xcode-version.yml
       parameters:
-        xcodeVersion: 14.3
+        xcodeVersion: 14.2
     - template: templates/mac-build-step-with-cache.yml
       parameters:
         WithCache: true
diff --git a/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml
index 5fd15b64e03b6..881023e1c1186 100644
--- a/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml
@@ -53,7 +53,7 @@ stages:
     displayName: "Set common variables"
 
     pool:
-      vmImage: "macOS-13"
+      vmImage: "macOS-latest"
 
     timeoutInMinutes: 5
 
diff --git a/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml
index e8f4931d5ad9f..886bacf5aac4d 100644
--- a/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml
@@ -61,4 +61,4 @@ stages:
   parameters:
     NpmPackagingMode: ${{ variables.NpmPackagingMode }}
     BuildConfig: 'Release'
-    PoolName: 'onnxruntime-Linux-CPU-For-Android-CI'
+    PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU'
diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml
index 9393fb07d718a..3a3375a313ca5 100644
--- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml
+++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml
@@ -55,6 +55,9 @@ stages:
       - checkout: self
         clean: true
         submodules: recursive
+      - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
+        displayName: 'Clean Agent Directories'
+        condition: always()
 
       - powershell: |
           if($env:TELEMETRYGUID)
@@ -185,7 +188,7 @@ stages:
           displayName: 'Publish unit test results'
           inputs:
             testResultsFiles: '**\*.results.xml'
-            searchFolder: '$(Build.BinariesDirectory)'
+            searchFolder: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)'
             testRunTitle: 'Unit Test Run'
           condition: succeededOrFailed()
 
@@ -231,14 +234,7 @@ stages:
               searchPattern: '**/*.pdb'
               symbolServerType: teamServices
 
-      - ${{ if eq(parameters['DoCompliance'], 'true') }}:
-        - template: ../../templates/compliance.yml
-          parameters :
-            msbuildPlatform: ${{ parameters.sln_platform }}
 
-      - template: ../../templates/component-governance-component-detection-steps.yml
-        parameters :
-          condition : 'succeeded'
 
       # Node.js Publish
       - ${{ if eq(parameters['DoNodejsPack'], 'true') }}:
@@ -294,6 +290,12 @@ stages:
             targetPath: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\${{ parameters.sln_platform }}'
             artifactName: 'drop-onnxruntime-nodejs-win-${{ parameters.sln_platform }}-dml'
 
-      - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
-        displayName: 'Clean Agent Directories'
-        condition: always()
+
+      - ${{ if eq(parameters['DoCompliance'], 'true') }}:
+        - template: ../../templates/compliance.yml
+          parameters :
+            msbuildPlatform: ${{ parameters.sln_platform }}
+
+      - template: ../../templates/component-governance-component-detection-steps.yml
+        parameters :
+          condition : 'succeeded'
diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml
index bf1ba71b7b818..4ca122f639551 100644
--- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml
+++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml
@@ -46,7 +46,7 @@ stages:
             --build-arg PYTHON_VERSION=$(PythonVersion)
             --build-arg INSTALL_DEPS_EXTRA_ARGS=-tu
             --build-arg BUILD_UID=$(id -u)
-          Repository: onnxruntimetrainingcpubuild
+          Repository: onnxruntimetrainingcpubuild_$(PythonVersion)
 
       - task: CmdLine@2
         displayName: 'build onnxruntime'
@@ -63,7 +63,7 @@ stages:
               -e BUILD_BUILDNUMBER \
               -e ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION \
               -e DEFAULT_TRAINING_PACKAGE_DEVICE \
-              onnxruntimetrainingcpubuild \
+              onnxruntimetrainingcpubuild_$(PythonVersion) \
                 $(PythonManylinuxDir)/bin/python3 /onnxruntime_src/tools/ci_build/build.py \
                   --build_dir /build --cmake_generator Ninja \
                   --config Debug Release \
diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml
index d9ab85ee80ce3..539a61c021cfb 100644
--- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml
+++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml
@@ -8,15 +8,28 @@ resources:
     name: pypa/manylinux
     ref: 5eda9aded5462201e6310105728d33016e637ea7
 
+parameters:
+  - name: SpecificArtifact
+    displayName: Use Specific Artifact
+    type: boolean
+    default: false
+
+  - name: BuildId
+    displayName: Specific Artifact's BuildId
+    type: string
+    default: '0'
+
 stages:
 - template: templates/py-packaging-training-cuda-stage.yml
   parameters:
     build_py_parameters: --enable_training --update --build
     torch_version: '2.0.0'
-    opset_version: '15'
+    opset_version: '17'
     cuda_version: '11.8'
     cmake_cuda_architectures: 60;61;70;75;80;86
     docker_file: Dockerfile.manylinux2_28_training_cuda11_8
     agent_pool: Onnxruntime-Linux-GPU
     upload_wheel: 'yes'
     debug_build: false
+    SpecificArtifact: ${{ parameters.SpecificArtifact }}
+    BuildId: ${{ parameters.BuildId }}
diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml
index 422fb33eec5de..86dce7ae465fc 100644
--- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml
+++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml
@@ -13,7 +13,7 @@ stages:
   parameters:
     build_py_parameters: --enable_training --update --build
     torch_version: '2.1.0'
-    opset_version: '15'
+    opset_version: '17'
     cuda_version: '12.2'
     cmake_cuda_architectures: 70;75;80;86;90
     docker_file: Dockerfile.manylinux2_28_training_cuda12_2
diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml
index 3ec5400dacc65..bb4402faeb191 100644
--- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml
+++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml
@@ -17,7 +17,6 @@ stages:
 # Each group has 4 jobs that cover:
 # o Windows ARM64EC
 # o Windows ARM64
-# o Windows ARM
 # o Windows x64
 # o Windows x86
 # Now we don't have coverage for ARM64EC yet. Will add it.
@@ -35,20 +34,6 @@ stages:
     buildNodejs: false
     ort_build_pool_name: 'onnxruntime-Win-CPU-2022'
 
-- template: templates/win-ci.yml
-  parameters:
-    DoCompliance: false
-    DoEsrp: false
-    stage_name_suffix: CPU_arm_default
-    buildArch: x64
-    msbuildPlatform: arm
-    packageName: arm
-    buildparameter: --arm  --enable_onnx_tests --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe
-    runTests: false
-    buildJava: false
-    buildNodejs: false
-    ort_build_pool_name: 'onnxruntime-Win-CPU-2022'
-
 - template: templates/win-ci.yml
   parameters:
     DoCompliance: false
@@ -106,21 +91,6 @@ stages:
     buildNodejs: false
     ort_build_pool_name: 'onnxruntime-Win-CPU-2022'
 
-- template: templates/win-ci.yml
-  parameters:
-    DoCompliance: false
-    DoEsrp: false
-    stage_name_suffix: CPU_arm_wcos
-    artifact_name_suffix: '-wcos'
-    buildArch: x64
-    msbuildPlatform: arm
-    packageName: arm
-    buildparameter: --arm  --enable_onnx_tests --enable_wcos --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe
-    runTests: false
-    buildJava: false
-    buildNodejs: false
-    ort_build_pool_name: 'onnxruntime-Win-CPU-2022'
-
 - template: templates/win-ci.yml
   parameters:
     DoCompliance: false
@@ -181,21 +151,6 @@ stages:
     buildNodejs: false
     ort_build_pool_name: 'onnxruntime-Win-CPU-2022'
 
-- template: templates/win-ci.yml
-  parameters:
-    DoCompliance: false
-    DoEsrp: false
-    stage_name_suffix: CPU_arm_extension
-    artifact_name_suffix: '-extension'
-    buildArch: x64
-    msbuildPlatform: arm
-    packageName: arm
-    buildparameter: --arm --use_extensions  --enable_onnx_tests --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe
-    runTests: false
-    buildJava: false
-    buildNodejs: false
-    ort_build_pool_name: 'onnxruntime-Win-CPU-2022'
-
 - template: templates/win-ci.yml
   parameters:
     DoCompliance: false
@@ -462,6 +417,7 @@ stages:
     - template: templates/use-xcode-version.yml
       parameters:
         xcodeVersion: 14.3
+
     - script: |
         pip install -r tools/ci_build/github/apple/ios_packaging.requirements.txt
       displayName: "Install Python requirements"
@@ -478,4 +434,41 @@ stages:
           --framework_info_file "$(Build.BinariesDirectory)/ios_framework/xcframework_info.json" \
           --c_framework_dir "$(Build.BinariesDirectory)/ios_framework/framework_out" \
           --variant Mobile
-      displayName: "Test pod with iOS dynamic framework"
+      displayName: "Test pod with iOS framework"
+
+- stage: IosMinimalTrainingBuild
+  dependsOn: []
+  jobs:
+  - job: IosMinimalTrainingBuild
+    timeoutInMinutes: 120
+    pool:
+      vmImage: "macOS-13"
+
+    steps:
+    - task: UsePythonVersion@0
+      inputs:
+        versionSpec: "3.9"
+        addToPath: true
+        architecture: "x64"
+
+    - template: templates/use-xcode-version.yml
+      parameters:
+        xcodeVersion: 14.3
+
+    - script: |
+        pip install -r tools/ci_build/github/apple/ios_packaging.requirements.txt
+      displayName: "Install Python requirements"
+
+    - script: |
+        python tools/ci_build/github/apple/build_apple_framework.py \
+          --build_dir "$(Build.BinariesDirectory)/ios_framework" \
+          tools/ci_build/github/apple/test_minimal_training_ios_simulator_framework_build_settings.json
+      displayName: "Build iOS framework with minimal build and training enabled"
+
+    - script: |
+        python tools/ci_build/github/apple/test_apple_packages.py \
+          --framework_info_file "$(Build.BinariesDirectory)/ios_framework/xcframework_info.json" \
+          --c_framework_dir "$(Build.BinariesDirectory)/ios_framework/framework_out" \
+          --variant Training \
+          --skip_macos_test
+      displayName: "Test pod with iOS framework"
diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml
index aee42d3675087..20646d3ba4a26 100644
--- a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml
@@ -21,6 +21,15 @@ parameters:
     values:
       - 11.8
       - 12.2
+  - name: SpecificArtifact
+    displayName: Use Specific Artifact
+    type: boolean
+    default: false
+
+  - name: BuildId
+    displayName: Specific Artifact's BuildId
+    type: string
+    default: '0'
 
 resources:
   repositories:
@@ -36,4 +45,6 @@ stages:
       enable_linux_gpu: ${{ parameters.enable_linux_gpu }}
       enable_windows_gpu: ${{ parameters.enable_windows_gpu }}
       cmake_build_type: ${{ parameters.cmake_build_type }}
-      cuda_version: ${{ parameters.cuda_version }}
\ No newline at end of file
+      cuda_version: ${{ parameters.cuda_version }}
+      SpecificArtifact: ${{ parameters.SpecificArtifact }}
+      BuildId: ${{ parameters.BuildId }}
diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml
index 5349b1ca67ab1..6b0ae085fa4db 100644
--- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml
@@ -34,6 +34,11 @@ parameters:
   type: boolean
   default: true
 
+- name: enable_windows_x64_qnn
+  displayName: 'Whether Windows x86_64 package with QNN EP is built.'
+  type: boolean
+  default: true
+
 - name: build_py_parameters
   displayName: 'Specify extra build parameters'
   type: string
@@ -70,5 +75,6 @@ stages:
     enable_mac_cpu: ${{ parameters.enable_mac_cpu }}
     enable_linux_arm: ${{ parameters.enable_linux_arm }}
     enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }}
+    enable_windows_x64_qnn: ${{ parameters.enable_windows_x64_qnn }}
     build_py_parameters: ${{ parameters.build_py_parameters }}
     cmake_build_type: ${{ parameters.cmake_build_type }}
diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml
index b0509467e1689..9a38513d04a79 100644
--- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml
@@ -2,7 +2,7 @@ parameters:
 - name: QnnSdk
   displayName: QNN SDK Version
   type: string
-  default: qnn-v2.18.0.240101_win
+  default: qnn-v2.19.2.240210_win
 
 - name: build_config
   displayName: Build Configuration
diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml
index 8ca3d9148b514..064e2ea91d194 100644
--- a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml
+++ b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml
@@ -213,13 +213,6 @@ stages:
             PlatformsSupported: 'linux-x64'
             VerifyNugetSigning: false
 
-        - task: PublishPipelineArtifact@0
-          displayName: 'Publish Pipeline NuGet Artifact'
-          inputs:
-            artifactName: 'drop-signed-nuget-GPU'
-            targetPath: '$(Build.ArtifactStagingDirectory)'
-
-
         - task: MSBuild@1
           displayName: 'Clean C#'
           inputs:
@@ -241,6 +234,12 @@ stages:
           parameters:
             condition: 'succeeded'
 
+        - task: PublishPipelineArtifact@0
+          displayName: 'Publish Pipeline NuGet Artifact'
+          inputs:
+            artifactName: 'drop-signed-nuget-GPU'
+            targetPath: '$(Build.ArtifactStagingDirectory)'
+
         - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
           displayName: 'Clean Agent Directories'
           condition: always()
diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml
index f82c80d4d7e93..a2c1eeef632c1 100644
--- a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml
+++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml
@@ -34,72 +34,40 @@ parameters:
    - 11.8
    - 12.2
 
-stages:
-- stage: Python_Packaging
-  dependsOn: []
-  variables:
-  - name: docker_base_image
-    ${{ if eq(parameters.cuda_version, '11.8') }}:
-      value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8
-    ${{ if eq(parameters.cuda_version, '12.2') }}:
-      value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8
-  - name: linux_trt_version
-    ${{ if eq(parameters.cuda_version, '11.8') }}:
-      value: 8.6.1.6-1.cuda11.8
-    ${{ if eq(parameters.cuda_version, '12.2') }}:
-      value: 8.6.1.6-1.cuda12.0
-  - name: win_trt_home
-    ${{ if eq(parameters.cuda_version, '11.8') }}:
-      value: $(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8
-    ${{ if eq(parameters.cuda_version, '12.2') }}:
-      value: $(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0
-  - name: win_cuda_home
-    ${{ if eq(parameters.cuda_version, '11.8') }}:
-      value: $(Agent.TempDirectory)\v11.8
-    ${{ if eq(parameters.cuda_version, '12.2') }}:
-      value: $(Agent.TempDirectory)\v12.2
-  jobs:
-  - ${{ if eq(parameters.enable_windows_gpu, true) }}:
-      - template: ../templates/py-win-gpu.yml
-        parameters:
-          MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4'
-          PYTHON_VERSION: '3.8'
-          EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }}  --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80"
-          EP_NAME: gpu
-          CudaVersion: ${{ parameters.cuda_version }}
-
-      - template: ../templates/py-win-gpu.yml
-        parameters:
-          MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4'
-          PYTHON_VERSION: '3.9'
-          EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }}  --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80"
-          EP_NAME: gpu
-          CudaVersion: ${{ parameters.cuda_version }}
+- name: SpecificArtifact
+  displayName: Use Specific Artifact
+  type: boolean
+  default: false
 
-      - template: ../templates/py-win-gpu.yml
-        parameters:
-          MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4'
-          PYTHON_VERSION: '3.10'
-          EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }}  --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80"
-          EP_NAME: gpu
-          CudaVersion: ${{ parameters.cuda_version }}
+- name: BuildId
+  displayName: Specific Artifact's BuildId
+  type: string
+  default: '0'
 
-      - template: ../templates/py-win-gpu.yml
-        parameters:
-          MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4'
-          PYTHON_VERSION: '3.11'
-          EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }}  --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80"
-          EP_NAME: gpu
-          CudaVersion: ${{ parameters.cuda_version }}
+- name: PythonVersions
+  type: object
+  displayName: 'Python versions to build'
+  default:
+    - '3.8'
+    - '3.9'
+    - '3.10'
+    - '3.11'
+    - '3.12'
 
+stages:
+  - ${{ if eq(parameters.enable_windows_gpu, true) }}:
+    - ${{ each python_version in parameters.PythonVersions }}:
       - template: ../templates/py-win-gpu.yml
         parameters:
-          MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4'
-          PYTHON_VERSION: '3.12'
-          EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }}  --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80"
+          PYTHON_VERSION: ${{ python_version }}
           EP_NAME: gpu
           CudaVersion: ${{ parameters.cuda_version }}
-
+          SpecificArtifact: ${{ parameters.SpecificArtifact }}
+          BuildId: ${{ parameters.BuildId }}
+          ${{ if eq(parameters.cuda_version, '11.8') }}:
+            EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8 --cuda_home=$(Agent.TempDirectory)\v11.8  --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80"
+          ${{ if eq(parameters.cuda_version, '12.2') }}:
+            EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0 --cuda_home=$(Agent.TempDirectory)\v12.2  --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80"
 
   - ${{ if eq(parameters.enable_linux_gpu, true) }}:
       - template: ../templates/py-linux-gpu.yml
@@ -108,6 +76,10 @@ stages:
           machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU'
           extra_build_arg: ${{ parameters.build_py_parameters }}
           cmake_build_type: ${{ parameters.cmake_build_type }}
-          docker_base_image: ${{ variables.docker_base_image }}
-          trt_version: ${{ variables.linux_trt_version }}
           cuda_version: ${{ parameters.cuda_version }}
+          ${{ if eq(parameters.cuda_version, '11.8') }}:
+            docker_base_image: nvidia/cuda:11.8.0-cudnn8-devel-ubi8
+            trt_version: 8.6.1.6-1.cuda11.8
+          ${{ if eq(parameters.cuda_version, '12.2') }}:
+            docker_base_image: nvidia/cuda:12.2.2-cudnn8-devel-ubi8
+            trt_version: 8.6.1.6-1.cuda12.0
diff --git a/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml b/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml
index 733cafdeeb8c0..9822950127112 100644
--- a/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml
@@ -31,7 +31,7 @@ stages:
     timeoutInMinutes: 60
     workspace:
       clean: all
-    pool: onnxruntime-Linux-CPU-For-Android-CI
+    pool: onnxruntime-Ubuntu2204-AMD-CPU
     steps:
     - checkout: self
       clean: true
@@ -49,6 +49,7 @@ stages:
     - task: PythonScript@0
       displayName: 'Set variables from config file "${{ parameters.BuildConfigFile }}"'
       inputs:
+        pythonInterpreter: /usr/bin/python3
         scriptSource: inline
         script: |
           import json
diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml
index 1ba0b02560aca..0bb9fad6716b7 100644
--- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml
@@ -138,7 +138,8 @@ stages:
           --framework_info_file "$(Build.BinariesDirectory)/ios_framework/xcframework_info.json" \
           --c_framework_dir "$(Build.BinariesDirectory)/ios_framework/framework_out" \
           --variant Full \
-          --skip_macos_test
+          --skip_macos_test \
+          --mac_catalyst_enabled
       displayName: "Test Apple framework"
 
     - task: PublishBuildArtifacts@1
diff --git a/tools/ci_build/github/azure-pipelines/templates/check_test_result.yml b/tools/ci_build/github/azure-pipelines/templates/check_test_result.yml
new file mode 100644
index 0000000000000..1a68d415c44d6
--- /dev/null
+++ b/tools/ci_build/github/azure-pipelines/templates/check_test_result.yml
@@ -0,0 +1,20 @@
+parameters:
+- name: FileName
+  type: string
+
+steps:
+  - task: UsePythonVersion@0
+    inputs:
+      versionSpec: '3.x'
+      addToPath: true
+      architecture: 'x64'
+
+  - task: PythonScript@0
+    displayName: 'Check test result yml'
+    inputs:
+      scriptSource: 'inline'
+      script: |
+        with open('${{parameters.FileName}}', 'r') as file:
+          content = file.read()
+        assert 'data_onnx_opset' in content, "operator test not found in test result file"
+        assert 'models_zoo_opset' in content, "models_zoo model not found in test reuslt file"
diff --git a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml
index c2ef565a6e9ee..3d128fdb78eee 100644
--- a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml
@@ -5,10 +5,9 @@ parameters:
   default: 'succeeded' # could be 'ci_only', 'always', 'succeeded'
 
 steps:
-- ${{ if eq(variables['System.TeamProject'], 'Lotus') }}: 
-  - task: DeleteFiles@1
-    inputs:
-      contents: $(Build.BinariesDirectory)/*
+- ${{ if eq(variables['System.TeamProject'], 'Lotus') }}:
+  - powershell: |
+      Remove-Item $(Build.BinariesDirectory)/* -Recurse -Force
     displayName: 'Clean up build directory'
 
   - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0
diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml
index 95e34cd863915..4fd33b4f0bc09 100644
--- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml
@@ -11,7 +11,7 @@ steps:
       packageType: upack
       feed: '/7424c8e4-5c62-490e-95c4-79446f31017c'
       definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0'
-      version: 1.0.133
+      version: 1.0.145
       downloadPath: $(Build.BinariesDirectory)/deps
 
 # The private ADO project
@@ -22,7 +22,7 @@ steps:
       packageType: upack
       feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325'
       definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a'
-      version: 1.0.133
+      version: 1.0.145
       downloadPath: $(Build.BinariesDirectory)/deps
 
 # You can add more ADO accounts at here.
diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml
index 9516753d50113..864513bc4d671 100644
--- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml
@@ -93,8 +93,17 @@ steps:
         $ccache_parent_dir = (Split-Path -parent $ccache_path)
         Copy-Item "C:\ProgramData\chocolatey\lib\ccache\tools\ccache-4.7.4-windows-x86_64\ccache.exe" -Destination "C:\ProgramData\chocolatey\bin\cl.exe"
         Get-ChildItem $ccache_parent_dir
-        ccache --version
       }
+
+      "ccache info:"
+      ccache --version
+      ccache --show-config
+
+      "cl.exe from path: $((Get-Command cl).Path). Version:"
+      (cl.exe -?) -match 'Compiler Version'
+      "C:\ProgramData\chocolatey\bin\cl.exe version:"
+      (C:\ProgramData\chocolatey\bin\cl.exe -?) -match 'Compiler Version'
+
     displayName: Install ccache and update PATH to use linked versions of gcc, cc, etc
 
   - ${{ if eq(parameters.WITHCACHE, true) }}:
diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml
index dd703f3199d9b..30e427a18509d 100644
--- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml
@@ -148,12 +148,9 @@ jobs:
       Get-Volume $("$(Build.BinariesDirectory)")[0]
     displayName: check disk size
 
-  - task: DeleteFiles@1
-    displayName: 'Delete intermedia files from $(Build.BinariesDirectory)\${{ parameters.BuildConfig }}'
-    inputs:
-      SourceFolder: '$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}'
-      Contents: |
-        **/*.obj
+  - powershell: |
+      Remove-Item "$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}" -Include "*.obj" -Recurse
+    displayName: 'Delete intermediate files from $(Build.BinariesDirectory)\${{ parameters.BuildConfig }}'
 
   - powershell: |
       Get-Volume $("$(Build.BinariesDirectory)")[0]
@@ -221,14 +218,6 @@ jobs:
         workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}\${{ parameters.BuildConfig }}'
         displayName: 'Run tests'
 
-      - task: PublishTestResults@2
-        displayName: 'Publish unit test results'
-        inputs:
-          testResultsFiles: '**/*.results.xml'
-          searchFolder: '$(Build.BinariesDirectory)/${{ parameters.BuildConfig }}'
-          testRunTitle: 'Unit Test Run'
-        condition: succeededOrFailed()
-
   - ${{ if eq(parameters.GenerateDocumentation, true) }}:
     - task: PythonScript@0
       displayName: 'Generate documentation'
@@ -251,4 +240,4 @@ jobs:
     condition: and(failed(), eq(variables['DocUpdateNeeded'], 'true'))
     inputs:
       pathtoPublish: '$(Build.SourcesDirectory)/docs/ContribOperators.md'
-      artifactName: 'ContribOperators.md'
+      artifactName: 'ContribOperators.md'
\ No newline at end of file
diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml b/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml
index e788e4b3dddaa..a4d5a73118ea2 100644
--- a/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml
@@ -31,6 +31,10 @@ steps:
     node -e "a=require('child_process').execSync('git diff --name-only').toString();if(a)throw new Error('Following source files are not formatted: (did you run \"npm run format\"?)\n'+a)"
   workingDirectory: '$(Build.SourcesDirectory)/js'
   displayName: 'Check unformatted files'
+- script: |
+    npx typedoc --emit none --treatWarningsAsErrors
+  workingDirectory: '$(Build.SourcesDirectory)/js/common'
+  displayName: 'TypeDoc Validation'
 - script: |
     npm run build:doc
   workingDirectory: '$(Build.SourcesDirectory)/js/web'
diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml
index 080079388a76c..945fbb7c4a094 100644
--- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml
@@ -71,7 +71,7 @@ stages:
         ${{ if eq(parameters.DoESRP, true)}}:
           vmImage: 'macOS-12'
         ${{ else }}:
-          vmImage: 'macOS-13'
+          vmImage: 'macOS-latest'
       steps:
       - checkout: none
       - template: flex-downloadPipelineArtifact.yml
diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml
index fd2113502478a..9e192716c3ffd 100644
--- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml
@@ -37,7 +37,7 @@ jobs:
     PROTO_CACHE_DIR: $(Pipeline.Workspace)/ccache_proto
     ORT_CACHE_DIR: $(Pipeline.Workspace)/ccache_ort
   pool:
-    vmImage: 'macOS-13'
+    vmImage: 'macOS-latest'
   timeoutInMinutes: 300
   steps:
   - checkout: self
@@ -55,6 +55,8 @@ jobs:
   - template: set-version-number-variables-step.yml
 
   - template: use-xcode-version.yml
+    parameters:
+      xcodeVersion: 14.2
 
   - template: mac-build-step-with-cache.yml
     parameters:
diff --git a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml
index cf39be23cbdaf..bfee58e6e5ef9 100644
--- a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml
@@ -61,21 +61,6 @@ stages:
     buildJava: false
     buildNodejs: false
 
-- template: win-ci.yml
-  parameters:
-    DoCompliance: ${{ parameters.DoCompliance }}
-    DoEsrp: ${{ parameters.DoEsrp }}
-    stage_name_suffix: Training_CPU_arm_${{ parameters.BuildVariant }}
-    artifact_name_suffix: -training
-    buildArch: x64
-    msbuildPlatform: arm
-    packageName: arm
-    buildparameter: --arm ${{ parameters.AdditionalBuildFlags }}  ${{ parameters.AdditionalWinBuildFlags}} --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe
-    runTests: false
-    buildJava: false
-    buildNodejs: false
-    ort_build_pool_name: onnxruntime-Win-CPU-2022
-
 - template: win-ci.yml
   parameters:
     DoCompliance: ${{ parameters.DoCompliance }}
@@ -127,7 +112,6 @@ stages:
   - Linux_C_API_Packaging_Training_CPU
   - Windows_Packaging_Training_CPU_x86_${{ parameters.BuildVariant }}
   - Windows_Packaging_Training_CPU_x64_${{ parameters.BuildVariant }}
-  - Windows_Packaging_Training_CPU_arm_${{ parameters.BuildVariant }}
   - Windows_Packaging_Training_CPU_arm64_${{ parameters.BuildVariant }}
   - Android_Java_API_AAR_Packaging_Training_Full
   condition: succeeded()
@@ -164,12 +148,6 @@ stages:
         artifactName: 'onnxruntime-training-win-arm64'
         targetPath: '$(Build.BinariesDirectory)/nuget-artifact'
 
-    - task: DownloadPipelineArtifact@0
-      displayName: 'Download win-arm Pipeline Artifact'
-      inputs:
-        artifactName: 'onnxruntime-training-win-arm'
-        targetPath: '$(Build.BinariesDirectory)/nuget-artifact'
-
     - task: DownloadPipelineArtifact@0
       displayName: 'Download linux-x64 Pipeline Artifact'
       inputs:
diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml
index 8cc48aac7a3b9..318ffd21febf5 100644
--- a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml
@@ -35,62 +35,66 @@ parameters:
   values:
    - 11.8
    - 12.2
-jobs:
-- job: Linux_py_GPU_Wheels_${{ parameters.arch }}
-  timeoutInMinutes: 240
-  workspace:
-    clean: all
-  pool: ${{ parameters.machine_pool }}
-  variables:
-    # The build machine pool doesn't have dotnet, so it can't run CG.
-    - name: skipComponentGovernanceDetection
-      value: true
-    - name: extra_build_args
-      ${{ if ne(parameters.extra_build_arg, '') }}:
-        value: -x ${{ parameters.extra_build_arg }}
-      ${{ if eq(parameters.extra_build_arg, '') }}:
-        value: ''
-  steps:
-    - checkout: self
-      clean: true
-      submodules: recursive
 
-    - template: set-nightly-build-option-variable-step.yml
+stages:
+- stage: Linux_py_GPU_Wheels_${{ parameters.arch }}
+  dependsOn: []
+  jobs:
+  - job: Linux_py_GPU_Wheels_${{ parameters.arch }}
+    timeoutInMinutes: 240
+    workspace:
+      clean: all
+    pool: ${{ parameters.machine_pool }}
+    variables:
+      # The build machine pool doesn't have dotnet, so it can't run CG.
+      - name: skipComponentGovernanceDetection
+        value: true
+      - name: extra_build_args
+        ${{ if ne(parameters.extra_build_arg, '') }}:
+          value: -x ${{ parameters.extra_build_arg }}
+        ${{ if eq(parameters.extra_build_arg, '') }}:
+          value: ''
+    steps:
+      - checkout: self
+        clean: true
+        submodules: recursive
 
-    - template: get-docker-image-steps.yml
-      parameters:
-        Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda
-        Context: tools/ci_build/github/linux/docker
-        DockerBuildArgs: "
-        --network=host 
-        --build-arg BASEIMAGE=${{ parameters.docker_base_image }}
-        --build-arg TRT_VERSION=${{ parameters.trt_version }}
-        --build-arg BUILD_UID=$( id -u )
-        --build-arg PLATFORM=${{ parameters.arch }}
-        "
-        Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }}
+      - template: set-nightly-build-option-variable-step.yml
 
+      - template: get-docker-image-steps.yml
+        parameters:
+          Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda
+          Context: tools/ci_build/github/linux/docker
+          DockerBuildArgs: "
+          --network=host
+          --build-arg BASEIMAGE=${{ parameters.docker_base_image }}
+          --build-arg TRT_VERSION=${{ parameters.trt_version }}
+          --build-arg BUILD_UID=$( id -u )
+          --build-arg PLATFORM=${{ parameters.arch }}
+          "
+          Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }}
 
-    - task: Bash@3
-      displayName: 'Build Python Wheel'
-      inputs:
-        targetType: filePath
-        filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh
-        arguments: -i onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} -d "GPU" -c ${{ parameters.cmake_build_type }} $(extra_build_args)
 
-    - task: PublishBuildArtifacts@1
-      displayName: 'Publish Artifact: ONNXRuntime python wheel'
-      inputs:
-        PathtoPublish: '$(Build.BinariesDirectory)/dist'
-        ArtifactName: onnxruntime_gpu
+      - task: Bash@3
+        displayName: 'Build Python Wheel'
+        inputs:
+          targetType: filePath
+          filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh
+          arguments: -i onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} -d "GPU" -c ${{ parameters.cmake_build_type }} $(extra_build_args)
 
-    - task: PublishPipelineArtifact@0
-      displayName: 'Publish Test Binaries'
-      inputs:
-        artifactName: 'drop-linux-gpu-${{ parameters.arch }}'
-        targetPath: '$(Build.BinariesDirectory)/Release'
+      - task: PublishBuildArtifacts@1
+        displayName: 'Publish Artifact: ONNXRuntime python wheel'
+        inputs:
+          PathtoPublish: '$(Build.BinariesDirectory)/dist'
+          ArtifactName: onnxruntime_gpu
 
+      - task: PublishPipelineArtifact@0
+        displayName: 'Publish Test Binaries'
+        inputs:
+          artifactName: 'drop-linux-gpu-${{ parameters.arch }}'
+          targetPath: '$(Build.BinariesDirectory)/Release'
 
-    - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
-      displayName: 'Clean Agent Directories'
-      condition: always()
+
+      - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
+        displayName: 'Clean Agent Directories'
+        condition: always()
diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml
index 146e3e58444c1..5ac5bda8b0964 100644
--- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml
@@ -40,6 +40,11 @@ parameters:
   type: boolean
   default: true
 
+- name: enable_windows_x64_qnn
+  displayName: 'Whether Windows x86_64 package with QNN EP is built.'
+  type: boolean
+  default: true
+
 # TODO: Now the Windows jobs use a different cmake build type. Consider to merge it.
 - name: cmake_build_type
   type: string
@@ -459,3 +464,9 @@ stages:
           QNN_SDK: 'qnn-v2.18.0.240101_win'
           PYTHON_VERSION: '3.11'
           NUMPY_VERSION: '1.25.2'
+
+  - ${{ if eq(parameters.enable_windows_x64_qnn, true) }}:
+      - template: py-win-x64-qnn.yml
+        parameters:
+          MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU'
+          QNN_SDK: 'qnn-v2.18.0.240101_win'
diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage-steps.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage-steps.yml
new file mode 100644
index 0000000000000..024b9b45591ba
--- /dev/null
+++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage-steps.yml
@@ -0,0 +1,230 @@
+parameters:
+  build_py_parameters: ''
+  torch_version: ''
+  opset_version: ''
+  cuda_version: ''
+  cmake_cuda_architectures: ''
+  docker_file: ''
+  upload_wheel: ''
+  debug_build:  ''
+  python_version: ''
+  stage_name: ''
+  SpecificArtifact: false
+  BuildId: '0'
+
+stages:
+  - stage: Build_${{ parameters.stage_name }}
+    variables:
+      - name: isMain
+        value: ${{ or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-')) }}
+      - name: finalStorage
+        ${{ if eq(variables['isMain'], 'true') }}:
+          value: '--final_storage'
+        ${{ else }}:
+          value: ''
+      - name: buildConfig
+        ${{ if eq(parameters['debug_build'], 'true') }}:
+          value: 'Debug'
+        ${{ else }}:
+          value: 'Release'
+      - name: PythonVersion
+        value: ${{ parameters.python_version }}
+      - name: Repository
+        value: onnxruntimetraininggpubuild_${{ parameters.python_version }}
+    dependsOn: []
+
+    jobs:
+    - job: Build
+      pool: onnxruntime-Ubuntu2204-AMD-CPU
+      steps:
+        - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
+          displayName: 'Clean Agent Directories'
+          condition: always()
+
+        - task: CmdLine@2
+          displayName: 'check variables'
+          inputs:
+            script: |
+              echo "Branch is "${{ variables['Build.SourceBranch'] }} && \
+              echo "isMain is "${{ variables['isMain'] }} && \
+              echo "final_storage is "${{ variables['finalStorage'] }}
+
+        - checkout: self
+          clean: true
+          submodules: recursive
+
+        - template: set-python-manylinux-variables-step.yml
+
+        - template: get-docker-image-steps.yml
+          parameters:
+            Dockerfile: tools/ci_build/github/linux/docker/${{ parameters.docker_file }}
+            Context: tools/ci_build/github/linux/docker
+            DockerBuildArgs: >-
+              --build-arg TORCH_VERSION=${{ parameters.torch_version }}
+              --build-arg OPSET_VERSION=${{ parameters.opset_version }}
+              --build-arg PYTHON_VERSION=${{ parameters.python_version }}
+              --build-arg INSTALL_DEPS_EXTRA_ARGS=-tu
+              --build-arg BUILD_UID=$(id -u)
+              --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64
+              --build-arg DEVTOOLSET_ROOTPATH=/usr
+              --build-arg PREPEND_PATH=/usr/local/cuda/bin:
+              --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64
+            Repository: $(Repository)
+
+        - task: CmdLine@2
+          displayName: 'build onnxruntime'
+          inputs:
+            script: |
+              set -e -x
+              mkdir -p $HOME/.onnx
+              docker run --rm -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" \
+                --volume /data/onnx:/data/onnx:ro \
+                --volume $(Build.SourcesDirectory):/onnxruntime_src \
+                --volume $(Build.BinariesDirectory):/build \
+                --volume /data/models:/build/models:ro \
+                --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \
+                -e NVIDIA_VISIBLE_DEVICES=all \
+                -e NIGHTLY_BUILD \
+                -e DEFAULT_TRAINING_PACKAGE_DEVICE \
+                -e BUILD_BUILDNUMBER \
+                -e ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION \
+                $(Repository) \
+                  $(PythonManylinuxDir)/bin/python3 /onnxruntime_src/tools/ci_build/build.py \
+                    --build_dir /build \
+                    --config ${{ variables['buildConfig'] }} \
+                    --skip_submodule_sync \
+                    --parallel --use_binskim_compliant_compile_flags \
+                    --build_wheel \
+                    --enable_onnx_tests \
+                    ${{ parameters.build_py_parameters }} \
+                    --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=${{ parameters.cmake_cuda_architectures }}' onnxruntime_BUILD_UNIT_TESTS=OFF \
+                    --use_cuda --cuda_version=${{ parameters.cuda_version }} --cuda_home=/usr/local/cuda-${{ parameters.cuda_version }} --cudnn_home=/usr/local/cuda-${{ parameters.cuda_version }};
+            workingDirectory: $(Build.SourcesDirectory)
+
+        - task: CopyFiles@2
+          displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)'
+          inputs:
+            SourceFolder: '$(Build.BinariesDirectory)'
+            Contents: "${{ variables['buildConfig'] }}/dist/*.whl"
+            TargetFolder: '$(Build.ArtifactStagingDirectory)'
+
+        - task: PublishBuildArtifacts@1
+          displayName: 'Publish Artifact: ONNXRuntime python wheel and documentation'
+          inputs:
+            ArtifactName: "onnxruntime_gpu_${{ variables['buildConfig'] }}_${{ parameters.python_version }}"
+
+        - template: component-governance-component-detection-steps.yml
+          parameters:
+            condition: 'succeeded'
+
+        - template: clean-agent-build-directory-step.yml
+
+  - stage: Test_${{ parameters.stage_name }}
+    variables:
+      - name: isMain
+        value: ${{ or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-')) }}
+      - name: finalStorage
+        ${{ if eq(variables['isMain'], 'true') }}:
+          value: '--final_storage'
+        ${{ else }}:
+          value: ''
+      - name: buildConfig
+        ${{ if eq(parameters['debug_build'], 'true') }}:
+          value: 'Debug'
+        ${{ else }}:
+          value: 'Release'
+      - name: PythonVersion
+        value: ${{ parameters.python_version }}
+      - name: Repository
+        value: onnxruntimetraininggpubuild_${{ parameters.python_version }}
+    dependsOn: Build_${{ parameters.stage_name }}
+    jobs:
+    - job: Test_GPU
+      pool: Onnxruntime-Linux-GPU
+      steps:
+        - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
+          displayName: 'Clean Agent Directories'
+          condition: always()
+
+        - checkout: self
+          clean: true
+          submodules: none
+
+        - template: set-python-manylinux-variables-step.yml
+
+        - template: flex-downloadPipelineArtifact.yml
+          parameters:
+            ArtifactName: "onnxruntime_gpu_${{ variables['buildConfig'] }}_${{ parameters.python_version }}"
+            StepName: 'Download Pipeline Artifact - Linux Training Build'
+            TargetPath: '$(Build.ArtifactStagingDirectory)'
+            SpecificArtifact: ${{ parameters.SpecificArtifact }}
+            BuildId: ${{ parameters.BuildId }}
+
+        - script: |
+            set -e -x
+            whlfilename=$(ls $(Build.ArtifactStagingDirectory)/Release/dist/*.whl | head -n 1)  ; \
+            echo $whlfilename ; du -sh $whlfilename ; \
+            (( $(wc -c < "$whlfilename") -  300*1024*1024 < 0 )) ||  ( echo 'Wheel size bigger than 300M'; exit 1)
+          displayName: 'Check wheel size'
+          continueOnError: true
+
+        - template: get-docker-image-steps.yml
+          parameters:
+            Dockerfile: tools/ci_build/github/linux/docker/${{ parameters.docker_file }}
+            Context: tools/ci_build/github/linux/docker
+            UpdateDepsTxt: false
+            DockerBuildArgs: >-
+              --build-arg TORCH_VERSION=${{ parameters.torch_version }}
+              --build-arg OPSET_VERSION=${{ parameters.opset_version }}
+              --build-arg PYTHON_VERSION=${{ parameters.python_version }}
+              --build-arg INSTALL_DEPS_EXTRA_ARGS=-tu
+              --build-arg BUILD_UID=$(id -u)
+              --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64
+              --build-arg DEVTOOLSET_ROOTPATH=/usr
+              --build-arg PREPEND_PATH=/usr/local/cuda/bin:
+              --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64
+            Repository: $(Repository)
+
+        - bash: tools/ci_build/github/linux/docker/scripts/training/azure_scale_set_vm_mount_test_data.sh -p $(orttrainingtestdatascus-storage-key) -s "//orttrainingtestdatascus.file.core.windows.net/mnist" -d "/mnist"
+          displayName: 'Mount MNIST'
+          condition: succeededOrFailed()
+          workingDirectory: $(Build.SourcesDirectory)
+
+        - bash: tools/ci_build/github/linux/docker/scripts/training/azure_scale_set_vm_mount_test_data.sh -p $(orttrainingtestdatascus-storage-key) -s "//orttrainingtestdatascus.file.core.windows.net/bert-data" -d "/bert_data"
+          displayName: 'Mount bert-data'
+          condition: succeededOrFailed()
+          workingDirectory: $(Build.SourcesDirectory)
+
+        - bash: tools/ci_build/github/linux/docker/scripts/training/azure_scale_set_vm_mount_test_data.sh -p $(orttrainingtestdatascus-storage-key) -s "//orttrainingtestdatascus.file.core.windows.net/hf-models-cache" -d "/hf_models_cache"
+          displayName: 'Mount hf-models-cache'
+          condition: succeededOrFailed()
+          workingDirectory: $(Build.SourcesDirectory)
+
+        - task: CmdLine@2
+          displayName: 'test ortmodule'
+          inputs:
+            script: |
+              set -ex ; \
+              whlfilename=$(ls $(Build.ArtifactStagingDirectory)/Release/dist/*.whl | head -n 1) ; \
+              echo $whlfilename ; \
+              basefilename=$(basename $whlfilename) ; \
+              docker run --rm \
+                --gpus all \
+                -e NVIDIA_VISIBLE_DEVICES=all \
+                --volume $(Build.ArtifactStagingDirectory):/build \
+                --volume /mnist:/mnist \
+                --volume /bert_data:/bert_data \
+                --volume /hf_models_cache:/hf_models_cache \
+                $(Repository) \
+                  bash -c " $(PythonManylinuxDir)/bin/python3 -m pip install /build/Release/dist/$basefilename && $(PythonManylinuxDir)/bin/python3 -m onnxruntime.training.ortmodule.torch_cpp_extensions.install " ;
+            workingDirectory: $(Build.SourcesDirectory)
+
+        - task: CmdLine@2
+          displayName: 'Upload wheel'
+          condition: and(succeeded(), and(eq(variables['UploadWheel'], 'yes'), ne(variables['ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION'], 'true')))
+          inputs:
+            script: |
+              set -e -x
+              whlfilename=$(ls $(Build.ArtifactStagingDirectory)/Release/dist/*.whl | head -n 1) ; \
+              python3 tools/ci_build/upload_python_package_to_azure_storage.py \
+                  --python_wheel_path $whlfilename ${{ variables['finalStorage'] }}
diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml
index c6921e151a029..f7ecc3cf84e48 100644
--- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml
@@ -47,183 +47,42 @@ parameters:
   type: boolean
   default: false
 
-stages:
-- stage: "Cuda_Python_Packaging_debug_${{ parameters.debug_build }}"
-
-  variables:
-    - name: isMain
-      value: ${{ or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-')) }}
-    - name: finalStorage
-      ${{ if eq(variables['isMain'], 'true') }}:
-        value: '--final_storage'
-      ${{ else }}:
-        value: ''
-    - name: buildConfig
-      ${{ if eq(parameters['debug_build'], 'true') }}:
-        value: 'Debug'
-      ${{ else }}:
-        value: 'Release'
-
-  dependsOn: []
-
-  jobs:
-    - job: Linux_py_Training_Cuda_Wheels
-      timeoutInMinutes: 180
-      workspace:
-        clean: all
-      pool: ${{ parameters.agent_pool }}
-      strategy:
-        matrix:
-          Python38:
-            PythonVersion: '3.8'
-            TorchVersion: ${{ parameters.torch_version }}
-            OpsetVersion: ${{ parameters.opset_version }}
-            CudaVersion: ${{ parameters.cuda_version }}
-            UploadWheel: ${{ parameters.upload_wheel }}
-          Python39:
-            PythonVersion: '3.9'
-            TorchVersion: ${{ parameters.torch_version }}
-            OpsetVersion: ${{ parameters.opset_version }}
-            CudaVersion: ${{ parameters.cuda_version }}
-            UploadWheel: ${{ parameters.upload_wheel }}
-          Python310:
-            PythonVersion: '3.10'
-            TorchVersion: ${{ parameters.torch_version }}
-            OpsetVersion: ${{ parameters.opset_version }}
-            CudaVersion: ${{ parameters.cuda_version }}
-            UploadWheel: ${{ parameters.upload_wheel }}
-          Python311:
-            PythonVersion: '3.11'
-            TorchVersion: ${{ parameters.torch_version }}
-            OpsetVersion: ${{ parameters.opset_version }}
-            CudaVersion: ${{ parameters.cuda_version }}
-            UploadWheel: ${{ parameters.upload_wheel }}
-# TODO: enable this when we have torch support pyton 3.12
-#          Python312:
-#            PythonVersion: '3.12'
-#            TorchVersion: ${{ parameters.torch_version }}
-#            OpsetVersion: ${{ parameters.opset_version }}
-#            CudaVersion: ${{ parameters.cuda_version }}
-#            UploadWheel: ${{ parameters.upload_wheel }}
-
-      steps:
-      - task: CmdLine@2
-        displayName: 'check variables'
-        inputs:
-          script: |
-            echo "Branch is "${{ variables['Build.SourceBranch'] }} && \
-            echo "isMain is "${{ variables['isMain'] }} && \
-            echo "final_storage is "${{ variables['finalStorage'] }}
-
-      - checkout: self
-        clean: true
-        submodules: recursive
-
-      - template: set-python-manylinux-variables-step.yml
-
-      - template: get-docker-image-steps.yml
-        parameters:
-          Dockerfile: tools/ci_build/github/linux/docker/${{ parameters.docker_file }}
-          Context: tools/ci_build/github/linux/docker
-          DockerBuildArgs: >-
-            --build-arg TORCH_VERSION=$(TorchVersion)
-            --build-arg OPSET_VERSION=$(OpsetVersion)
-            --build-arg PYTHON_VERSION=$(PythonVersion)
-            --build-arg INSTALL_DEPS_EXTRA_ARGS=-tu
-            --build-arg BUILD_UID=$(id -u)
-            --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64
-            --build-arg DEVTOOLSET_ROOTPATH=/usr
-            --build-arg PREPEND_PATH=/usr/local/cuda/bin:
-            --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64
-          Repository: onnxruntimetraininggpubuild
-
-      - bash: tools/ci_build/github/linux/docker/scripts/training/azure_scale_set_vm_mount_test_data.sh -p $(orttrainingtestdatascus-storage-key) -s "//orttrainingtestdatascus.file.core.windows.net/mnist" -d "/mnist"
-        displayName: 'Mount MNIST'
-        condition: succeededOrFailed()
-
-      - bash: tools/ci_build/github/linux/docker/scripts/training/azure_scale_set_vm_mount_test_data.sh -p $(orttrainingtestdatascus-storage-key) -s "//orttrainingtestdatascus.file.core.windows.net/bert-data" -d "/bert_data"
-        displayName: 'Mount bert-data'
-        condition: succeededOrFailed()
-
-      - bash: tools/ci_build/github/linux/docker/scripts/training/azure_scale_set_vm_mount_test_data.sh -p $(orttrainingtestdatascus-storage-key) -s "//orttrainingtestdatascus.file.core.windows.net/hf-models-cache" -d "/hf_models_cache"
-        displayName: 'Mount hf-models-cache'
-        condition: succeededOrFailed()
-
-      - task: CmdLine@2
-        displayName: 'build onnxruntime'
-        inputs:
-          script: |
-            set -e -x
-            mkdir -p $HOME/.onnx
-            docker run --rm -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" \
-              --volume /data/onnx:/data/onnx:ro \
-              --volume $(Build.SourcesDirectory):/onnxruntime_src \
-              --volume $(Build.BinariesDirectory):/build \
-              --volume /data/models:/build/models:ro \
-              --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \
-              -e NVIDIA_VISIBLE_DEVICES=all \
-              -e NIGHTLY_BUILD \
-              -e DEFAULT_TRAINING_PACKAGE_DEVICE \
-              -e BUILD_BUILDNUMBER \
-              -e ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION \
-              onnxruntimetraininggpubuild \
-                $(PythonManylinuxDir)/bin/python3 /onnxruntime_src/tools/ci_build/build.py \
-                  --build_dir /build \
-                  --config ${{ variables['buildConfig'] }} \
-                  --skip_submodule_sync \
-                  --parallel --use_binskim_compliant_compile_flags \
-                  --build_wheel \
-                  --enable_onnx_tests \
-                  ${{ parameters.build_py_parameters }} \
-                  --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=${{ parameters.cmake_cuda_architectures }}' onnxruntime_BUILD_UNIT_TESTS=OFF \
-                  --use_cuda --cuda_version=$(CudaVersion) --cuda_home=/usr/local/cuda-$(CudaVersion) --cudnn_home=/usr/local/cuda-$(CudaVersion) ;
-          workingDirectory: $(Build.SourcesDirectory)
-
-      - task: CmdLine@2
-        displayName: 'test ortmodule'
-        inputs:
-          script: |
-            rm -rf $(Build.BinariesDirectory)/${{ variables['buildConfig'] }}/onnxruntime/ && \
-            files=($(Build.BinariesDirectory)/${{ variables['buildConfig'] }}/dist/*.whl) && \
-            echo ${files[0]} && \
-            whlfilename=$(basename ${files[0]}) && \
-            echo $whlfilename && \
-            docker run --rm \
-              --gpus all \
-              -e NVIDIA_VISIBLE_DEVICES=all \
-              --volume $(Build.BinariesDirectory):/build \
-              --volume /mnist:/mnist \
-              --volume /bert_data:/bert_data \
-              --volume /hf_models_cache:/hf_models_cache \
-              onnxruntimetraininggpubuild \
-                bash -c " $(PythonManylinuxDir)/bin/python3 -m pip install /build/${{ variables['buildConfig'] }}/dist/$whlfilename && $(PythonManylinuxDir)/bin/python3 -m onnxruntime.training.ortmodule.torch_cpp_extensions.install " ;
-          workingDirectory: $(Build.SourcesDirectory)
-
-      - task: CopyFiles@2
-        displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)'
-        inputs:
-          SourceFolder: '$(Build.BinariesDirectory)'
-          Contents: "${{ variables['buildConfig'] }}/dist/*.whl"
-          TargetFolder: '$(Build.ArtifactStagingDirectory)'
-
-      - task: PublishBuildArtifacts@1
-        displayName: 'Publish Artifact: ONNXRuntime python wheel and documentation'
-        inputs:
-          ArtifactName: "onnxruntime_gpu_${{ variables['buildConfig'] }}"
-
-      - task: CmdLine@2
-        displayName: 'Upload wheel'
-        condition: and(succeeded(), and(eq(variables['UploadWheel'], 'yes'), ne(variables['ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION'], 'true')))
-        inputs:
-          script: |
-            set -e -x
-            files=($(Build.ArtifactStagingDirectory)/${{ variables['buildConfig'] }}/dist/*.whl) && \
-            echo ${files[0]} && \
-            python3 tools/ci_build/upload_python_package_to_azure_storage.py \
-                --python_wheel_path ${files[0]} ${{ variables['finalStorage'] }}
+- name: SpecificArtifact
+  displayName: Use Specific Artifact
+  type: boolean
+  default: false
 
-      - template: component-governance-component-detection-steps.yml
-        parameters:
-          condition: 'succeeded'
+- name: BuildId
+  displayName: Specific Artifact's BuildId
+  type: string
+  default: '0'
+
+- name: PythonVersionList
+  displayName: Python Version List
+  type: object
+  default:
+    - name: '38'
+      version: '3.8'
+    - name: '39'
+      version: '3.9'
+    - name: '310'
+      version: '3.10'
+    - name: '311'
+      version: '3.11'
 
-      - template: clean-agent-build-directory-step.yml
+stages:
+- ${{ each python_version in parameters.PythonVersionList }}:
+  - template: py-packaging-training-cuda-stage-steps.yml
+    parameters:
+      build_py_parameters: ${{ parameters.build_py_parameters }}
+      torch_version: ${{ parameters.torch_version }}
+      opset_version: ${{ parameters.opset_version }}
+      cuda_version: ${{ parameters.cuda_version }}
+      cmake_cuda_architectures: ${{ parameters.cmake_cuda_architectures }}
+      docker_file: ${{ parameters.docker_file }}
+      upload_wheel: ${{ parameters.upload_wheel }}
+      debug_build: ${{ parameters.debug_build }}
+      stage_name: 'Linux_py_Training_Cuda_Wheels_${{ python_version.name }}'
+      python_version: ${{ python_version.version }}
+      SpecificArtifact: ${{ parameters.SpecificArtifact }}
+      BuildId: ${{ parameters.BuildId }}
diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml
index 18368e59cad52..17915d107dbe6 100644
--- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml
@@ -1,8 +1,4 @@
 parameters:
-
-- name: MACHINE_POOL
-  type: string
-
 - name: EP_NAME
   type: string
 
@@ -27,169 +23,257 @@ parameters:
   values:
     - 11.8
     - 12.2
-jobs:
-- job: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }}
-  timeoutInMinutes: 240
-  workspace:
-    clean: all
-  pool:
-    name: ${{ parameters.MACHINE_POOL }}
-#    demands:
-#      - ImageVersionOverride -equals 1.0.367516
-  variables:
-    GRADLE_OPTS: '-Dorg.gradle.daemon=false'
-    VSGenerator: 'Visual Studio 17 2022'
-    CUDA_MODULE_LOADING: 'LAZY'
-  steps:
-      - checkout: self
-        clean: true
-        submodules: recursive
-
-      - template: telemetry-steps.yml
-
-      - task: UsePythonVersion@0
-        inputs:
-          versionSpec: ${{ parameters.PYTHON_VERSION }}
-          addToPath: true
-          architecture: 'x64'
-
-      - task: onebranch.pipeline.tsaoptions@1
-        displayName: 'OneBranch TSAOptions'
-        inputs:
-          tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json'
-          appendSourceBranchName: false
-
-      - task: PythonScript@0
-        inputs:
-          scriptSource: inline
-          script: |
-            import sys
-            np_version = 'numpy==1.21.6' if sys.version_info < (3, 11) else 'numpy==1.24.2'
-            import subprocess
-            subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', np_version])
-          workingDirectory: '$(Build.BinariesDirectory)'
-          displayName: 'Install python modules'
-
-      - template: download-deps.yml
-
-      - ${{ if ne(parameters.ENV_SETUP_SCRIPT, '') }}:
-        - template: jobs/set-winenv.yml
+
+- name: SpecificArtifact
+  displayName: Use Specific Artifact
+  type: boolean
+  default: false
+
+- name: BuildId
+  displayName: Specific Artifact's BuildId
+  type: string
+  default: '0'
+
+stages:
+  - stage: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }}_Build
+    dependsOn: []
+    jobs:
+    - job: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }}_Build
+      timeoutInMinutes: 120
+      workspace:
+        clean: all
+      pool:
+        name: onnxruntime-Win-CPU-2022
+    #    demands:
+    #      - ImageVersionOverride -equals 1.0.367516
+      variables:
+        GRADLE_OPTS: '-Dorg.gradle.daemon=false'
+        VSGenerator: 'Visual Studio 17 2022'
+        CUDA_MODULE_LOADING: 'LAZY'
+      steps:
+          - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
+            displayName: 'Clean Agent Directories'
+            condition: always()
+
+          - checkout: self
+            clean: true
+            submodules: recursive
+
+          - template: telemetry-steps.yml
+
+          - task: UsePythonVersion@0
+            inputs:
+              versionSpec: ${{ parameters.PYTHON_VERSION }}
+              addToPath: true
+              architecture: 'x64'
+
+          - task: onebranch.pipeline.tsaoptions@1
+            displayName: 'OneBranch TSAOptions'
+            inputs:
+              tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json'
+              appendSourceBranchName: false
+
+          - task: PythonScript@0
+            inputs:
+              scriptSource: inline
+              script: |
+                import sys
+                np_version = 'numpy==1.21.6' if sys.version_info < (3, 11) else 'numpy==1.26'
+                import subprocess
+                try:
+                  subprocess.check_call(['pip', 'install', '-q', 'setuptools', 'wheel', np_version])
+                except subprocess.CalledProcessError:
+                  sys.exit(1)
+              workingDirectory: '$(Build.BinariesDirectory)'
+              displayName: 'Install python modules'
+
+          - template: download-deps.yml
+
+          - ${{ if ne(parameters.ENV_SETUP_SCRIPT, '') }}:
+            - template: jobs/set-winenv.yml
+              parameters:
+                EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }}
+                ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}:
+                  DownloadCUDA: true
+
+          - ${{ if eq(parameters.ENV_SETUP_SCRIPT, '') }}:
+            - template: jobs/download_win_gpu_library.yml
+              parameters:
+                CudaVersion: ${{ parameters.CudaVersion }}
+                ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}:
+                  DownloadCUDA: true
+                ${{ if contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt') }}:
+                  DownloadTRT: true
+
+          - task: PythonScript@0
+            displayName: 'Update deps.txt'
+            inputs:
+              scriptPath: $(Build.SourcesDirectory)/tools/ci_build/replace_urls_in_deps.py
+              arguments: --new_dir $(Build.BinariesDirectory)/deps
+              workingDirectory: $(Build.BinariesDirectory)
+
+          - task: PowerShell@2
+            displayName: 'Install ONNX'
+            inputs:
+              filePath: '$(Build.SourcesDirectory)/tools/ci_build/github/windows/install_third_party_deps.ps1'
+              workingDirectory: '$(Build.BinariesDirectory)'
+              arguments: -cpu_arch x64 -install_prefix $(Build.BinariesDirectory)\RelWithDebInfo\installed -build_config RelWithDebInfo
+
+          # it could be removed once there's onnx wheel for python 3.12
+          - ${{ if eq(parameters.PYTHON_VERSION, '3.12') }}:
+            - task: PublishPipelineArtifact@1
+              displayName: 'Publish Artifact: ONNX python 12 wheel'
+              inputs:
+                targetPath: '$(Agent.TempDirectory)\onnx\onnx-1.15.0\dist\'
+                publishLocation: 'pipeline'
+                artifactName: onnx_py12_wheel
+
+          - template: set-nightly-build-option-variable-step.yml
+
+          - task: PythonScript@0
+            displayName: 'Generate cmake config'
+            inputs:
+              scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py'
+              arguments: >
+                --config RelWithDebInfo
+                --build_dir $(Build.BinariesDirectory)
+                --skip_submodule_sync
+                --cmake_generator "$(VSGenerator)"
+                --enable_pybind
+                --enable_onnx_tests
+                --parallel --use_binskim_compliant_compile_flags --update
+                $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }}
+              workingDirectory: '$(Build.BinariesDirectory)'
+
+          # building with build.py so the parallelization parameters are added to the msbuild command
+          - task: PythonScript@0
+            displayName: 'Build'
+            inputs:
+              scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py'
+              arguments: >
+                --config RelWithDebInfo
+                --build_dir $(Build.BinariesDirectory)
+                --parallel --build
+                $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }}
+              workingDirectory: '$(Build.BinariesDirectory)'
+
+          # Esrp signing
+          - template: win-esrp-dll.yml
+            parameters:
+              FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi'
+              DisplayName: 'ESRP - Sign Native dlls'
+              DoEsrp: true
+              Pattern: '*.pyd,*.dll'
+
+          - task: PythonScript@0
+            displayName: 'Build wheel'
+            inputs:
+              scriptPath: '$(Build.SourcesDirectory)\setup.py'
+              arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=${{ parameters.EP_NAME }}'
+              workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo'
+
+          - task: CopyFiles@2
+            displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)'
+            inputs:
+              SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist'
+              Contents: '*.whl'
+              TargetFolder: '$(Build.ArtifactStagingDirectory)'
+
+          - task: PublishBuildArtifacts@1
+            displayName: 'Publish Artifact: ONNXRuntime python wheel'
+            inputs:
+              ArtifactName: onnxruntime_${{ parameters.EP_NAME }}
+
+          - script: |
+              7z x *.whl
+            workingDirectory: '$(Build.ArtifactStagingDirectory)'
+            displayName: 'unzip the package'
+
+          - task: CredScan@3
+            displayName: 'Run CredScan'
+            inputs:
+              debugMode: false
+            continueOnError: true
+
+          - task: BinSkim@4
+            displayName: 'Run BinSkim'
+            inputs:
+              AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll;-:file|$(Build.ArtifactStagingDirectory)\**\DirectML.dll'
+
+  - stage: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }}_Tests
+    dependsOn: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }}_Build
+    jobs:
+    - job: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }}_Tests
+      workspace:
+        clean: all
+      pool:
+        name: onnxruntime-Win2022-GPU-T4
+      steps:
+        - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
+          displayName: 'Clean Agent Directories'
+          condition: always()
+
+        - checkout: self
+          clean: true
+          submodules: none
+
+        - task: UsePythonVersion@0
+          inputs:
+            versionSpec: ${{ parameters.PYTHON_VERSION }}
+            addToPath: true
+            architecture: 'x64'
+
+        - template: flex-downloadPipelineArtifact.yml
           parameters:
-            EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }}
-            ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}:
-              DownloadCUDA: true
+            ArtifactName: "onnxruntime_${{ parameters.EP_NAME }}"
+            StepName: 'Download Pipeline Artifact - Windows GPU Build'
+            TargetPath: '$(Build.ArtifactStagingDirectory)'
+            SpecificArtifact: ${{ parameters.SpecificArtifact }}
+            BuildId: ${{ parameters.BuildId }}
+
+        # It could be remove once there's onnx wheel for python 3.12
+        - ${{ if eq(parameters.PYTHON_VERSION, '3.12') }}:
+          - template: flex-downloadPipelineArtifact.yml
+            parameters:
+              ArtifactName: "onnx_py12_wheel"
+              StepName: 'Download Pipeline Artifact - Onnx Python12 wheel'
+              TargetPath: '$(Agent.TempDirectory)\onnx\'
+              SpecificArtifact: ${{ parameters.SpecificArtifact }}
+              BuildId: ${{ parameters.BuildId }}
+
+          - powershell: |
+              python -m pip install upgrade pip
+              Get-ChildItem -Path $(Agent.TempDirectory)\onnx\*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname tabulate}
+              python -m pip install pytest
+            workingDirectory: '$(Build.SourcesDirectory)'
+            displayName: 'Install ONNX and pytest'
+        - ${{ else }}:
+          - powershell: |
+              pushd onnxruntime/test/python
+              python -m pip install --upgrade pip
+              python -m pip install -r requirements.txt
+              popd
+            workingDirectory: '$(Build.SourcesDirectory)'
+            displayName: 'Install ONNX'
+
+        - powershell: |
+            python -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu -qq
+            Get-ChildItem -Path $(Build.ArtifactStagingDirectory)/*cp${{ replace(parameters.PYTHON_VERSION,'.','') }}*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname tabulate}
+            mkdir -p $(Agent.TempDirectory)\ort_test_data
+            Copy-Item -Path $(Build.sourcesDirectory)/onnxruntime/test/python/onnx_backend_test_series.py -Destination $(Agent.TempDirectory)\ort_test_data
+            Copy-Item -Recurse -Path $(Build.sourcesDirectory)/onnxruntime/test/testdata -Destination $(Agent.TempDirectory)\ort_test_data
+            cd $(Agent.TempDirectory)\ort_test_data
+            python onnx_backend_test_series.py
+          workingDirectory: '$(Build.sourcesDirectory)'
+          displayName: 'Run Python Tests'
+
+        - task: TSAUpload@2
+          displayName: 'TSA upload'
+          condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main'))
+          inputs:
+            GdnPublishTsaOnboard: false
+            GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa'
 
-      - ${{ if eq(parameters.ENV_SETUP_SCRIPT, '') }}:
-        - template: jobs/download_win_gpu_library.yml
+        - template: component-governance-component-detection-steps.yml
           parameters:
-            CudaVersion: ${{ parameters.CudaVersion }}
-            ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}:
-              DownloadCUDA: true
-            ${{ if contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt') }}:
-              DownloadTRT: true
-
-      - task: PythonScript@0
-        displayName: 'Update deps.txt'
-        inputs:
-          scriptPath: $(Build.SourcesDirectory)/tools/ci_build/replace_urls_in_deps.py
-          arguments: --new_dir $(Build.BinariesDirectory)/deps
-          workingDirectory: $(Build.BinariesDirectory)
-
-      - task: PowerShell@2
-        displayName: 'Install ONNX'
-        inputs:
-          filePath: '$(Build.SourcesDirectory)/tools/ci_build/github/windows/install_third_party_deps.ps1'
-          workingDirectory: '$(Build.BinariesDirectory)'
-          arguments: -cpu_arch x64 -install_prefix $(Build.BinariesDirectory)\RelWithDebInfo\installed -build_config RelWithDebInfo
-
-      - template: set-nightly-build-option-variable-step.yml
-
-
-      - task: PythonScript@0
-        displayName: 'Generate cmake config'
-        inputs:
-          scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py'
-          arguments: >
-            --config RelWithDebInfo
-            --build_dir $(Build.BinariesDirectory)
-            --skip_submodule_sync
-            --cmake_generator "$(VSGenerator)"
-            --enable_pybind
-            --enable_onnx_tests
-            --parallel --use_binskim_compliant_compile_flags --update
-            $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }}
-          workingDirectory: '$(Build.BinariesDirectory)'
-
-      - task: VSBuild@1
-        displayName: 'Build'
-        inputs:
-          solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln'
-          platform: x64
-          configuration: RelWithDebInfo
-          msbuildArchitecture: $(buildArch)
-          maximumCpuCount: true
-          logProjectEvents: true
-          workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo'
-          createLogFile: true
-
-      # Esrp signing
-      - template: win-esrp-dll.yml
-        parameters:
-          FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi'
-          DisplayName: 'ESRP - Sign Native dlls'
-          DoEsrp: true
-          Pattern: '*.pyd,*.dll'
-
-      - task: PythonScript@0
-        displayName: 'Build wheel'
-        inputs:
-          scriptPath: '$(Build.SourcesDirectory)\setup.py'
-          arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=${{ parameters.EP_NAME }}'
-          workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo'
-
-      - task: CopyFiles@2
-        displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)'
-        inputs:
-          SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist'
-          Contents: '*.whl'
-          TargetFolder: '$(Build.ArtifactStagingDirectory)'
-
-      - task: PublishBuildArtifacts@1
-        displayName: 'Publish Artifact: ONNXRuntime python wheel'
-        inputs:
-          ArtifactName: onnxruntime_${{ parameters.EP_NAME }}
-
-      - script: |
-          7z x *.whl
-        workingDirectory: '$(Build.ArtifactStagingDirectory)'
-        displayName: 'unzip the package'
-
-      - task: CredScan@3
-        displayName: 'Run CredScan'
-        inputs:
-          debugMode: false
-        continueOnError: true
-
-      - task: BinSkim@4
-        displayName: 'Run BinSkim'
-        inputs:
-          AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll;-:file|$(Build.ArtifactStagingDirectory)\**\DirectML.dll'
-
-      - powershell: |
-         python -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu -qq
-         Get-ChildItem -Path $(Build.ArtifactStagingDirectory)/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname tabulate}
-         Remove-Item -Recurse -Force onnxruntime
-         python onnx_backend_test_series.py
-        workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo'
-        displayName: 'Run Python Tests'
-
-      - task: TSAUpload@2
-        displayName: 'TSA upload'
-        condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main'))
-        inputs:
-          GdnPublishTsaOnboard: false
-          GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' 
-
-      - template: component-governance-component-detection-steps.yml
-        parameters:
-          condition: 'succeeded'
+            condition: 'succeeded'
diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml
new file mode 100644
index 0000000000000..30f21e933ee36
--- /dev/null
+++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml
@@ -0,0 +1,177 @@
+parameters:
+
+- name: MACHINE_POOL
+  type: string
+  default: 'Onnxruntime-QNNEP-Windows-2022-CPU'
+
+- name: QNN_SDK
+  displayName: QNN Windows SDK path
+  type: string
+  default: qnn-v2.18.0.240101_win
+
+- name: ENV_SETUP_SCRIPT
+  type: string
+  default: ''
+
+- name: BUILD_PY_PARAMETERS
+  displayName: >
+    Extra parameters to pass to build.py. Don't put newlines in here.
+  type: string
+  default: ''
+
+jobs:
+- job: Win_py_x64_qnn_Wheels
+  timeoutInMinutes: 210
+  workspace:
+    clean: all
+  pool:
+    name: ${{ parameters.MACHINE_POOL }}
+  strategy:
+    matrix:
+      Python38_x64:
+        PythonVersion: '3.8'
+      Python39_x64:
+        PythonVersion: '3.9'
+      Python310_x64:
+        PythonVersion: '3.10'
+      Python311_x64:
+        PythonVersion: '3.11'
+      Python312_x64:
+        PythonVersion: '3.12'
+  variables:
+    GRADLE_OPTS: '-Dorg.gradle.daemon=false'
+    VSGenerator: 'Visual Studio 17 2022'
+    QNN_SDK_ROOTDIR: 'C:\data\qnnsdk\${{parameters.QNN_SDK}}'
+  steps:
+      - checkout: self
+        clean: true
+        submodules: recursive
+
+      - template: telemetry-steps.yml
+
+      - script: |
+          DIR C:\data\qnnsdk
+        displayName: Check available QNN SDKs
+
+      - task: UsePythonVersion@0
+        inputs:
+          versionSpec: $(PythonVersion)
+          addToPath: true
+          architecture: 'x64'
+
+      - task: onebranch.pipeline.tsaoptions@1
+        displayName: 'OneBranch TSAOptions'
+        inputs:
+          tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json'
+          appendSourceBranchName: false
+
+      - task: PythonScript@0
+        inputs:
+          scriptSource: inline
+          script: |
+            import sys
+            np_version = 'numpy==1.21.6' if sys.version_info < (3, 11) else 'numpy==1.24.2'
+            import subprocess
+            subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', np_version])
+          workingDirectory: '$(Build.BinariesDirectory)'
+          displayName: 'Install python modules'
+
+      - template: download-deps.yml
+
+      - task: PythonScript@0
+        displayName: 'Update deps.txt'
+        inputs:
+          scriptPath: $(Build.SourcesDirectory)/tools/ci_build/replace_urls_in_deps.py
+          arguments: --new_dir $(Build.BinariesDirectory)/deps
+          workingDirectory: $(Build.BinariesDirectory)
+
+      - task: PowerShell@2
+        displayName: 'Install ONNX'
+        inputs:
+          filePath: '$(Build.SourcesDirectory)/tools/ci_build/github/windows/install_third_party_deps.ps1'
+          workingDirectory: '$(Build.BinariesDirectory)'
+          arguments: -cpu_arch x64 -install_prefix $(Build.BinariesDirectory)\RelWithDebInfo\installed -build_config RelWithDebInfo
+
+      - template: set-nightly-build-option-variable-step.yml
+
+      - task: PythonScript@0
+        displayName: 'Generate cmake config'
+        inputs:
+          scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py'
+          arguments: >
+            --config RelWithDebInfo
+            --build_dir $(Build.BinariesDirectory)
+            --skip_submodule_sync
+            --cmake_generator "$(VSGenerator)"
+            --use_qnn
+            --qnn_home $(QNN_SDK_ROOTDIR)
+            --enable_pybind
+            --parallel --update
+            $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }}
+          workingDirectory: '$(Build.BinariesDirectory)'
+
+      - task: VSBuild@1
+        displayName: 'Build'
+        inputs:
+          solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln'
+          platform: 'x64'
+          configuration: RelWithDebInfo
+          msbuildArchitecture: 'x64'
+          maximumCpuCount: true
+          logProjectEvents: true
+          workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo'
+          createLogFile: true
+
+      # Esrp signing
+      - template: win-esrp-dll.yml
+        parameters:
+          FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi'
+          DisplayName: 'ESRP - Sign Native dlls'
+          DoEsrp: true
+          Pattern: '*.pyd,*.dll'
+
+      - task: PythonScript@0
+        displayName: 'Build wheel'
+        inputs:
+          scriptPath: '$(Build.SourcesDirectory)\setup.py'
+          arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=qnn'
+          workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo'
+
+      - task: CopyFiles@2
+        displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)'
+        inputs:
+          SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist'
+          Contents: '*.whl'
+          TargetFolder: '$(Build.ArtifactStagingDirectory)'
+
+      - task: PublishBuildArtifacts@1
+        displayName: 'Publish Artifact: ONNXRuntime python wheel'
+        inputs:
+          ArtifactName: onnxruntime_qnn
+
+      - script: |
+          7z x *.whl
+        workingDirectory: '$(Build.ArtifactStagingDirectory)'
+        displayName: 'unzip the package'
+
+      - task: CredScan@3
+        displayName: 'Run CredScan'
+        inputs:
+          debugMode: false
+        continueOnError: true
+
+      - task: BinSkim@4
+        displayName: 'Run BinSkim'
+        inputs:
+          AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll'
+
+      - task: TSAUpload@2
+        displayName: 'TSA upload'
+        condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main'))
+        inputs:
+          GdnPublishTsaOnboard: false
+          GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' 
+
+      - template: component-governance-component-detection-steps.yml
+        parameters:
+          condition: 'succeeded'
diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml
index 47cd72f412c67..1b7962059e301 100644
--- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml
@@ -279,7 +279,7 @@ stages:
 
     - script: |
         JEST_JUNIT_OUTPUT_FILE=$(Build.SourcesDirectory)/js/react_native/e2e/android-test-results.xml \
-        detox test --record-logs all --configuration android.emu.release
+        detox test --record-logs all --configuration android.emu.release --loglevel trace
       workingDirectory: '$(Build.SourcesDirectory)/js/react_native/e2e'
       displayName: Run React Native Detox Android e2e Tests
 
@@ -329,7 +329,7 @@ stages:
 
     - script: |
         JEST_JUNIT_OUTPUT_FILE=$(Build.SourcesDirectory)/js/react_native/e2e/ios-test-results.xml \
-        detox test --record-logs all --configuration ios.sim.release
+        detox test --record-logs all --configuration ios.sim.release --loglevel trace
       workingDirectory: '$(Build.SourcesDirectory)/js/react_native/e2e'
       displayName: Run React Native Detox iOS e2e Tests
 
diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml
index ed32c5d0e15be..b1cdb498bb4ae 100644
--- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml
@@ -16,10 +16,10 @@ stages:
     displayName: "Build iOS package for variant: ${{ parameters.packageVariant}}"
 
     pool:
-      vmImage: "macOS-13"
+      vmImage: "macOS-latest"
 
     variables:
-      xcodeVersion: "14.3"
+      xcodeVersion: "14.2"
       ortPodVersion: $[stageDependencies.IosPackaging_SetCommonVariables.j.outputs['SetCommonVariables.ORT_POD_VERSION']]
 
       ${{ if eq(parameters.packageVariant, 'Mobile') }}:
diff --git a/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml
index 96e6ff89cd4f1..9ab2d3401de42 100644
--- a/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml
@@ -71,7 +71,7 @@ jobs:
     timeoutInMinutes: 20
   - script: |
       export ORT_WEB_TEST_BS_BROWSERS=BS_MAC_11_Safari_14,BS_MAC_11_Chrome_91,BS_ANDROID_11_Pixel_5
-      npm test -- suite0 --env=bs --wasm-init-timeout=30000 --file-cache
+      npm test -- suite0 -e=bs --wasm.initTimeout=30000 --file-cache
     workingDirectory: '$(Build.SourcesDirectory)/js/web'
     displayName: 'npm test (Suite0, BS_ANDROID, BS_MAC)'
     env:
@@ -80,7 +80,7 @@ jobs:
     continueOnError: true
   - script: |
       export ORT_WEB_TEST_BS_BROWSERS=BS_IOS_14_iPhoneXS
-      npm test -- suite1 --env=bs --wasm-init-timeout=30000 --file-cache --backend=wasm
+      npm test -- suite1 -e=bs --wasm.initTimeout=30000 --file-cache --backend=wasm
     workingDirectory: '$(Build.SourcesDirectory)/js/web'
     displayName: 'npm test (Suite1, BS_IOS)'
     continueOnError: true
@@ -95,4 +95,3 @@ jobs:
   - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
     displayName: 'Clean Agent Directories'
     condition: always()
-
diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml
index 8ed22153fd947..e32956d6eb913 100644
--- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml
@@ -162,10 +162,11 @@ stages:
           platform: ${{ parameters.msbuildPlatform }}
           configuration: RelWithDebInfo
           msbuildArchitecture: ${{ parameters.buildArch }}
-          maximumCpuCount: true
+          maximumCpuCount: true  # default is num logical cores worth of projects building concurrently
           logProjectEvents: true
           workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo'
           createLogFile: true
+          msbuildArgs: "/p:CL_MPCount=2"  # 2x cl.exe per project building.
 
       - task: PythonScript@0
         displayName: 'test'
diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml
index 8ba3517530edd..1eb2ee6f6409c 100644
--- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml
@@ -31,6 +31,7 @@ jobs:
   variables:
     webgpuCommandlineExtraFlags: '--chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de'
     runCodesignValidationInjection: false
+    CHROME_BIN: 'C:\Program Files\Google\Chrome\Application\chrome.exe'
   timeoutInMinutes: 60
   workspace:
     clean: all
@@ -95,18 +96,6 @@ jobs:
       targetFolder: $(Build.SourcesDirectory)\js\web\lib\wasm\binding
       flattenFolders: true
     displayName: 'Binplace js files'
-  - script: |
-      npm i -g puppeteer
-    workingDirectory: '$(Build.SourcesDirectory)'
-    displayName: 'Use puppeteer to prepare Chrome for tests'
-  - script: |
-      FOR /F "tokens=* USEBACKQ" %%F IN (`where /r %HOMEDRIVE%%HOMEPATH%\.cache\puppeteer chrome.exe`) DO (
-        SET var=%%F
-        ECHO found chrome.exe: %%F
-      )
-      ECHO ##vso[task.setvariable variable=CHROME_BIN;]%var%
-    workingDirectory: '$(Build.SourcesDirectory)'
-    displayName: 'Set CHROME_BIN'
   - script: |
      npm ci
     workingDirectory: '$(Build.SourcesDirectory)\js'
@@ -155,12 +144,7 @@ jobs:
       path: $(Build.SourcesDirectory)/js/test/
       cacheHitVar: CACHE_RESTORED
     displayName: 'Cache ONNX node test data'
-  - task: Bash@3
-    inputs:
-      targetType: 'inline'
-      script: find "$(Build.SourcesDirectory)/js/test/" -type f
-    condition: and(not(canceled()), eq(variables.CACHE_RESTORED, 'true'))
-    displayName: 'List ONNX node test data'
+
   - task: PowerShell@2
     inputs:
       filePath: '$(Build.SourcesDirectory)\tools\ci_build\github\js\pack-npm-packages.ps1'
@@ -169,31 +153,31 @@ jobs:
       errorActionPreference: stop
     displayName: 'Pack NPM packages'
   - script: |
-     npm test -- -e=chrome -b=webgl,wasm
+     npm test -- -e=chrome -b=webgl,wasm --karma-debug
     workingDirectory: '$(Build.SourcesDirectory)\js\web'
     displayName: 'Run ort-web tests (wasm,webgl backend)'
     condition: eq('${{ parameters.RunWebGpuTests }}', 'false')
   - script: |
-     npm test -- -e=chrome -b=webgl,wasm,webgpu $(webgpuCommandlineExtraFlags)
+     npm test -- -e=chrome -b=webgl,wasm,webgpu --karma-debug $(webgpuCommandlineExtraFlags)
     workingDirectory: '$(Build.SourcesDirectory)\js\web'
     displayName: 'Run ort-web tests (ALL backends)'
     condition: eq('${{ parameters.RunWebGpuTests }}', 'true')
   - script: |
-     npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-tensor $(webgpuCommandlineExtraFlags)
+     npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-tensor --karma-debug $(webgpuCommandlineExtraFlags)
     workingDirectory: '$(Build.SourcesDirectory)\js\web'
     displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-tensor)'
     condition: eq('${{ parameters.RunWebGpuTests }}', 'true')
   - script: |
-     npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-location $(webgpuCommandlineExtraFlags)
+     npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-location --karma-debug $(webgpuCommandlineExtraFlags)
     workingDirectory: '$(Build.SourcesDirectory)\js\web'
     displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-location)'
     condition: eq('${{ parameters.RunWebGpuTests }}', 'true')
   - script: |
-     npm test -- --webgl-texture-pack-mode -b=webgl -e=chrome
+     npm test -- --webgl.pack -b=webgl -e=chrome --karma-debug
     workingDirectory: '$(Build.SourcesDirectory)\js\web'
     displayName: 'Run ort-web tests - WebGL: packed mode'
   - script: |
-     npm test -- --wasm-enable-proxy -b=wasm -e=chrome
+     npm test -- --wasm.proxy -b=wasm -e=chrome --karma-debug
     workingDirectory: '$(Build.SourcesDirectory)\js\web'
     displayName: 'Run ort-web tests - WebAssembly: proxy'
     condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release'))
diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml
index 31ee488318a0b..79bf0b5e71363 100644
--- a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml
@@ -68,15 +68,15 @@ jobs:
     workingDirectory: '$(Build.SourcesDirectory)\js\web'
     displayName: 'npm ci /js/web/'
   - script: |
-      npm test -- suite0 -b=wasm,webgl --wasm-init-timeout=30000 --file-cache
+      npm test -- suite0 -b=wasm,webgl --wasm.initTimeout=30000 --file-cache
     workingDirectory: '$(Build.SourcesDirectory)\js\web'
     displayName: 'npm test (Suite0, Chrome)'
   - script: |
-      npm test -- suite0 -b=wasm,webgl --env=firefox --wasm-init-timeout=30000 --file-cache
+      npm test -- suite0 -b=wasm,webgl -e=firefox --wasm.initTimeout=30000 --file-cache
     workingDirectory: '$(Build.SourcesDirectory)\js\web'
     displayName: 'npm test (Suite0, Firefox)'
   - script: |
-      npm test -- suite0 -b=wasm,webgl --env=edge --wasm-init-timeout=30000 --file-cache
+      npm test -- suite0 -b=wasm,webgl -e=edge --wasm.initTimeout=30000 --file-cache
     workingDirectory: '$(Build.SourcesDirectory)\js\web'
     displayName: 'npm test (Suite0, Edge)'
   - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml
index 13d4589a67cdc..dc861f7f1ed79 100644
--- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml
@@ -32,7 +32,7 @@ parameters:
 - name: QnnSdk
   displayName: QNN SDK version
   type: string
-  default: qnn-v2.18.0.240101_win
+  default: qnn-v2.19.2.240210_win
 
 jobs:
 - job: 'build'
diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml
index 6246bb83566e5..534d5c6d6135b 100644
--- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml
@@ -32,7 +32,7 @@ parameters:
 - name: QnnSdk
   displayName: QNN SDK version
   type: string
-  default: qnn-v2.18.0.240101_win
+  default: qnn-v2.19.2.240210_win
 
 jobs:
 - job: 'build'
diff --git a/tools/ci_build/github/linux/copy_strip_binary.sh b/tools/ci_build/github/linux/copy_strip_binary.sh
index 42973a8fcb5b8..65d6d97ebf0a8 100755
--- a/tools/ci_build/github/linux/copy_strip_binary.sh
+++ b/tools/ci_build/github/linux/copy_strip_binary.sh
@@ -44,17 +44,10 @@ elif [[ $LIB_NAME == *.so.* ]]
 then
     ln -s $LIB_NAME $BINARY_DIR/$ARTIFACT_NAME/lib/libonnxruntime.so
 fi
-cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_c_api.h  $BINARY_DIR/$ARTIFACT_NAME/include
-cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_api.h  $BINARY_DIR/$ARTIFACT_NAME/include
-cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_inline.h  $BINARY_DIR/$ARTIFACT_NAME/include
-cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_float16.h  $BINARY_DIR/$ARTIFACT_NAME/include
-cp $SOURCE_DIR/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h  $BINARY_DIR/$ARTIFACT_NAME/include
-cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h  $BINARY_DIR/$ARTIFACT_NAME/include
-cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h  $BINARY_DIR/$ARTIFACT_NAME/include
+cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_*.h $BINARY_DIR/$ARTIFACT_NAME/include
 cp $SOURCE_DIR/include/onnxruntime/core/framework/provider_options.h  $BINARY_DIR/$ARTIFACT_NAME/include
-cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h  $BINARY_DIR/$ARTIFACT_NAME/include
-cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h  $BINARY_DIR/$ARTIFACT_NAME/include
-cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h  $BINARY_DIR/$ARTIFACT_NAME/include
+cp $SOURCE_DIR/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h  $BINARY_DIR/$ARTIFACT_NAME/include
+cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training_*.h  $BINARY_DIR/$ARTIFACT_NAME/include
 
 if [[ -f "$BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_cuda.so" ]]; then
 # copy headers for context context used in custom ops
diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm
index dd7c669c37885..e1914d5fe2f06 100644
--- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm
+++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm
@@ -178,7 +178,7 @@ CMD ["/bin/bash"]
 #Build manylinux2014 docker image end
 
 ARG PYTHON_VERSION=3.8
-ARG OPSET_VERSION=15
+ARG OPSET_VERSION=17
 ARG INSTALL_DEPS_EXTRA_ARGS
 
 
diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8
index a6a75afb0f4c3..fed29689fbe5e 100644
--- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8
+++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8
@@ -161,7 +161,7 @@ CMD ["/bin/bash"]
 #Build manylinux2014 docker image end
 ARG PYTHON_VERSION=3.9
 ARG TORCH_VERSION=2.0.0
-ARG OPSET_VERSION=15
+ARG OPSET_VERSION=17
 ARG INSTALL_DEPS_EXTRA_ARGS
 
 #Add our own dependencies
diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2
index d29157daef611..e1caa141ef317 100644
--- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2
+++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2
@@ -161,7 +161,7 @@ CMD ["/bin/bash"]
 #Build manylinux2014 docker image end
 ARG PYTHON_VERSION=3.9
 ARG TORCH_VERSION=2.1.0
-ARG OPSET_VERSION=15
+ARG OPSET_VERSION=17
 ARG INSTALL_DEPS_EXTRA_ARGS
 
 #Add our own dependencies
diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu
index 9b9dc9ecae822..331eb6472070c 100644
--- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu
+++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu
@@ -16,15 +16,14 @@ ENV DEBIAN_FRONTEND=noninteractive
 ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG}:${LD_LIBRARY_PATH}
 
 RUN apt-get update &&\
-    apt-get install -y git bash wget
+    apt-get install -y git bash wget diffutils
 
 # Install python3
 RUN apt-get install -y --no-install-recommends \
     python3 \
     python3-pip \
     python3-dev \
-    python3-wheel 
-   
+    python3-wheel
 
 RUN pip install --upgrade pip
 
diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6
index 04a6af962b5e6..f1ffba3b3e1c9 100644
--- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6
+++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6
@@ -82,8 +82,9 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM
     git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi
 
 # Build ORT
-ENV CUDA_MODULE_LOADING "LAZY"
-RUN /bin/sh build.sh --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"'
+ENV CUDA_MODULE_LOADING "LAZY" 
+ARG PARSER_CONFIG=""
+RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"'
 
 # Switch to root to continue following steps of CI
 USER root
diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6
new file mode 100644
index 0000000000000..9493480784e81
--- /dev/null
+++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6
@@ -0,0 +1,96 @@
+# --------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------
+# Dockerfile to run ONNXRuntime with TensorRT integration
+
+# Build base image with required system packages
+FROM nvidia/cuda:12.3.1-devel-ubuntu20.04 AS base
+
+# The local directory into which to build and install CMAKE
+ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code
+
+ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH}
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get update &&\
+    apt-get install -y sudo git bash unattended-upgrades wget
+RUN unattended-upgrade
+
+# Install python3
+RUN apt-get install -y --no-install-recommends \
+    python3 \
+    python3-pip \
+    python3-dev \
+    python3-wheel &&\
+    cd /usr/local/bin &&\
+    ln -s /usr/bin/python3 python &&\
+    ln -s /usr/bin/pip3 pip;
+
+RUN pip install --upgrade pip 
+RUN pip install setuptools>=68.2.2
+
+# Install cuDNN v9
+RUN apt-get -y install cudnn9-cuda-12
+
+# Install TensorRT
+RUN v="8.6.1.6-1+cuda12.0" &&\
+    apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\
+    apt-get update &&\
+    sudo apt-get install -y libnvinfer8=${v} libnvonnxparsers8=${v} libnvparsers8=${v} libnvinfer-plugin8=${v} libnvinfer-lean8=${v} libnvinfer-vc-plugin8=${v} libnvinfer-dispatch8=${v}\
+        libnvinfer-headers-dev=${v} libnvinfer-headers-plugin-dev=${v} libnvinfer-dev=${v} libnvonnxparsers-dev=${v} libnvparsers-dev=${v} libnvinfer-plugin-dev=${v} libnvinfer-lean-dev=${v} libnvinfer-vc-plugin-dev=${v}  libnvinfer-dispatch-dev=${v}\
+        python3-libnvinfer=${v} libnvinfer-samples=${v} tensorrt-dev=${v} tensorrt-libs=${v}
+
+# Compile trtexec
+RUN cd /usr/src/tensorrt/samples/trtexec && make
+
+# Install Valgrind
+RUN apt-get install -y valgrind
+
+# Build final image from base. Builds ORT.
+FROM base as final
+ARG BUILD_USER=onnxruntimedev
+ARG BUILD_UID=1000
+RUN adduser --gecos 'onnxruntime Build User' --disabled-password $BUILD_USER --uid $BUILD_UID
+USER $BUILD_USER
+
+# ONNX Runtime arguments
+
+# URL to the github repo from which to clone ORT.
+ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
+
+# The local directory into which to clone ORT.
+ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code
+
+# The git branch of ORT to checkout and build.
+ARG ONNXRUNTIME_BRANCH=main
+
+# Optional. The specific commit to pull and build from. If not set, the latest commit is used.
+ARG ONNXRUNTIME_COMMIT_ID
+
+# The supported CUDA architecture
+ARG CMAKE_CUDA_ARCHITECTURES=75
+
+WORKDIR ${ONNXRUNTIME_LOCAL_CODE_DIR}
+
+# Clone ORT repository with branch
+RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\
+    /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh
+
+WORKDIR ${ONNXRUNTIME_LOCAL_CODE_DIR}/onnxruntime
+
+# Reset to a specific commit if specified by build args.
+RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIME_BRANCH}" ;\
+    else echo "Building branch ${ONNXRUNTIME_BRANCH} @ commit ${ONNXRUNTIME_COMMIT_ID}" &&\
+    git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi
+
+# Build ORT
+ENV CUDA_MODULE_LOADING "LAZY" 
+ARG PARSER_CONFIG=""
+RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"'
+
+# Switch to root to continue following steps of CI
+USER root
+
+# Intall ORT wheel
+RUN pip install ${ONNXRUNTIME_LOCAL_CODE_DIR}/onnxruntime/build/Linux/Release/dist/*.whl
\ No newline at end of file
diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino
index a0ba5ea232ca3..45682c797bbb8 100644
--- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino
+++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_openvino
@@ -1,8 +1,8 @@
 ARG UBUNTU_VERSION=20.04
 FROM ubuntu:${UBUNTU_VERSION}
 
-ARG OPENVINO_VERSION=2023.0.0
-ARG PYTHON_VERSION=3.8
+ARG OPENVINO_VERSION=2024.0.0
+ARG PYTHON_VERSION=3.9
 
 ADD scripts /tmp/scripts
 RUN /tmp/scripts/install_ubuntu.sh -p ${PYTHON_VERSION} -d EdgeDevice && \
@@ -14,15 +14,14 @@ RUN apt update && apt install -y libnuma1 ocl-icd-libopencl1 && \
 
 ENV INTEL_OPENVINO_DIR /opt/intel/openvino_${OPENVINO_VERSION}
 ENV LD_LIBRARY_PATH $INTEL_OPENVINO_DIR/runtime/lib/intel64:$INTEL_OPENVINO_DIR/runtime/3rdparty/tbb/lib:/usr/local/openblas/lib:$LD_LIBRARY_PATH
-ENV InferenceEngine_DIR $INTEL_OPENVINO_DIR/runtime/cmake
-ENV ngraph_DIR $INTEL_OPENVINO_DIR/runtime/cmake
+ENV OpenVINO_DIR $INTEL_OPENVINO_DIR/runtime/cmake
 ENV IE_PLUGINS_PATH $INTEL_OPENVINO_DIR/runtime/lib/intel64
 ENV DEBIAN_FRONTEND=noninteractive
 
 RUN cd /opt && mkdir -p intel && cd intel && \
-    wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2023.0/linux/l_openvino_toolkit_ubuntu20_2023.0.0.10926.b4452d56304_x86_64.tgz && \
-    tar xzf l_openvino_toolkit_ubuntu20_2023.0.0.10926.b4452d56304_x86_64.tgz && rm -rf l_openvino_toolkit_ubuntu20_2023.0.0.10926.b4452d56304_x86_64.tgz && \
-    mv l_openvino_toolkit_ubuntu20_2023.0.0.10926.b4452d56304_x86_64 openvino_2023.0.0 && \
+    wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.0/linux/l_openvino_toolkit_ubuntu20_2024.0.0.14509.34caeefd078_x86_64.tgz && \
+    tar xzf l_openvino_toolkit_ubuntu20_2024.0.0.14509.34caeefd078_x86_64.tgz && rm -rf l_openvino_toolkit_ubuntu20_2024.0.0.14509.34caeefd078_x86_64.tgz && \
+    mv l_openvino_toolkit_ubuntu20_2024.0.0.14509.34caeefd078_x86_64 openvino_2024.0.0 && \
     cd $INTEL_OPENVINO_DIR/install_dependencies && ./install_openvino_dependencies.sh -y
 
 WORKDIR /root
diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin
index 21b09b2d8978e..a26bf88fbbdf6 100644
--- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin
+++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin
@@ -4,29 +4,15 @@
 # --------------------------------------------------------------
 # Dockerfile to run ONNXRuntime with TensorRT installed from provided binaries
 
-FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04
+# Build base image with required system packages
+FROM nvidia/cuda:12.3.1-devel-ubuntu20.04 AS base
 
+# The local directory into which to build and install CMAKE
+ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code
 
-# ONNX Runtime Variables
-ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
-ARG ONNXRUNTIME_BRANCH=main
-ARG CMAKE_CUDA_ARCHITECTURES=37;50;52;60;61;70;75;80
-
-# Must provide version numbers used to build the name of the tar file containing TensorRT binaries.
-# See: https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html#installing-tar
-ARG TAR_TRT_VERSION
-ARG TAR_CUDA_VERSION
-ARG TAR_CUDNN_VERSION
-
-# Directory containing TensorRT tar.gz installation package
-ARG TRT_BINS_DIR=.
-
-ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/code/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH}
-
+ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH}
 ENV DEBIAN_FRONTEND=noninteractive
 
-COPY ${TRT_BINS_DIR}/TensorRT-${TAR_TRT_VERSION}.Linux.x86_64-gnu.cuda-${TAR_CUDA_VERSION}.cudnn${TAR_CUDNN_VERSION}.tar.gz /TensorRT-${TAR_TRT_VERSION}.tar.gz
-
 RUN apt-get update &&\
     apt-get install -y sudo git bash unattended-upgrades wget
 RUN unattended-upgrade
@@ -44,22 +30,77 @@ RUN apt-get install -y --no-install-recommends \
 RUN pip install --upgrade pip 
 RUN pip install setuptools>=68.2.2
 
+# Install cuDNN v9
+RUN apt-get -y install cudnn9-cuda-12
+
+# Install TensorRT
+# Must provide version numbers used to build the name of the tar file containing TensorRT binaries.
+# See: https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html#installing-tar
+ARG TAR_TRT_VERSION
+ARG TAR_CUDA_VERSION
+
+# Directory containing TensorRT tar.gz installation package
+ARG TRT_BINS_DIR=.
+COPY ${TRT_BINS_DIR}/TensorRT-${TAR_TRT_VERSION}.Linux.x86_64-gnu.cuda-${TAR_CUDA_VERSION}.tar.gz /TensorRT-${TAR_TRT_VERSION}.tar.gz
+
 # Install TensorRT from tar.gz
 RUN tar -xzvf /TensorRT-${TAR_TRT_VERSION}.tar.gz
 
 RUN cd /TensorRT-${TAR_TRT_VERSION}/python &&\
-    python3 -m pip install tensorrt-${TAR_TRT_VERSION}-cp38-none-linux_x86_64.whl
+    python3 -m pip install tensorrt*cp38*.whl
 
 RUN cp -r /TensorRT-${TAR_TRT_VERSION}/lib/* /usr/lib/x86_64-linux-gnu/
 RUN cp /TensorRT-${TAR_TRT_VERSION}/include/* /usr/local/include/
 RUN cp /TensorRT-${TAR_TRT_VERSION}/bin/* /usr/local/bin/
 
-WORKDIR /code
+# Install Valgrind
+RUN apt-get install -y valgrind
+
+# Build final image from base. Builds ORT.
+FROM base as final
+ARG BUILD_USER=onnxruntimedev
+ARG BUILD_UID=1000
+RUN adduser --gecos 'onnxruntime Build User' --disabled-password $BUILD_USER --uid $BUILD_UID
+USER $BUILD_USER
+
+# ONNX Runtime arguments
+
+# URL to the github repo from which to clone ORT.
+ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
+
+# The local directory into which to clone ORT.
+ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code
+
+# The git branch of ORT to checkout and build.
+ARG ONNXRUNTIME_BRANCH=main
+
+# Optional. The specific commit to pull and build from. If not set, the latest commit is used.
+ARG ONNXRUNTIME_COMMIT_ID
+
+# The supported CUDA architecture
+ARG CMAKE_CUDA_ARCHITECTURES=75
 
 # Prepare onnxruntime repository & build onnxruntime with TensorRT
+WORKDIR ${ONNXRUNTIME_LOCAL_CODE_DIR}
+
+# Clone ORT repository with branch
 RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\
-    /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\
-    cd onnxruntime &&\
-    /bin/sh build.sh --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' &&\
-    pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\
-    cd .. 
+    /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh
+
+WORKDIR ${ONNXRUNTIME_LOCAL_CODE_DIR}/onnxruntime
+
+# Reset to a specific commit if specified by build args.
+RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIME_BRANCH}" ;\
+    else echo "Building branch ${ONNXRUNTIME_BRANCH} @ commit ${ONNXRUNTIME_COMMIT_ID}" &&\
+    git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi
+
+# Build ORT
+ENV CUDA_MODULE_LOADING "LAZY" 
+ARG PARSER_CONFIG=""
+RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"'
+
+# Switch to root to continue following steps of CI
+USER root
+
+# Intall ORT wheel
+RUN pip install ${ONNXRUNTIME_LOCAL_CODE_DIR}/onnxruntime/build/Linux/Release/dist/*.whl
\ No newline at end of file
diff --git a/tools/ci_build/github/linux/docker/Dockerfile_manylinux2014_openvino_multipython b/tools/ci_build/github/linux/docker/Dockerfile_manylinux2014_openvino_multipython
deleted file mode 100644
index bc0b412773286..0000000000000
--- a/tools/ci_build/github/linux/docker/Dockerfile_manylinux2014_openvino_multipython
+++ /dev/null
@@ -1,83 +0,0 @@
-FROM quay.io/pypa/manylinux2014_x86_64:latest
-
-ENV PATH /opt/rh/devtoolset-10/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
-ADD scripts /tmp/scripts
-RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts
-
-ARG PYTHON_VER_PATH="cp38-cp38"
-ARG PYTHON_VERSION="3.8"
-ARG BUILD_UID=1001
-ARG BUILD_USER=onnxruntimedev
-ARG OV_DEVICE_PRECISION="CPU_FP32"
-ARG ENABLE_TRAINING=true
-ARG ORT_BRANCH="rel-1.13.1"
-ARG OV_VERSION="2022.2.0"
-RUN adduser --uid $BUILD_UID $BUILD_USER
-WORKDIR /home/$BUILD_USER
-ENV PYTHON_EXE="/opt/python/$PYTHON_VER_PATH/bin/python$PYTHON_VERSION"
-
-RUN yum -y install wget git
-
-# libusb1.0.22
-RUN cd /home/ && wget https://github.com/libusb/libusb/archive/v1.0.22.zip && \
-    unzip v1.0.22.zip && rm -rf v1.0.22.zip && cd  /home/libusb-1.0.22 && \
-# bootstrap steps
-    ./bootstrap.sh && \
-    ./configure --disable-udev --enable-shared && \
-    make -j4 && \
-# configure libusb1.0.22
-    cd /home/libusb-1.0.22/libusb && \
-    /bin/mkdir -p '/usr/local/lib' && \
-    /bin/bash ../libtool   --mode=install /usr/bin/install -c   libusb-1.0.la '/usr/local/lib' && \
-    /bin/mkdir -p '/usr/local/include/libusb-1.0' && \
-    /usr/bin/install -c -m 644 libusb.h '/usr/local/include/libusb-1.0' && \
-    /bin/mkdir -p '/usr/local/lib/pkgconfig'
-
-RUN ${PYTHON_EXE} -m pip install onnx numpy wheel
-USER $BUILD_USER
-RUN cd $WORKDIR && git clone https://github.com/openvinotoolkit/openvino.git && \
-    cd openvino && \
-    git checkout $OV_VERSION && \
-    git submodule init && \
-    git submodule update --recursive
-
-RUN cd $WORKDIR && cd openvino && mkdir build && cd build && \
-    cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS=-D_GLIBCXX_USE_CXX11_ABI=0 -DENABLE_PYTHON=ON -DPYTHON_EXECUTABLE=$PYTHON_EXE -DCMAKE_INSTALL_PREFIX=/home/onnxruntimedev/openvino_$OV_VERSION && \
-    make -j8 && make install
-
-ENV INTEL_OPENVINO_DIR /home/onnxruntimedev/openvino_$OV_VERSION
-ENV LD_LIBRARY_PATH $INTEL_OPENVINO_DIR/runtime/lib/intel64:$INTEL_OPENVINO_DIR/runtime/3rdparty/tbb/lib:/usr/local/openblas/lib:$LD_LIBRARY_PATH
-ENV TBB_LIBS $INTEL_OPENVINO_DIR/runtime/3rdparty/tbb/lib
-ENV InferenceEngine_DIR $INTEL_OPENVINO_DIR/runtime/cmake
-ENV ngraph_DIR $INTEL_OPENVINO_DIR/runtime/cmake
-ENV IE_PLUGINS_PATH $INTEL_OPENVINO_DIR/runtime/lib/intel64
-ENV OPENVINO_MANYLINUX 1
-
-RUN cd $WORKDIR && \
-    git clone --recursive -b $ORT_BRANCH https://github.com/intel/onnxruntime.git
-RUN cd onnxruntime/onnxruntime/core/providers/openvino && mkdir scripts
-
-RUN cp ${IE_PLUGINS_PATH}/libopenvino.so /home/onnxruntimedev/onnxruntime/onnxruntime/core/providers/openvino/scripts/
-RUN cp ${IE_PLUGINS_PATH}/libopenvino_c.so /home/onnxruntimedev/onnxruntime/onnxruntime/core/providers/openvino/scripts/
-RUN cp ${IE_PLUGINS_PATH}/libopenvino_onnx_frontend.so /home/onnxruntimedev/onnxruntime/onnxruntime/core/providers/openvino/scripts/
-RUN cp ${IE_PLUGINS_PATH}/libopenvino_intel_cpu_plugin.so /home/onnxruntimedev/onnxruntime/onnxruntime/core/providers/openvino/scripts/
-RUN cp ${IE_PLUGINS_PATH}/libopenvino_intel_gpu_plugin.so /home/onnxruntimedev/onnxruntime/onnxruntime/core/providers/openvino/scripts/
-RUN cp ${IE_PLUGINS_PATH}/libopenvino_intel_myriad_plugin.so /home/onnxruntimedev/onnxruntime/onnxruntime/core/providers/openvino/scripts/
-RUN cp ${IE_PLUGINS_PATH}/libopenvino_hetero_plugin.so /home/onnxruntimedev/onnxruntime/onnxruntime/core/providers/openvino/scripts/
-RUN cp ${IE_PLUGINS_PATH}/libopenvino_auto_plugin.so /home/onnxruntimedev/onnxruntime/onnxruntime/core/providers/openvino/scripts/
-RUN cp ${IE_PLUGINS_PATH}/plugins.xml /home/onnxruntimedev/onnxruntime/onnxruntime/core/providers/openvino/scripts/
-RUN cp ${IE_PLUGINS_PATH}/usb-ma2x8x.mvcmd /home/onnxruntimedev/onnxruntime/onnxruntime/core/providers/openvino/scripts/
-RUN cp ${TBB_LIBS}/libtbb.so /home/onnxruntimedev/onnxruntime/onnxruntime/core/providers/openvino/scripts/
-RUN cp ${TBB_LIBS}/libtbb.so.2 /home/onnxruntimedev/onnxruntime/onnxruntime/core/providers/openvino/scripts/
-RUN cp ${TBB_LIBS}/libtbbmalloc.so /home/onnxruntimedev/onnxruntime/onnxruntime/core/providers/openvino/scripts/
-RUN cp ${TBB_LIBS}/libtbbmalloc.so.2 /home/onnxruntimedev/onnxruntime/onnxruntime/core/providers/openvino/scripts/
-RUN cd /home/onnxruntimedev/onnxruntime && git pull
-RUN if $ENABLE_TRAINING; then \
-        ${PYTHON_EXE} ./onnxruntime/tools/ci_build/build.py \
-        --build_dir ./onnxruntime/build --use_openvino $(OV_DEVICE_PRECISION) --build_shared_lib \
-        --config Release --build_wheel --skip_tests --enable_training ; \
-    else \
-        ${PYTHON_EXE} ./onnxruntime/tools/ci_build/build.py \
-        --build_dir ./onnxruntime/build --use_openvino $(OV_DEVICE_PRECISION) --build_shared_lib \
-        --config Release --build_wheel --skip_tests ;\
-    fi
diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh
index da8a45e00cc90..39c15338aeddb 100755
--- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh
+++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh
@@ -31,8 +31,7 @@ cd /usr/local/
 echo "Cloning ONNX Script"
 git clone --recursive https://github.com/microsoft/onnxscript.git
 cd onnxscript
-/opt/python/cp39-cp39/bin/python3.9 -m pip install -r requirements-dev.txt
-/opt/python/cp39-cp39/bin/python3.9 setup.py install
+/opt/python/cp39-cp39/bin/python3.9 -m pip install .
 cd ~ && /opt/python/cp39-cp39/bin/python3.9 -c "import onnxscript; print(f'Installed ONNX Script: {onnxscript.__version__}')"
 
 cd /usr/local
diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt
index 94f52f476579b..886f19388d01e 100644
--- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt
+++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt
@@ -10,3 +10,4 @@ protobuf==4.21.12
 sympy==1.12
 flatbuffers
 neural-compressor>=2.2.1
+triton
diff --git a/tools/ci_build/github/linux/ort_minimal/check_build_binary_size.py b/tools/ci_build/github/linux/ort_minimal/check_build_binary_size.py
index ea4a3fd32b18b..40debff3b2fef 100644
--- a/tools/ci_build/github/linux/ort_minimal/check_build_binary_size.py
+++ b/tools/ci_build/github/linux/ort_minimal/check_build_binary_size.py
@@ -31,9 +31,7 @@ def _check_binary_size(path, readelf, threshold, os_str, arch, build_config):
 
     if threshold is not None and sections_total > threshold:
         raise RuntimeError(
-            "Sections total size for {} of {} exceeds threshold of {} by {}. On-disk size={}".format(
-                path, sections_total, threshold, sections_total - threshold, ondisk_size
-            )
+            f"Sections total size for {path} of {sections_total} exceeds threshold of {threshold} by {sections_total - threshold}. On-disk size={ondisk_size}"
         )
 
 
diff --git a/tools/ci_build/github/linux/upload_code_coverage_data.sh b/tools/ci_build/github/linux/upload_code_coverage_data.sh
index 2f63e4c2fe087..cba54a421d511 100755
--- a/tools/ci_build/github/linux/upload_code_coverage_data.sh
+++ b/tools/ci_build/github/linux/upload_code_coverage_data.sh
@@ -2,5 +2,5 @@
 # Copyright (c) Microsoft Corporation. All rights reserved.
 # Licensed under the MIT License.
 set -x -e
-/usr/bin/env python3 -m pip install -r $BUILD_SOURCESDIRECTORY/tools/ci_build/github/windows/post_to_dashboard/requirements.txt
+/usr/bin/env python3 -m pip install --user -r $BUILD_SOURCESDIRECTORY/tools/ci_build/github/windows/post_to_dashboard/requirements.txt
 $BUILD_SOURCESDIRECTORY/tools/ci_build/github/windows/post_code_coverage_to_dashboard.py --commit_hash=$BUILD_SOURCEVERSION --report_file $1 --report_url $2 --branch $BUILD_SOURCEBRANCHNAME --arch $3 --os $4 --build_config $5
\ No newline at end of file
diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile
index 4767c74afd28f..496b57b417fbd 100644
--- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile
+++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile
@@ -112,7 +112,7 @@ RUN pip install \
     cerberus \
     sympy \
     h5py \
-    datasets==1.9.0 \
+    datasets==2.17.0 \
     requests \
     sacrebleu==1.5.1 \
     sacremoses \
@@ -131,7 +131,7 @@ RUN pip install \
 # Install migraphx
 RUN apt update && apt install -y migraphx
 
-ENV ORTMODULE_ONNX_OPSET_VERSION=15
+ENV ORTMODULE_ONNX_OPSET_VERSION=17
 
 ARG BUILD_UID=1001
 ARG BUILD_USER=onnxruntimedev
diff --git a/tools/ci_build/github/windows/post_binary_sizes_to_dashboard.py b/tools/ci_build/github/windows/post_binary_sizes_to_dashboard.py
index acca4fb13c45a..a9667fe4d0654 100644
--- a/tools/ci_build/github/windows/post_binary_sizes_to_dashboard.py
+++ b/tools/ci_build/github/windows/post_binary_sizes_to_dashboard.py
@@ -49,7 +49,7 @@ def get_binary_sizes(size_data_file):
                 break
             linedata = line.strip().split(",")
             tablerow = {}
-            for i in range(0, len(headers)):
+            for i in range(len(headers)):
                 if headers[i] == "size":
                     tablerow[headers[i]] = int(linedata[i])
                 else:
diff --git a/tools/ci_build/op_registration_utils.py b/tools/ci_build/op_registration_utils.py
index 3fd01253a3e37..811ce424eae10 100644
--- a/tools/ci_build/op_registration_utils.py
+++ b/tools/ci_build/op_registration_utils.py
@@ -104,14 +104,12 @@ def process_registration(
         :param end_version: End version or None if unversioned registration
         :param type: Type or types used in registration, if this is a typed registration
         """
-        pass
 
     def process_other_line(self, line):
         """
         Process a line that does not contain a kernel registration
         :param line: Original line
         """
-        pass
 
     def ok(self):
         """
diff --git a/tools/ci_build/op_registration_validator.py b/tools/ci_build/op_registration_validator.py
index 5c7edfa88a48b..d92050a31f967 100644
--- a/tools/ci_build/op_registration_validator.py
+++ b/tools/ci_build/op_registration_validator.py
@@ -45,7 +45,7 @@ def domain_and_op_str(self):
 
 
 def _log_registration_error(r: RegistrationInfo, message: str):
-    log.error("Invalid registration for {}. {}\n{}".format(r.domain_and_op_str(), message, "".join(r.lines)))
+    log.error("Invalid registration for %s. %s\n%s", r.domain_and_op_str(), message, "".join(r.lines))
 
 
 class RegistrationValidator(op_registration_utils.RegistrationProcessor):
diff --git a/tools/doc/rename_folders.py b/tools/doc/rename_folders.py
index cc64775ae158d..90d800f2a4498 100644
--- a/tools/doc/rename_folders.py
+++ b/tools/doc/rename_folders.py
@@ -3,6 +3,7 @@
 This extension does not publish any folder starting with `_`.
 These folders need to be renamed.
 """
+
 import os
 import re
 
diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py
index 09fe99d36cc34..31c920c6e4438 100644
--- a/tools/nuget/generate_nuspec_for_native_nuget.py
+++ b/tools/nuget/generate_nuspec_for_native_nuget.py
@@ -324,10 +324,12 @@ def generate_metadata(line_list, args):
     generate_owners(metadata_list, "Microsoft")
     generate_description(metadata_list, args.package_name)
     generate_copyright(metadata_list, "\xc2\xa9 " + "Microsoft Corporation. All rights reserved.")
-    generate_tags(
-        metadata_list, "ONNX ONNX Runtime Machine Learning"
-    ) if "Microsoft.ML.OnnxRuntime.Training." in args.package_name else generate_tags(
-        metadata_list, "native ONNX ONNXRuntime-Training Learning-on-The-Edge On-Device-Training MachineLearning"
+    (
+        generate_tags(metadata_list, "ONNX ONNX Runtime Machine Learning")
+        if "Microsoft.ML.OnnxRuntime.Training." in args.package_name
+        else generate_tags(
+            metadata_list, "native ONNX ONNXRuntime-Training Learning-on-The-Edge On-Device-Training MachineLearning"
+        )
     )
     generate_icon(metadata_list, "ORT_icon_for_light_bg.png")
     generate_license(metadata_list)
@@ -732,7 +734,7 @@ def generate_files(line_list, args):
         )
 
     if args.execution_provider == "openvino":
-        openvino_path = get_env_var("INTEL_OPENVINO_DIR")
+        get_env_var("INTEL_OPENVINO_DIR")
         files_list.append(
             "<file src="
             + '"'
@@ -750,32 +752,6 @@ def generate_files(line_list, args):
             + '\\native" />'
         )
 
-        if is_windows():
-            dll_list_path = os.path.join(openvino_path, "runtime\\bin\\intel64\\Release\\")
-            tbb_list_path = os.path.join(openvino_path, "runtime\\3rdparty\\tbb\\bin\\")
-
-            for dll_element in os.listdir(dll_list_path):
-                if dll_element.endswith("dll"):
-                    files_list.append(
-                        "<file src="
-                        + '"'
-                        + os.path.join(dll_list_path, dll_element)
-                        + runtimes_target
-                        + args.target_architecture
-                        + '\\native" />'
-                    )
-
-            for tbb_element in os.listdir(tbb_list_path):
-                if tbb_element.endswith("dll"):
-                    files_list.append(
-                        "<file src="
-                        + '"'
-                        + os.path.join(tbb_list_path, tbb_element)
-                        + runtimes_target
-                        + args.target_architecture
-                        + '\\native" />'
-                    )
-
     if args.execution_provider == "cuda" or is_cuda_gpu_win_sub_package and not is_ado_packaging_build:
         files_list.append(
             "<file src="
diff --git a/tools/python/dump_ort_model.py b/tools/python/dump_ort_model.py
index 2177c42f5bc35..b9e3bfa0d3bcd 100644
--- a/tools/python/dump_ort_model.py
+++ b/tools/python/dump_ort_model.py
@@ -29,10 +29,10 @@ def __init__(self, model_path: str):
 
     def _dump_initializers(self, graph: fbs.Graph):
         print("Initializers:")
-        for idx in range(0, graph.InitializersLength()):
+        for idx in range(graph.InitializersLength()):
             tensor = graph.Initializers(idx)
             dims = []
-            for dim in range(0, tensor.DimsLength()):
+            for dim in range(tensor.DimsLength()):
                 dims.append(tensor.Dims(dim))
 
             print(f"{tensor.Name().decode()} data_type={tensor.DataType()} dims={dims}")
@@ -40,7 +40,7 @@ def _dump_initializers(self, graph: fbs.Graph):
 
     def _dump_nodeargs(self, graph: fbs.Graph):
         print("NodeArgs:")
-        for idx in range(0, graph.NodeArgsLength()):
+        for idx in range(graph.NodeArgsLength()):
             node_arg = graph.NodeArgs(idx)
             type = node_arg.Type()
             if not type:
@@ -57,7 +57,7 @@ def _dump_nodeargs(self, graph: fbs.Graph):
                 shape = tensor_type_and_shape.Shape()
                 if shape:
                     dims = []
-                    for dim in range(0, shape.DimLength()):
+                    for dim in range(shape.DimLength()):
                         d = shape.Dim(dim).Value()
                         if d.DimType() == fbs.DimensionValueType.DimensionValueType.VALUE:
                             dims.append(str(d.DimValue()))
@@ -76,8 +76,8 @@ def _dump_node(self, node: fbs.Node):
         domain = node.Domain().decode() or "ai.onnx"  # empty domain defaults to ai.onnx
         since_version = node.SinceVersion()
 
-        inputs = [node.Inputs(i).decode() for i in range(0, node.InputsLength())]
-        outputs = [node.Outputs(i).decode() for i in range(0, node.OutputsLength())]
+        inputs = [node.Inputs(i).decode() for i in range(node.InputsLength())]
+        outputs = [node.Outputs(i).decode() for i in range(node.OutputsLength())]
         print(
             f"{node.Index()}:{node.Name().decode()}({domain}:{optype}:{since_version}) "
             f'inputs=[{",".join(inputs)}] outputs=[{",".join(outputs)}]'
@@ -91,12 +91,12 @@ def _dump_graph(self, graph: fbs.Graph):
         self._dump_initializers(graph)
         self._dump_nodeargs(graph)
         print("Nodes:")
-        for i in range(0, graph.NodesLength()):
+        for i in range(graph.NodesLength()):
             node = graph.Nodes(i)
             self._dump_node(node)
 
             # Read all the attributes
-            for j in range(0, node.AttributesLength()):
+            for j in range(node.AttributesLength()):
                 attr = node.Attributes(j)
                 attr_type = attr.Type()
                 if attr_type == fbs.AttributeType.AttributeType.GRAPH:
@@ -107,7 +107,7 @@ def _dump_graph(self, graph: fbs.Graph):
                     # the ONNX spec doesn't currently define any operators that have multiple graphs in an attribute
                     # so entering this 'elif' isn't currently possible
                     print(f"## Subgraphs for {node.OpType().decode()}.{attr.Name().decode()} ##")
-                    for k in range(0, attr.GraphsLength()):
+                    for k in range(attr.GraphsLength()):
                         print(f"## Subgraph {k} ##")
                         self._dump_graph(attr.Graphs(k))
                         print(f"## End Subgraph {k} ##")
diff --git a/tools/python/find_optimizer_opset_version_updates_required.py b/tools/python/find_optimizer_opset_version_updates_required.py
index 8a5e57b51e38d..b46f7e4a54d9c 100644
--- a/tools/python/find_optimizer_opset_version_updates_required.py
+++ b/tools/python/find_optimizer_opset_version_updates_required.py
@@ -199,9 +199,7 @@ def find_potential_issues(root_dir, op_to_opset):
                 latest = op_to_opset[op]
                 if int(latest) != int(last_version):
                     log.warning(
-                        "Newer opset found for {}. Latest:{} Optimizer support ends at {}. File:{}".format(
-                            op, latest, last_version, file
-                        )
+                        f"Newer opset found for {op}. Latest:{latest} Optimizer support ends at {last_version}. File:{file}"
                     )
             else:
                 log.error(f"Failed to find version information for {op}. File:{file}")
diff --git a/tools/python/gen_contrib_doc.py b/tools/python/gen_contrib_doc.py
index accab96bd3593..ab9421b395326 100644
--- a/tools/python/gen_contrib_doc.py
+++ b/tools/python/gen_contrib_doc.py
@@ -359,11 +359,7 @@ def main(output_path: str, domain_filter: [str]):
 
             for _, namemap in supportmap:
                 for n, schema, versions in namemap:  # noqa: B007
-                    s = '  * {}<a href="#{}">{}</a>\n'.format(
-                        support_level_str(schema.support_level),
-                        format_name_with_domain(domain, n),
-                        format_name_with_domain(domain, n),
-                    )
+                    s = f'  * {support_level_str(schema.support_level)}<a href="#{format_name_with_domain(domain, n)}">{format_name_with_domain(domain, n)}</a>\n'
                     fout.write(s)
 
         fout.write("\n")
diff --git a/tools/python/run_CIs_for_branch.py b/tools/python/run_CIs_for_branch.py
new file mode 100644
index 0000000000000..975ea2b988d75
--- /dev/null
+++ b/tools/python/run_CIs_for_branch.py
@@ -0,0 +1,153 @@
+#!/usr/bin/env python3
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import argparse
+import json
+import os
+import subprocess
+import sys
+import typing
+
+from run_CIs_for_external_pr import get_pipeline_names
+from util.platform_helpers import is_windows
+
+
+class DefaultArgsRawHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
+    pass
+
+
+def _parse_args():
+    parser = argparse.ArgumentParser(
+        os.path.basename(__file__),
+        formatter_class=DefaultArgsRawHelpFormatter,
+        description="""Run the CIs used to validate PRs for the specified branch.
+
+        If not specified, the branch will be inferred (if possible) by running `git branch --show-current`.
+
+        If specified, the `--include` filter is applied first, followed by any `--exclude` filter.
+        `--include` and `--exclude` can be specified multiple times to accumulate values to include/exclude.
+
+        Requires the Azure CLI with DevOps extension to be installed.
+          Azure CLI: https://learn.microsoft.com/en-us/cli/azure/install-azure-cli
+          DevOps extension: https://github.com/Azure/azure-devops-cli-extension
+
+        Configuration:
+          Login:`az login`
+          Configure ORT repo as default:
+            `az devops configure --defaults organization=https://dev.azure.com/onnxruntime project=onnxruntime`
+
+        Example usage:
+          List all CIs
+            `python run_CIs_for_branch.py --dry-run my/BranchName`
+          Run all CIs
+            `python run_CIs_for_branch.py my/BranchName`
+          Run only Linux CIs
+            `python run_CIs_for_branch.py --include linux my/BranchName`
+          Exclude training CIs
+            `python run_CIs_for_branch.py --exclude training my/BranchName`
+          Run non-training Linux CIs
+            `python run_CIs_for_branch.py --include linux --exclude training my/BranchName`
+        """,
+    )
+
+    current_branch = None
+    get_branch_result = subprocess.run(["git", "branch", "--show-current"], capture_output=True, text=True, check=False)
+    if get_branch_result.returncode == 0:
+        current_branch = get_branch_result.stdout.strip()
+
+    parser.add_argument(
+        "-i", "--include", action="append", type=str, help="Include CIs that match this string. Case insensitive."
+    )
+    parser.add_argument(
+        "-e", "--exclude", action="append", type=str, help="Exclude CIs that match this string. Case insensitive."
+    )
+    parser.add_argument("--dry-run", action="store_true", help="Print selected CIs but do not run them.")
+    parser.add_argument(
+        "branch",
+        type=str,
+        nargs="?",
+        default=current_branch,
+        help="Specify the branch to run. Default is current branch if available.",
+    )
+
+    args = parser.parse_args()
+    if not args.branch:
+        raise ValueError("Branch was unable to be inferred and must be specified")
+
+    return args
+
+
+def _run_az_pipelines_command(command: typing.List[str]):
+    try:
+        az = "az.cmd" if is_windows() else "az"
+        az_output = subprocess.run([az, "pipelines", *command], capture_output=True, text=True, check=True)
+    except subprocess.CalledProcessError as cpe:
+        print(cpe)
+        print(cpe.stderr)
+        sys.exit(-1)
+
+    return az_output
+
+
+def main():
+    args = _parse_args()
+    branch = args.branch
+
+    # To debug available pipelines:
+    # az_out = az_pipelines = _run_az_pipelines_command(["list"])
+    # pipeline_info = json.loads(az_out.stdout)
+    # print(pipeline_info)
+
+    pipelines = get_pipeline_names()
+    pipelines_to_run = []
+    if args.include:
+        values = [i.lower().strip() for i in args.include]
+        for p in pipelines:
+            include = False
+            for value in values:
+                if value in p.lower():
+                    include = True
+                    break
+
+            if include:
+                print(f"Including {p}")
+                pipelines_to_run.append(p)
+    else:
+        pipelines_to_run = pipelines
+
+    if args.exclude:
+        values = [e.lower().strip() for e in args.exclude]
+        cur_pipelines = pipelines_to_run
+        pipelines_to_run = []
+        for p in cur_pipelines:
+            exclude = False
+            for value in values:
+                if value in p.lower():
+                    exclude = True
+                    break
+
+            if exclude:
+                print(f"Excluding {p}")
+            else:
+                pipelines_to_run.append(p)
+
+    print(f"Pipelines to run for {args.branch}:")
+    for p in pipelines_to_run:
+        print(f"\t{p}")
+
+    if args.dry_run:
+        sys.exit(0)
+
+    for pipeline in pipelines_to_run:
+        az_out = _run_az_pipelines_command(["run", "--branch", branch, "--name", pipeline])
+        run_output = json.loads(az_out.stdout)
+        if "id" in run_output:
+            build_url = f"https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId={run_output['id']}"
+            print(f"{pipeline} build results: {build_url}&view=results")
+        else:
+            raise ValueError("Build id was not found in az output:\n" + run_output)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py
index 7a77839c4a4e7..dcafe898b3bdf 100644
--- a/tools/python/run_CIs_for_external_pr.py
+++ b/tools/python/run_CIs_for_external_pr.py
@@ -3,13 +3,54 @@
 # Licensed under the MIT License.
 
 import argparse
+import json
 import os
 import subprocess
 import sys
 import typing
 
 
-def parse_args():
+def get_pipeline_names():
+    # Current pipelines. These change semi-frequently and may need updating.
+    # There is no easy way to get the list of "required" pipelines using `azp` before they are run,
+    # so we need to maintain this list manually.
+    # NOTE: This list is also used by run_CIs_for_branch.py
+    pipelines = [
+        # windows
+        "Windows ARM64 QNN CI Pipeline",
+        "Windows x64 QNN CI Pipeline",
+        "Windows CPU CI Pipeline",
+        "Windows GPU CI Pipeline",
+        "Windows GPU TensorRT CI Pipeline",
+        "ONNX Runtime Web CI Pipeline",
+        # linux
+        "Linux CPU CI Pipeline",
+        "Linux CPU Minimal Build E2E CI Pipeline",
+        "Linux GPU CI Pipeline",
+        "Linux GPU TensorRT CI Pipeline",
+        "Linux OpenVINO CI Pipeline",
+        "Linux QNN CI Pipeline",
+        # mac
+        "MacOS CI Pipeline",
+        # training
+        "orttraining-amd-gpu-ci-pipeline",
+        "orttraining-linux-ci-pipeline",
+        "orttraining-linux-gpu-ci-pipeline",
+        "orttraining-ortmodule-distributed",
+        # checks
+        "onnxruntime-binary-size-checks-ci-pipeline",
+        # big models
+        "Big Models",
+        # not currently required, but running ensures we're hitting all mobile platforms
+        "Android CI Pipeline",
+        "iOS CI Pipeline",
+        "ONNX Runtime React Native CI Pipeline",
+    ]
+
+    return pipelines
+
+
+def _parse_args():
     parser = argparse.ArgumentParser(
         os.path.basename(__file__),
         formatter_class=argparse.RawDescriptionHelpFormatter,
@@ -25,7 +66,7 @@ def parse_args():
     return args
 
 
-def run_gh_pr_command(command: typing.List[str], check=True):
+def run_gh_pr_command(command: typing.List[str], check: bool = True):
     try:
         return subprocess.run(["gh", "pr", *command], capture_output=True, text=True, check=check)
     except subprocess.CalledProcessError as cpe:
@@ -35,23 +76,25 @@ def run_gh_pr_command(command: typing.List[str], check=True):
 
 
 def main():
-    args = parse_args()
+    args = _parse_args()
     pr_id = args.pr
 
     # validate PR
-    gh_out = run_gh_pr_command(["view", pr_id])
-    info = gh_out.stdout.split("\n")
-    for line in info:
-        pieces = line.split("\t")
-        if len(pieces) != 2:
-            continue
-
-        if pieces[0] == "state:":
-            if pieces[1] != "OPEN":
-                print(f"PR {pr_id} is not OPEN. Currently in state {pieces[1]}.")
-                sys.exit(-1)
-
-    print("Check passed pipelines")
+    print("Checking PR is open")
+    gh_out = run_gh_pr_command(["view", "--json", "state", pr_id])
+    info = json.loads(gh_out.stdout)
+    if "state" not in info:
+        print(f"Could not get current state from `gh pr view` response of\n{gh_out.stdout}")
+        sys.exit(-1)
+
+    if info["state"] != "OPEN":
+        print(f"PR {pr_id} is not OPEN. Currently in state {info['state']}.")
+        sys.exit(0)
+
+    # This will return CIs that have run previously but not passed. We filter the CIs to run based on this, so it's
+    # fine for the initial response to have no info in it.
+    # `gh pr checks` exits with non-zero exit code when failures in pipeline exist, so we set `check` to False.
+    print("Checking for pipelines that have passed.")
     gh_out = run_gh_pr_command(["checks", pr_id, "--required"], check=False)
     # output format is a tab separated list of columns:
     # (pipeline name) "\t" (status) "\t" (ran time) "\t" (url)
@@ -61,52 +104,21 @@ def main():
         if len(columns) == 4 and columns[1] == "pass"
     ]
 
-    print("Adding azp run commands")
-
-    # Current pipelines. These change semi-frequently and may need updating.
-    #
-    # Note: there is no easy way to get the list for azp "required" pipelines before they starts.
-    #       we need to maintain this list manually.
-    #
-    pipelines = [
-        # windows
-        "Windows ARM64 QNN CI Pipeline",
-        "Windows x64 QNN CI Pipeline",
-        "Windows CPU CI Pipeline",
-        "Windows GPU CI Pipeline",
-        "Windows GPU TensorRT CI Pipeline",
-        "ONNX Runtime Web CI Pipeline",
-        # linux
-        "Linux CPU CI Pipeline",
-        "Linux CPU Minimal Build E2E CI Pipeline",
-        "Linux GPU CI Pipeline",
-        "Linux GPU TensorRT CI Pipeline",
-        "Linux OpenVINO CI Pipeline",
-        "Linux QNN CI Pipeline",
-        # mac
-        "MacOS CI Pipeline",
-        # training
-        "orttraining-amd-gpu-ci-pipeline",
-        "orttraining-linux-ci-pipeline",
-        "orttraining-linux-gpu-ci-pipeline",
-        "orttraining-ortmodule-distributed",
-        # checks
-        "onnxruntime-python-checks-ci-pipeline",
-        "onnxruntime-binary-size-checks-ci-pipeline",
-        # not currently required, but running ensures we're hitting all mobile platforms
-        "Android CI Pipeline",
-        "iOS CI Pipeline",
-        "ONNX Runtime React Native CI Pipeline",
-    ]
+    pipelines = get_pipeline_names()
 
     # remove pipelines that have already run successfully
     pipelines = [p for p in pipelines if p not in checked_pipelines]
 
+    print("Pipelines to run:")
+    for p in pipelines:
+        print("\t" + p)
+
     # azp run is limited to 10 pipelines at a time
     max_pipelines_per_comment = 10
     start = 0
     num_pipelines = len(pipelines)
 
+    print("Adding azp run commands")
     while start < num_pipelines:
         end = start + max_pipelines_per_comment
         if end > num_pipelines:
diff --git a/tools/python/util/convert_onnx_models_to_ort.py b/tools/python/util/convert_onnx_models_to_ort.py
index 18bba78661796..08e840092bc22 100644
--- a/tools/python/util/convert_onnx_models_to_ort.py
+++ b/tools/python/util/convert_onnx_models_to_ort.py
@@ -302,9 +302,7 @@ def convert_onnx_models_to_ort(
 
     for optimization_style in optimization_styles:
         print(
-            "Converting models with optimization style '{}' and level '{}'".format(
-                optimization_style.name, optimization_level_str
-            )
+            f"Converting models with optimization style '{optimization_style.name}' and level '{optimization_level_str}'"
         )
 
         converted_models = _convert(
@@ -330,9 +328,9 @@ def convert_onnx_models_to_ort(
                 )
                 session_options_config_entries_for_second_conversion = session_options_config_entries.copy()
                 # Limit the optimizations to those that can run in a model with runtime optimizations.
-                session_options_config_entries_for_second_conversion[
-                    "optimization.minimal_build_optimizations"
-                ] = "apply"
+                session_options_config_entries_for_second_conversion["optimization.minimal_build_optimizations"] = (
+                    "apply"
+                )
 
                 print(
                     "Converting models again without runtime optimizations to generate a complete config file. "
@@ -351,9 +349,7 @@ def convert_onnx_models_to_ort(
                 )
 
             print(
-                "Generating config file from ORT format models with optimization style '{}' and level '{}'".format(
-                    optimization_style.name, optimization_level_str
-                )
+                f"Generating config file from ORT format models with optimization style '{optimization_style.name}' and level '{optimization_level_str}'"
             )
 
             config_file = _create_config_file_path(
diff --git a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py
index f8cc34e04afa0..548d4a8ba6c45 100644
--- a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py
+++ b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py
@@ -230,7 +230,7 @@ def run_check_with_model(
     if unsupported_ops:
         logger.info("Unsupported operators:")
         for entry in sorted(unsupported_ops):
-            logger.info("  " + entry)
+            logger.info("  " + entry)  # noqa: G003
 
     if unsupported:
         logger.info("\nModel is not supported by the pre-built package due to unsupported types and/or operators.")
diff --git a/tools/python/util/ort_format_model/operator_type_usage_processors.py b/tools/python/util/ort_format_model/operator_type_usage_processors.py
index 22d7dff3e13b2..598549c42b60a 100644
--- a/tools/python/util/ort_format_model/operator_type_usage_processors.py
+++ b/tools/python/util/ort_format_model/operator_type_usage_processors.py
@@ -92,7 +92,6 @@ def to_config_entry(self):
         Generate a configuration file entry in JSON format with the required types for the operator.
         :return: JSON string with required type information.
         """
-        pass
 
     @abstractmethod
     def from_config_entry(self, entry: str):
@@ -101,7 +100,6 @@ def from_config_entry(self, entry: str):
         NOTE: Any existing type information should be cleared prior to re-creating from a config file entry.
         :param entry: Configuration file entry
         """
-        pass
 
 
 class DefaultTypeUsageProcessor(TypeUsageProcessor):
@@ -182,9 +180,7 @@ def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
             # Don't know of any ops where the number of outputs changed across versions, so require a valid length
             if o >= node.OutputsLength():
                 raise RuntimeError(
-                    "Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.".format(
-                        node.OutputsLength(), self.name, o
-                    )
+                    f"Node has {node.OutputsLength()} outputs. Tracker for {self.name} incorrectly configured as it requires {o}."
                 )
 
             type_str = value_name_to_typestr(node.Outputs(o), value_name_to_typeinfo)
@@ -514,7 +510,6 @@ def is_typed_registration_needed(self, domain: str, optype: str, type_registrati
         :param type_registration_str: Type string from kernel registration
         :return: True is required. False if not.
         """
-        pass
 
     @abstractmethod
     def get_cpp_entries(self):
@@ -522,7 +517,6 @@ def get_cpp_entries(self):
         Get the C++ code that specifies the operator types to enable.
         :return: List of strings. One line of C++ code per entry.
         """
-        pass
 
 
 class OperatorTypeUsageManager:
@@ -644,9 +638,7 @@ def __init__(self, globally_allowed_types: typing.Set[str]):
 
         if not globally_allowed_types.issubset(self._valid_allowed_types):
             raise ValueError(
-                "Globally allowed types must all be valid. Invalid types: {}".format(
-                    sorted(globally_allowed_types - self._valid_allowed_types)
-                )
+                f"Globally allowed types must all be valid. Invalid types: {sorted(globally_allowed_types - self._valid_allowed_types)}"
             )
 
         self._globally_allowed_types = globally_allowed_types
diff --git a/tools/python/util/ort_format_model/ort_model_processor.py b/tools/python/util/ort_format_model/ort_model_processor.py
index d3a07efe92aa5..b20f3a0cfd97d 100644
--- a/tools/python/util/ort_format_model/ort_model_processor.py
+++ b/tools/python/util/ort_format_model/ort_model_processor.py
@@ -35,7 +35,7 @@ def _setup_type_info(graph: fbs.Graph, outer_scope_value_typeinfo={}):  # noqa:
         :return: Dictionary of NodeArg name to TypeInfo
         """
         value_name_to_typeinfo = outer_scope_value_typeinfo.copy()
-        for j in range(0, graph.NodeArgsLength()):
+        for j in range(graph.NodeArgsLength()):
             n = graph.NodeArgs(j)
             value_name_to_typeinfo[n.Name()] = n.Type()  # TypeInfo for this NodeArg's name
 
@@ -57,7 +57,7 @@ def _process_graph(self, graph: fbs.Graph, outer_scope_value_typeinfo: dict):
         # Merge the TypeInfo for all values in this level of the graph with the outer scope value TypeInfo.
         value_name_to_typeinfo = OrtFormatModelProcessor._setup_type_info(graph, outer_scope_value_typeinfo)
 
-        for i in range(0, graph.NodesLength()):
+        for i in range(graph.NodesLength()):
             node = graph.Nodes(i)
 
             optype = node.OpType().decode()
@@ -69,7 +69,7 @@ def _process_graph(self, graph: fbs.Graph, outer_scope_value_typeinfo: dict):
                 self._op_type_processors.process_node(node, value_name_to_typeinfo)
 
             # Read all the attributes
-            for j in range(0, node.AttributesLength()):
+            for j in range(node.AttributesLength()):
                 attr = node.Attributes(j)
                 attr_type = attr.Type()
                 if attr_type == fbs.AttributeType.AttributeType.GRAPH:
@@ -77,7 +77,7 @@ def _process_graph(self, graph: fbs.Graph, outer_scope_value_typeinfo: dict):
                 elif attr_type == fbs.AttributeType.AttributeType.GRAPHS:
                     # the ONNX spec doesn't currently define any operators that have multiple graphs in an attribute
                     # so entering this 'elif' isn't currently possible
-                    for k in range(0, attr.GraphsLength()):
+                    for k in range(attr.GraphsLength()):
                         self._process_graph(attr.Graphs(k), value_name_to_typeinfo)
 
     def process(self):
diff --git a/winml/lib/Api/HardwareCoreEnumerator.cpp b/winml/lib/Api/HardwareCoreEnumerator.cpp
index a89ac561f8860..d04e276347170 100644
--- a/winml/lib/Api/HardwareCoreEnumerator.cpp
+++ b/winml/lib/Api/HardwareCoreEnumerator.cpp
@@ -14,7 +14,7 @@ struct LogicalProcessorInformation {
 
 struct CoreCounter {
   uint32_t PhysicalCores = 0;
-  uint32_t SocDieCores = 0;
+  uint32_t Num2CacheCores = 0;
 };
 
 static LogicalProcessorInformation GetLogicalProcessorInfos(LOGICAL_PROCESSOR_RELATIONSHIP relationship) {
@@ -75,7 +75,7 @@ static CoreCounter GetNumberOPhysicalAndEngineeringCores() {
     read += currentProcessorInfo->Size;
   }
 
-  cores.SocDieCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask);
+  cores.Num2CacheCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask);
   return cores;
 }
 
@@ -83,8 +83,27 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() {
   // # of physical cores = # of P cores + # of E Cores + # of Soc Cores.
   // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores.
   auto cores = GetNumberOPhysicalAndEngineeringCores();
-  // We want to use the number of physical cores, but exclude soc cores
-  return cores.PhysicalCores - cores.SocDieCores;
+
+#if !defined(_M_ARM64EC) && !defined(_M_ARM64) && !defined(__aarch64__)
+  const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69};  // "GenuntelineI"
+  int regs_leaf0[4];
+  int regs_leaf7[4];
+  __cpuid(regs_leaf0, 0);
+  __cpuid(regs_leaf7, 0x7);
+
+  auto isIntel = (kVendorID_Intel[0] == regs_leaf0[1]) && (kVendorID_Intel[1] == regs_leaf0[2]) &&
+    (kVendorID_Intel[2] == regs_leaf0[3]);
+
+  auto isHybrid = (regs_leaf7[3] & (1 << 15));
+
+  if (isIntel && isHybrid) {
+    // We want to use the number of physical cores, but exclude soc cores
+    // On Intel Hybrid processors, numSocCores == cores.Num2CacheCores
+    return cores.PhysicalCores - cores.Num2CacheCores;
+  }
+#endif
+
+  return cores.PhysicalCores;
 }
 
 }  // namespace WINMLP
diff --git a/winml/test/model/model_tests.cpp b/winml/test/model/model_tests.cpp
index f40f08ad2696d..27d74d7d6b034 100644
--- a/winml/test/model/model_tests.cpp
+++ b/winml/test/model/model_tests.cpp
@@ -118,7 +118,7 @@ TEST_P(ModelTest, Run) {
   LearningModelDevice device = nullptr;
   LearningModelSession session = nullptr;
   LearningModelBinding binding = nullptr;
-  WINML_EXPECT_NO_THROW(model = LearningModel::LoadFromFilePath(m_testCase->GetModelUrl()));
+  WINML_EXPECT_NO_THROW(model = LearningModel::LoadFromFilePath(m_testCase->GetModelUrl().native()));
   WINML_EXPECT_NO_THROW(device = LearningModelDevice(m_deviceKind));
   WINML_EXPECT_NO_THROW(session = LearningModelSession(model, device));
   for (size_t i = 0; i < m_testCase->GetDataCount(); i++) {