diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 1740140094df2..e60da111afd36 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -138,7 +138,8 @@ Do not modify directly.*
|||[7, 8]|**T** = tensor(double), tensor(float)
**T1** = tensor(bool)|
|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
-|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)|
+|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(double), tensor(float)|
+|||[16, 19]|**T1** = tensor(float)
**T2** = tensor(float)|
|HammingWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|HannWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float)|
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 2ca3b1cdf817e..4553e7ee18913 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -798,7 +798,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 18, If);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, RoiAlign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, RoiAlign);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, GridSample);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 19, float, GridSample);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, string, Where);
@@ -960,6 +960,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh
// Opset 20
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, GridSample);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, AffineGrid);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
@@ -2183,8 +2185,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
RoiAlign)>,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2401,6 +2403,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
// Opset 20
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cpu/tensor/grid_sample.cc b/onnxruntime/core/providers/cpu/tensor/grid_sample.cc
index c58a7d8337114..a83ba378d7f1e 100644
--- a/onnxruntime/core/providers/cpu/tensor/grid_sample.cc
+++ b/onnxruntime/core/providers/cpu/tensor/grid_sample.cc
@@ -11,17 +11,23 @@
namespace onnxruntime {
-#define REGISTER_KERNEL_TYPED(T) \
- ONNX_CPU_OPERATOR_TYPED_KERNEL( \
- GridSample, \
- 16, \
- T, \
- KernelDefBuilder() \
- .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
- .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
- GridSample);
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(GridSample, kOnnxDomain, 16, 19, T, kCpuExecutionProvider, \
+ KernelDefBuilder() \
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
+ GridSample);
+
+#define REGISTER_KERNEL_TYPED_20(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX(GridSample, kOnnxDomain, 20, T, kCpuExecutionProvider, \
+ KernelDefBuilder() \
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
+ GridSample);
REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED_20(float)
+REGISTER_KERNEL_TYPED_20(double)
// Restore normalized location to actual image location
// When align_corners is true:
@@ -44,16 +50,15 @@ T GsDenormalize(T n, int64_t length, bool align_corners) {
}
// Reflect by the near border till within the borders
-// Use float for borders to avoid potential issues with integer T
template
-T GsReflect(T x, float x_min, float x_max) {
- float dx = {};
- float fx = static_cast(x);
- float range = x_max - x_min;
+T GsReflect(T x, T x_min, T x_max) {
+ T dx = {};
+ T fx = static_cast(x);
+ T range = x_max - x_min;
if (fx < x_min) {
dx = x_min - fx;
int n = static_cast(dx / range);
- float r = dx - n * range;
+ T r = dx - n * range;
if (n % 2 == 0) {
fx = x_min + r;
} else {
@@ -62,7 +67,7 @@ T GsReflect(T x, float x_min, float x_max) {
} else if (fx > x_max) {
dx = fx - x_max;
int n = static_cast(dx / range);
- float r = dx - n * range;
+ T r = dx - n * range;
if (n % 2 == 0) {
fx = x_max - r;
} else {
@@ -75,9 +80,9 @@ T GsReflect(T x, float x_min, float x_max) {
// Calculate cubic convolution interpolation coefficients
// ROBERT G. KEYS https://ieeexplore.ieee.org/document/1163711
-// Use float to avoid potential issues with integer T
-void GsGetCubicCoeffs(float x, float coeffs[4]) {
- constexpr float cubic_alpha = -0.75f;
+template
+void GsGetCubicCoeffs(T x, T coeffs[4]) {
+ constexpr T cubic_alpha = -0.75f;
x = std::abs(x);
coeffs[0] = ((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha;
coeffs[1] = ((cubic_alpha + 2) * x - (cubic_alpha + 3)) * x * x + 1;
@@ -86,9 +91,9 @@ void GsGetCubicCoeffs(float x, float coeffs[4]) {
}
template
-T GsBicubicInterpolate(T p[4][4], float x, float y) {
- float v[4] = {};
- float coeffs[4] = {};
+T GsBicubicInterpolate(T p[4][4], T x, T y) {
+ T v[4] = {};
+ T coeffs[4] = {};
GsGetCubicCoeffs(x, coeffs);
for (int64_t i = 0; i < 4; i++) {
v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3];
@@ -98,7 +103,7 @@ T GsBicubicInterpolate(T p[4][4], float x, float y) {
}
template
-T GridSample::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, float border[/* 4 */]) const {
+T GridSample::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, T border[/* 4 */]) const {
T pixel = {}; // default 0
if (padding_mode_ == Zeros) {
if (c >= 0 && c < W && r >= 0 && r < H) {
@@ -116,6 +121,27 @@ T GridSample::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, in
return pixel;
}
+template
+T GridSample::PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W, T border[/* 6 */]) const {
+ T pixel = {}; // default 0
+ if (padding_mode_ == Zeros) {
+ if (w >= 0 && w < W && h >= 0 && h < H && d >= 0 && d < D) {
+ pixel = image[d * H * W + h * W + w];
+ }
+ } else if (padding_mode_ == Border) {
+ w = std::clamp(w, 0, W - 1);
+ h = std::clamp(h, 0, H - 1);
+ d = std::clamp(d, 0, D - 1);
+ pixel = image[d * H * W + h * W + w];
+ } else { // (padding_mode_ == Reflection)
+ w = static_cast(GsReflect(static_cast(w), border[0], border[3]));
+ h = static_cast(GsReflect(static_cast(h), border[1], border[4]));
+ d = static_cast(GsReflect(static_cast(d), border[2], border[5]));
+ pixel = image[d * H * W + h * W + w];
+ }
+ return pixel;
+}
+
// When grid sampling, padding is applied before interpolation.
// For instance, in bilinear mode and zeros padding-mode, pixel p at actual
// image location (-0.5, -0.5)
@@ -134,113 +160,203 @@ Status GridSample::Compute(OpKernelContext* context) const {
const auto& input_dims = input->Shape();
const auto& grid_dims = grid->Shape();
- if (input_dims.NumDimensions() != 4 || grid_dims.NumDimensions() != 4) {
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only 4-D tensor is supported");
- }
+ int64_t data_dims = input_dims.NumDimensions() - 2;
+ ORT_ENFORCE(static_cast(grid_dims.NumDimensions()) == data_dims + 2,
+ "grid dimensions must be ", data_dims + 2, "for input dimension of ", data_dims);
+
+ ORT_ENFORCE(grid_dims[grid_dims.NumDimensions() - 1] == data_dims,
+ "Last dimension of grid: ", grid_dims[grid_dims.NumDimensions() - 1], ", expect ", data_dims);
+
+ ORT_ENFORCE(input_dims.NumDimensions() == 4 || input_dims.NumDimensions() == 5, "Only 4-D or 5-D tensor is supported");
auto N = input_dims[0];
auto C = input_dims[1];
- auto H_in = input_dims[2];
- auto W_in = input_dims[3];
- auto H_out = grid_dims[1];
- auto W_out = grid_dims[2];
ORT_ENFORCE(grid_dims[0] == N, "Grid batch size ", grid_dims[0], " does not match input batch size ", N);
- ORT_ENFORCE(grid_dims[3] == 2, "Last dimension of grid: ", grid_dims[3], ", expect 2");
- TensorShape Y_shape = {N, C, H_out, W_out};
- auto& Y = *context->Output(0, Y_shape);
- // Return early if the output tensor is going to be of size 0
- if (Y.Shape().Size() == 0) {
- return Status::OK();
+ if (input_dims.NumDimensions() == 5) {
+ ORT_ENFORCE(mode_ != Cubic, "Only support GridSample Cubic mode in 4-D cases.");
}
- // Force float here to avoid possible issue in integer T case
- float x_min = -0.5f;
- float x_max = W_in - 0.5f;
- float y_min = -0.5f;
- float y_max = H_in - 0.5f;
-
- if (align_corners_) {
- x_min = 0.f;
- x_max = W_in - 1.f;
- y_min = 0.f;
- y_max = H_in - 1.f;
- }
- float border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b
-
- concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr;
- for (int64_t n = 0; n < N; n++) {
- const T* grid_data = grid->Data() + n * (H_out * W_out) * 2;
- concurrency::ThreadPool::TrySimpleParallelFor(
- tp, onnxruntime::narrow(C),
- [&](std::ptrdiff_t c) {
- const T* X_data = input->Data() + (n * C + c) * (H_in * W_in);
- T* Y_data = Y.MutableData() + (n * C + c) * (H_out * W_out);
-
- for (int64_t oy = 0; oy < H_out; oy++) {
- for (int64_t ox = 0; ox < W_out; ox++) {
- const T* gridpoint = grid_data + (oy * W_out + ox) * 2;
- T* Y_gridpoint = Y_data + oy * W_out + ox;
- auto nx = gridpoint[0]; // normalized location
- auto ny = gridpoint[1];
- auto x = GsDenormalize(nx, W_in, align_corners_); // actual location
- auto y = GsDenormalize(ny, H_in, align_corners_);
-
- if (mode_ == Nearest) {
- x = static_cast(std::nearbyintf(static_cast(x)));
- y = static_cast(std::nearbyintf(static_cast(y)));
- }
+ if (data_dims == 2) {
+ // sample 2d;
+ auto H_in = input_dims[2];
+ auto W_in = input_dims[3];
+ auto H_out = grid_dims[1];
+ auto W_out = grid_dims[2];
+ TensorShape Y_shape = {N, C, H_out, W_out};
+ auto& Y = *context->Output(0, Y_shape);
+ // Return early if the output tensor is going to be of size 0
+ if (Y.Shape().Size() == 0) {
+ return Status::OK();
+ }
- if (x < x_min || x > x_max || y < y_min || y > y_max) { // out of bound
- if (padding_mode_ == Border) {
- // use original border in both align_corner cases
- x = std::clamp(x, static_cast(0), static_cast(W_in - 1));
- y = std::clamp(y, static_cast(0), static_cast(H_in - 1));
- } else if (padding_mode_ == Reflection) {
- x = GsReflect(x, x_min, x_max);
- y = GsReflect(y, y_min, y_max);
- }
- } // out of bound
+ T x_min = -0.5f;
+ T x_max = W_in - 0.5f;
+ T y_min = -0.5f;
+ T y_max = H_in - 0.5f;
- if (mode_ == Nearest) {
- // x, y are integers in all padding modes
- *Y_gridpoint = PixelAtGrid(X_data, static_cast(y), static_cast(x), H_in, W_in, border);
- continue;
- }
+ if (align_corners_) {
+ x_min = 0.f;
+ x_max = W_in - 1.f;
+ y_min = 0.f;
+ y_max = H_in - 1.f;
+ }
+ T border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b
- if (mode_ == Bilinear) {
- int64_t x1 = static_cast(std::floor(x));
- int64_t y1 = static_cast(std::floor(y));
- int64_t x2 = x1 + 1;
- int64_t y2 = y1 + 1;
-
- T p11 = PixelAtGrid(X_data, y1, x1, H_in, W_in, border);
- T p12 = PixelAtGrid(X_data, y1, x2, H_in, W_in, border);
- T p21 = PixelAtGrid(X_data, y2, x1, H_in, W_in, border);
- T p22 = PixelAtGrid(X_data, y2, x2, H_in, W_in, border);
-
- T dx2 = static_cast(x2) - x;
- T dx1 = x - static_cast(x1);
- T dy2 = static_cast(y2) - y;
- T dy1 = y - static_cast(y1);
- *Y_gridpoint = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22);
+ concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr;
+ for (int64_t n = 0; n < N; n++) {
+ const T* grid_data = grid->Data() + n * (H_out * W_out) * 2;
+ concurrency::ThreadPool::TrySimpleParallelFor(
+ tp, onnxruntime::narrow(C),
+ [&](std::ptrdiff_t c) {
+ const T* X_data = input->Data() + (n * C + c) * (H_in * W_in);
+ T* Y_data = Y.MutableData() + (n * C + c) * (H_out * W_out);
+
+ for (int64_t oy = 0; oy < H_out; oy++) {
+ for (int64_t ox = 0; ox < W_out; ox++) {
+ const T* gridpoint = grid_data + (oy * W_out + ox) * 2;
+ T* Y_gridpoint = Y_data + oy * W_out + ox;
+ auto nx = gridpoint[0]; // normalized location
+ auto ny = gridpoint[1];
+ auto x = GsDenormalize(nx, W_in, align_corners_); // actual location
+ auto y = GsDenormalize(ny, H_in, align_corners_);
+
+ if (mode_ == Nearest) {
+ x = static_cast(std::nearbyint(static_cast(x)));
+ y = static_cast(std::nearbyint(static_cast(y)));
+ // x, y are integers in all padding modes
+ *Y_gridpoint = PixelAtGrid(X_data, static_cast(y), static_cast(x), H_in, W_in, border);
+ } else if (mode_ == Linear) {
+ int64_t x1 = static_cast(std::floor(x));
+ int64_t y1 = static_cast(std::floor(y));
+ int64_t x2 = x1 + 1;
+ int64_t y2 = y1 + 1;
+
+ T p11 = PixelAtGrid(X_data, y1, x1, H_in, W_in, border);
+ T p12 = PixelAtGrid(X_data, y1, x2, H_in, W_in, border);
+ T p21 = PixelAtGrid(X_data, y2, x1, H_in, W_in, border);
+ T p22 = PixelAtGrid(X_data, y2, x2, H_in, W_in, border);
+
+ T dx2 = static_cast(x2) - x;
+ T dx1 = x - static_cast(x1);
+ T dy2 = static_cast(y2) - y;
+ T dy1 = y - static_cast(y1);
+ *Y_gridpoint = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22);
+ } else if (mode_ == Cubic) {
+ int64_t x0 = static_cast(std::floor(x)) - 1; // top-left corner of the bbox
+ int64_t y0 = static_cast(std::floor(y)) - 1;
+
+ T p[4][4] = {}; // [H][W]
+ for (int64_t h = 0; h < 4; h++) {
+ for (int64_t w = 0; w < 4; w++) {
+ p[h][w] = PixelAtGrid(X_data, h + y0, w + x0, H_in, W_in, border);
+ }
+ }
+ T dx = static_cast(x - x0 - 1);
+ T dy = static_cast(y - y0 - 1);
+ *Y_gridpoint = GsBicubicInterpolate(p, dx, dy);
+ }
}
- if (mode_ == Bicubic) {
- int64_t x0 = static_cast(std::floor(x)) - 1; // top-left corner of the bbox
- int64_t y0 = static_cast(std::floor(y)) - 1;
- T p[4][4] = {}; // [H][W]
- for (int64_t h = 0; h < 4; h++) {
- for (int64_t w = 0; w < 4; w++) {
- p[h][w] = PixelAtGrid(X_data, h + y0, w + x0, H_in, W_in, border);
+ }
+ });
+ }
+ } else if (data_dims == 3) {
+ // sample 3d;
+ auto D_in = input_dims[2];
+ auto H_in = input_dims[3];
+ auto W_in = input_dims[4];
+ auto D_out = grid_dims[1];
+ auto H_out = grid_dims[2];
+ auto W_out = grid_dims[3];
+ TensorShape Y_shape = {N, C, D_out, H_out, W_out};
+ auto& Y = *context->Output(0, Y_shape);
+ // Return early if the output tensor is going to be of size 0
+ if (Y.Shape().Size() == 0) {
+ return Status::OK();
+ }
+
+ T x_min = -0.5f;
+ T x_max = W_in - 0.5f;
+ T y_min = -0.5f;
+ T y_max = H_in - 0.5f;
+ T z_min = -0.5f;
+ T z_max = D_in - 0.5f;
+
+ if (align_corners_) {
+ x_min = 0.f;
+ x_max = W_in - 1.f;
+ y_min = 0.f;
+ y_max = H_in - 1.f;
+ z_min = 0.f;
+ z_max = D_in - 1.f;
+ }
+ T border[] = {x_min, y_min, z_min, x_max, y_max, z_max};
+
+ concurrency::ThreadPool* tp = D_out * H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr;
+ for (int64_t n = 0; n < N; n++) {
+ const T* grid_data = grid->Data() + n * (D_out * H_out * W_out) * 3;
+ concurrency::ThreadPool::TrySimpleParallelFor(
+ tp, onnxruntime::narrow(C),
+ [&](std::ptrdiff_t c) {
+ const T* X_data = input->Data() + (n * C + c) * (D_in * H_in * W_in);
+ T* Y_data = Y.MutableData() + (n * C + c) * (D_out * H_out * W_out);
+
+ for (int64_t oz = 0; oz < D_out; oz++) {
+ for (int64_t oy = 0; oy < H_out; oy++) {
+ for (int64_t ox = 0; ox < W_out; ox++) {
+ const T* gridpoint = grid_data + (oz * H_out * W_out + oy * W_out + ox) * 3;
+ T* Y_gridpoint = Y_data + oz * H_out * W_out + oy * W_out + ox;
+ auto nx = gridpoint[0]; // normalized location
+ auto ny = gridpoint[1];
+ auto nz = gridpoint[2];
+ auto x = GsDenormalize(nx, W_in, align_corners_); // actual location
+ auto y = GsDenormalize(ny, H_in, align_corners_);
+ auto z = GsDenormalize(nz, D_in, align_corners_);
+
+ if (mode_ == Nearest) {
+ x = static_cast(std::nearbyint(static_cast(x)));
+ y = static_cast(std::nearbyint(static_cast(y)));
+ z = static_cast(std::nearbyint(static_cast(z)));
+
+ // x, y are integers in all padding modes
+ *Y_gridpoint = PixelAtGrid3D(X_data, static_cast(z), static_cast(y), static_cast(x),
+ D_in, H_in, W_in, border);
+ } else if (mode_ == Linear) {
+ int64_t x1 = static_cast(std::floor(x));
+ int64_t y1 = static_cast(std::floor(y));
+ int64_t z1 = static_cast(std::floor(z));
+ int64_t x2 = x1 + 1;
+ int64_t y2 = y1 + 1;
+ int64_t z2 = z1 + 1;
+
+ T dx2 = static_cast(x2) - x;
+ T dx1 = x - static_cast(x1);
+ T dy2 = static_cast(y2) - y;
+ T dy1 = y - static_cast(y1);
+ T dz2 = static_cast(z2) - z;
+ T dz1 = z - static_cast(z1);
+
+ T p111 = PixelAtGrid3D(X_data, z1, y1, x1, D_in, H_in, W_in, border);
+ T p112 = PixelAtGrid3D(X_data, z1, y1, x2, D_in, H_in, W_in, border);
+ T p121 = PixelAtGrid3D(X_data, z1, y2, x1, D_in, H_in, W_in, border);
+ T p122 = PixelAtGrid3D(X_data, z1, y2, x2, D_in, H_in, W_in, border);
+ T Y_gridpoint_z1 = dy2 * (dx2 * p111 + dx1 * p112) + dy1 * (dx2 * p121 + dx1 * p122);
+
+ T p211 = PixelAtGrid3D(X_data, z2, y1, x1, D_in, H_in, W_in, border);
+ T p212 = PixelAtGrid3D(X_data, z2, y1, x2, D_in, H_in, W_in, border);
+ T p221 = PixelAtGrid3D(X_data, z2, y2, x1, D_in, H_in, W_in, border);
+ T p222 = PixelAtGrid3D(X_data, z2, y2, x2, D_in, H_in, W_in, border);
+ T Y_gridpoint_z2 = dy2 * (dx2 * p211 + dx1 * p212) + dy1 * (dx2 * p221 + dx1 * p222);
+ *Y_gridpoint = dz2 * Y_gridpoint_z1 + dz1 * Y_gridpoint_z2;
}
}
- T dx = static_cast(x - x0 - 1);
- T dy = static_cast(y - y0 - 1);
- *Y_gridpoint = GsBicubicInterpolate(p, static_cast(dx), static_cast(dy));
}
}
- }
- });
+ });
+ }
+ } else {
+ // shall not reach here due to above checks
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only support GirdSample in 4-D or 5-D cases.");
}
return Status::OK();
}
diff --git a/onnxruntime/core/providers/cpu/tensor/grid_sample.h b/onnxruntime/core/providers/cpu/tensor/grid_sample.h
index 2dd828b3ae3f1..dee0c4701ee21 100644
--- a/onnxruntime/core/providers/cpu/tensor/grid_sample.h
+++ b/onnxruntime/core/providers/cpu/tensor/grid_sample.h
@@ -15,37 +15,52 @@ template
class GridSample final : public OpKernel {
public:
explicit GridSample(const OpKernelInfo& info) : OpKernel(info) {
- std::string mode_str = info.GetAttrOrDefault("mode", "bilinear");
- std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros");
- align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0));
- ORT_ENFORCE(mode_str == "bilinear" || mode_str == "nearest" || mode_str == "bicubic",
- "mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic");
- ORT_ENFORCE(padding_mode_str == "zeros" || padding_mode_str == "border" || padding_mode_str == "reflection",
- "padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection");
- if (mode_str == "bicubic") {
- mode_ = Bicubic;
- } else if (mode_str == "nearest") {
- mode_ = Nearest;
+ int start_version = info.node().SinceVersion();
+ if (start_version >= 20) {
+ std::string mode_str = info.GetAttrOrDefault("mode", "linear");
+ if (mode_str == "cubic") {
+ mode_ = Cubic;
+ } else if (mode_str == "nearest") {
+ mode_ = Nearest;
+ } else if (mode_str == "linear") {
+ mode_ = Linear;
+ } else {
+ ORT_THROW("mode \"", mode_str, "\" not supported, expect linear, nearest or cubic");
+ }
} else {
- mode_ = Bilinear;
+ std::string mode_str = info.GetAttrOrDefault("mode", "bilinear");
+ if (mode_str == "bicubic") {
+ mode_ = Cubic;
+ } else if (mode_str == "nearest") {
+ mode_ = Nearest;
+ } else if (mode_str == "bilinear") {
+ mode_ = Linear;
+ } else {
+ ORT_THROW("mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic");
+ }
}
+
+ std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros");
+ align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0));
if (padding_mode_str == "reflection") {
padding_mode_ = Reflection;
} else if (padding_mode_str == "border") {
padding_mode_ = Border;
- } else {
+ } else if (padding_mode_str == "zeros") {
padding_mode_ = Zeros;
+ } else {
+ ORT_THROW("padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection");
}
}
Status Compute(OpKernelContext* context) const override;
private:
- enum GridSampleInterpolationMode {
- Bilinear,
+ typedef enum {
+ Linear,
+ Cubic,
Nearest,
- Bicubic
- };
+ } GridSampleInterpolationMode;
enum GridSamplePaddingMode {
Zeros,
@@ -53,9 +68,10 @@ class GridSample final : public OpKernel {
Reflection
};
- T PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, float border[/* 4 */]) const;
+ T PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, T border[/* 4 */]) const;
+ T PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W, T border[/* 6 */]) const;
- GridSampleInterpolationMode mode_{Bilinear};
+ GridSampleInterpolationMode mode_{Linear};
GridSamplePaddingMode padding_mode_{Zeros};
bool align_corners_{0};
};
diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc
index 47c3798721679..636c0bbfa94e9 100644
--- a/onnxruntime/test/onnx/TestCase.cc
+++ b/onnxruntime/test/onnx/TestCase.cc
@@ -944,8 +944,6 @@ std::unique_ptr> GetBrokenTests(const std::string& provider
{"simple_rnn_batchwise", "type error", {}},
{"mod_float_mixed_sign_example", "fmod attribute must be true for floating point types", {}},
{"col2im_pads", "result mismatch", {"opset18"}},
- {"gridsample_volumetric_nearest_align_corners_0", "result differs", {}},
- {"gridsample_volumetric_nearest_align_corners_1", "result differs", {}},
{"reduce_l1_empty_set", "unknown version", {}},
{"reduce_l1_empty_set_expanded", "unknown version", {}},
{"reduce_l2_empty_set", "unknown version", {}},
@@ -1351,6 +1349,8 @@ std::unique_ptr> GetBrokenTests(const std::string& provider
broken_tests->insert({"sce_sum_log_prob", "result differs"});
broken_tests->insert({"sce_sum_log_prob_expanded", "result differs"});
broken_tests->insert({"gridsample_reflection_padding", "result differs"});
+ broken_tests->insert({"gridsample_volumetric_nearest_align_corners_0", "unknown version"});
+ broken_tests->insert({"gridsample_volumetric_nearest_align_corners_1", "unknown version"});
broken_tests->insert({"spacetodepth", "result differs"});
}
diff --git a/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py b/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py
index 22bad6f1be534..7dcd6484a5688 100644
--- a/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py
+++ b/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py
@@ -1,3 +1,9 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+# This code is used to generate the test cases for the AffineGrid operator
+# in onnxruntime/test/providers/cpu/tensor/affine_grid_test.cc
+
import argparse
import numpy as np
diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc
new file mode 100644
index 0000000000000..0f097622abff0
--- /dev/null
+++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc
@@ -0,0 +1,1019 @@
+#include "gtest/gtest.h"
+#include "test/providers/provider_test_utils.h"
+#include "test/util/include/default_providers.h"
+
+#include
+
+namespace onnxruntime {
+namespace test {
+// DO NOT edit following tests. They are generated by:
+// onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py
+TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) {
+ OpTester test("GridSample", 16);
+ std::string mode = "nearest";
+ std::string padding_mode = "zeros";
+ int64_t align_corners = 1;
+ std::initializer_list X_shape{2, 2, 3, 2};
+ std::initializer_list X_data{-1.125840f, -1.152360f, -0.250579f, -0.433879f, 0.848710f, 0.692009f, -0.316013f, -2.115219f, 0.468096f, -0.157712f, 1.443660f, 0.266049f, 0.166455f, 0.874382f, -0.143474f, -0.111609f, 0.931827f, 1.259009f, 2.004981f, 0.053737f, 0.618057f, -0.412802f, -0.841065f, -2.316042f};
+ std::initializer_list Grid_shape{2, 3, 2, 2};
+ std::initializer_list Grid_data{0.063110f, -0.615220f, 0.203022f, -1.120434f, -0.867079f, -0.618636f, 0.757125f, 0.703586f, -0.532194f, -0.043299f, 0.767473f, 1.192960f, 0.476259f, 0.162111f, 0.804584f, -0.706563f, 0.223613f, -0.930367f, -0.831703f, -0.619900f, 0.542968f, 0.482592f, -0.710823f, 0.362529f};
+ std::initializer_list Y_shape{2, 2, 3, 2};
+ std::initializer_list Y_data{-1.152360f, -1.152360f, -1.125840f, 0.692009f, -0.250579f, 0.692009f, -2.115219f, -2.115219f, -0.316013f, 0.266049f, 0.468096f, 0.266049f, -0.111609f, 0.874382f, 0.874382f, 0.166455f, -0.111609f, -0.143474f, -0.412802f, 0.053737f, 0.053737f, 2.004981f, -0.412802f, 0.618057f};
+ test.AddInput("X", X_shape, X_data);
+ test.AddInput("Grid", Grid_shape, Grid_data);
+ test.AddAttribute("mode", mode);
+ test.AddAttribute("padding_mode", padding_mode);
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", Y_shape, Y_data);
+ test.ConfigEp(DefaultCpuExecutionProvider())
+ .RunWithConfig();
+}
+
+TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) {
+ OpTester test("GridSample", 16);
+ std::string mode = "nearest";
+ std::string padding_mode = "zeros";
+ int64_t align_corners = 0;
+ std::initializer_list X_shape{2, 2, 3, 2};
+ std::initializer_list X_data{-0.569248f, 0.919971f, 1.110816f, 1.289874f, -1.478174f, 2.567233f, -0.473120f, 0.335551f, -0.003304f, -0.534441f, 1.168688f, 0.394503f, 1.941462f, 0.791498f, -0.020252f, -0.437170f, -1.535287f, -0.412679f, 0.966303f, 1.624783f, -0.365619f, -1.302440f, 0.099403f, 0.441822f};
+ std::initializer_list Grid_shape{2, 3, 2, 2};
+ std::initializer_list Grid_data{-1.143118f, -0.021569f, -0.903671f, -0.925628f, -0.066120f, 0.180174f, -0.491436f, 0.712053f, -0.730247f, 1.088844f, 0.822360f, -1.011940f, -0.298661f, 0.054147f, 0.175081f, 0.284609f, 0.470914f, 0.071880f, -0.585515f, 0.567827f, -1.151099f, -0.711248f, -0.300396f, -0.584536f};
+ std::initializer_list Y_shape{2, 2, 3, 2};
+ std::initializer_list Y_data{0.000000f, -0.569248f, 1.110816f, -1.478174f, 0.000000f, 0.000000f, 0.000000f, -0.473120f, -0.003304f, 1.168688f, 0.000000f, 0.000000f, -0.020252f, -0.437170f, -0.437170f, -1.535287f, 0.000000f, 1.941462f, -0.365619f, -1.302440f, -1.302440f, 0.099403f, 0.000000f, 0.966303f};
+ test.AddInput("X", X_shape, X_data);
+ test.AddInput("Grid", Grid_shape, Grid_data);
+ test.AddAttribute("mode", mode);
+ test.AddAttribute("padding_mode", padding_mode);
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", Y_shape, Y_data);
+ test.ConfigEp(DefaultCpuExecutionProvider())
+ .RunWithConfig();
+}
+
+TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) {
+ OpTester test("GridSample", 16);
+ std::string mode = "nearest";
+ std::string padding_mode = "border";
+ int64_t align_corners = 1;
+ std::initializer_list X_shape{2, 2, 3, 2};
+ std::initializer_list X_data{-0.883376f, -0.418913f, -0.804826f, 0.565610f, 0.610365f, 0.466884f, 1.950657f, -1.063099f, -0.829367f, -1.407257f, 1.626847f, 0.172273f, -1.611502f, -0.479448f, -0.143351f, -0.317295f, 0.573655f, 0.997931f, 0.543609f, 0.078804f, 0.862860f, -0.019490f, 0.991047f, -0.777735f};
+ std::initializer_list Grid_shape{2, 3, 2, 2};
+ std::initializer_list Grid_data{-1.080070f, -0.080985f, 1.055303f, -0.489470f, 1.083604f, 0.434584f, -1.082953f, 0.759237f, -0.138473f, -0.535688f, 0.959584f, -0.969714f, 0.128766f, -0.251242f, 0.856935f, 0.334973f, 0.576606f, 0.423791f, -0.288570f, -0.252367f, -0.988898f, 0.650213f, 0.952774f, 0.821070f};
+ std::initializer_list Y_shape{2, 2, 3, 2};
+ std::initializer_list Y_data{-0.804826f, 0.565610f, 0.565610f, 0.610365f, -0.883376f, -0.418913f, -0.829367f, -1.407257f, -1.407257f, 1.626847f, 1.950657f, -1.063099f, -0.317295f, -0.317295f, -0.317295f, -0.143351f, 0.573655f, 0.997931f, -0.019490f, -0.019490f, -0.019490f, 0.862860f, 0.991047f, -0.777735f};
+ test.AddInput("X", X_shape, X_data);
+ test.AddInput("Grid", Grid_shape, Grid_data);
+ test.AddAttribute("mode", mode);
+ test.AddAttribute("padding_mode", padding_mode);
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", Y_shape, Y_data);
+ test.ConfigEp(DefaultCpuExecutionProvider())
+ .RunWithConfig();
+}
+
+TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) {
+ OpTester test("GridSample", 16);
+ std::string mode = "nearest";
+ std::string padding_mode = "border";
+ int64_t align_corners = 0;
+ std::initializer_list X_shape{2, 2, 3, 2};
+ std::initializer_list X_data{-0.559630f, 0.533472f, 0.406887f, 0.394587f, 0.171511f, 0.876045f, -0.287087f, 1.021640f, 0.438649f, -0.010704f, 1.338354f, -0.279405f, -0.551834f, -2.889061f, -1.509981f, 1.024115f, 0.195393f, -0.737109f, 1.700101f, 0.346216f, 0.971125f, 1.450250f, -0.051909f, -0.628431f};
+ std::initializer_list Grid_shape{2, 3, 2, 2};
+ std::initializer_list Grid_data{0.149807f, 1.074831f, 0.734055f, -0.758657f, 0.538205f, -0.848275f, -0.508590f, 0.352947f, 0.396231f, 0.900274f, -0.386299f, 0.001921f, 0.617788f, -1.160511f, 0.867577f, -0.992307f, 0.016539f, -0.204020f, -0.632008f, 0.158605f, 0.992302f, -0.350783f, -0.712433f, -0.443807f};
+ std::initializer_list Y_shape{2, 2, 3, 2};
+ std::initializer_list Y_data{0.876045f, 0.533472f, 0.533472f, 0.171511f, 0.876045f, 0.406887f, -0.279405f, 1.021640f, 1.021640f, 1.338354f, -0.279405f, 0.438649f, -2.889061f, -2.889061f, 1.024115f, -1.509981f, -2.889061f, -0.551834f, 0.346216f, 0.346216f, 1.450250f, 0.971125f, 0.346216f, 1.700101f};
+ test.AddInput("X", X_shape, X_data);
+ test.AddInput("Grid", Grid_shape, Grid_data);
+ test.AddAttribute("mode", mode);
+ test.AddAttribute("padding_mode", padding_mode);
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", Y_shape, Y_data);
+ test.ConfigEp(DefaultCpuExecutionProvider())
+ .RunWithConfig();
+}
+
+TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) {
+ OpTester test("GridSample", 16);
+ std::string mode = "nearest";
+ std::string padding_mode = "reflection";
+ int64_t align_corners = 1;
+ std::initializer_list X_shape{2, 2, 3, 2};
+ std::initializer_list X_data{-0.039373f, -0.801472f, -0.495544f, -0.361514f, 0.585113f, -1.156007f, -0.143365f, -0.194741f, -0.906885f, -0.591838f, 0.150785f, -1.041149f, -0.720534f, -2.214754f, -0.683730f, 0.516358f, 0.792848f, 0.083228f, 0.422800f, -1.868747f, -1.105713f, 0.143731f, 0.583597f, 1.348155f};
+ std::initializer_list Grid_shape{2, 3, 2, 2};
+ std::initializer_list Grid_data{0.829854f, -0.893309f, 0.491599f, -0.403504f, -0.578962f, 0.215574f, -0.623348f, 0.276486f, 0.235657f, -0.890987f, 0.199798f, 0.511115f, 0.474997f, -0.151054f, -0.983745f, -0.184985f, 0.416769f, -0.437853f, 0.455497f, 0.799155f, -0.626582f, 0.011834f, 0.496199f, 0.094053f};
+ std::initializer_list Y_shape{2, 2, 3, 2};
+ std::initializer_list Y_data{-0.801472f, -0.361514f, -0.495544f, -0.495544f, -0.801472f, -1.156007f, -0.194741f, -0.591838f, -0.906885f, -0.906885f, -0.194741f, -1.041149f, 0.516358f, -0.683730f, 0.516358f, 0.083228f, -0.683730f, 0.516358f, 0.143731f, -1.105713f, 0.143731f, 1.348155f, -1.105713f, 0.143731f};
+ test.AddInput("X", X_shape, X_data);
+ test.AddInput("Grid", Grid_shape, Grid_data);
+ test.AddAttribute("mode", mode);
+ test.AddAttribute("padding_mode", padding_mode);
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", Y_shape, Y_data);
+ test.ConfigEp(DefaultCpuExecutionProvider())
+ .RunWithConfig();
+}
+
+TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) {
+ OpTester test("GridSample", 16);
+ std::string mode = "nearest";
+ std::string padding_mode = "reflection";
+ int64_t align_corners = 0;
+ std::initializer_list X_shape{2, 2, 3, 2};
+ std::initializer_list X_data{-0.129230f, -0.054595f, 0.408347f, 1.126366f, 1.935057f, 1.007685f, 1.004642f, -0.433520f, -0.562711f, -0.832754f, -1.395545f, -0.399295f, -0.309940f, -0.056062f, 0.517413f, -1.596237f, 0.356960f, -2.297482f, -0.871083f, -1.674028f, 0.563055f, -1.435067f, 0.719400f, -1.370747f};
+ std::initializer_list Grid_shape{2, 3, 2, 2};
+ std::initializer_list Grid_data{-0.811910f, -1.183845f, -0.963667f, 0.947364f, 0.649243f, 1.125859f, 0.961345f, -1.071655f, -0.818917f, -0.193899f, -0.779319f, 0.833276f, -0.907209f, -0.585482f, -1.159310f, -0.681295f, 0.986973f, 0.982512f, 0.859005f, 0.926553f, 1.067024f, -0.307276f, 0.528003f, 1.069117f};
+ std::initializer_list Y_shape{2, 2, 3, 2};
+ std::initializer_list Y_data{-0.129230f, 1.935057f, 1.007685f, -0.054595f, 0.408347f, 1.935057f, 1.004642f, -1.395545f, -0.399295f, -0.433520f, -0.562711f, -1.395545f, -0.309940f, -0.309940f, -2.297482f, -2.297482f, -1.596237f, -2.297482f, -0.871083f, -0.871083f, -1.370747f, -1.370747f, -1.435067f, -1.370747f};
+ test.AddInput("X", X_shape, X_data);
+ test.AddInput("Grid", Grid_shape, Grid_data);
+ test.AddAttribute("mode", mode);
+ test.AddAttribute("padding_mode", padding_mode);
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", Y_shape, Y_data);
+ test.ConfigEp(DefaultCpuExecutionProvider())
+ .RunWithConfig();
+}
+
+TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) {
+ OpTester test("GridSample", 16);
+ std::string mode = "bilinear";
+ std::string padding_mode = "zeros";
+ int64_t align_corners = 1;
+ std::initializer_list X_shape{2, 2, 3, 2};
+ std::initializer_list X_data{0.294201f, 0.797322f, 1.264215f, 0.935492f, 0.545464f, -1.537389f, 0.312439f, 0.740060f, -0.575326f, -1.432532f, -0.666175f, 1.017438f, -2.241368f, 0.437349f, -0.555362f, -0.057943f, 0.658583f, 0.992938f, -0.206548f, -0.244841f, -0.380599f, 1.131112f, -0.090205f, -0.897900f};
+ std::initializer_list Grid_shape{2, 3, 2, 2};
+ std::initializer_list Grid_data{0.595248f, -1.096726f, -0.214731f, -0.891773f, -0.512023f, 0.432352f, -0.852156f, 0.446072f, 1.018534f, 0.078706f, -0.799785f, -0.429942f, 0.262037f, -0.914782f, 0.596172f, -1.089444f, -1.153552f, -1.165993f, -0.243436f, 0.806920f, -1.135775f, 0.997425f, -0.480027f, 0.351461f};
+ std::initializer_list Y_shape{2, 2, 3, 2};
+ std::initializer_list Y_data{0.628229f, 0.561377f, 0.688215f, 0.861459f, 0.733996f, 0.850061f, 0.590307f, 0.329661f, -0.555725f, -0.595435f, -1.228216f, -0.224152f, -0.524667f, -0.094262f, -1.725798f, 0.562584f, 0.610959f, -0.014286f, -0.162194f, -0.215901f, -0.159037f, -0.282404f, -0.084779f, -0.097448f};
+ test.AddInput("X", X_shape, X_data);
+ test.AddInput("Grid", Grid_shape, Grid_data);
+ test.AddAttribute("mode", mode);
+ test.AddAttribute("padding_mode", padding_mode);
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", Y_shape, Y_data);
+ test.ConfigEp(DefaultCpuExecutionProvider())
+ .RunWithConfig();
+}
+
+TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) {
+ OpTester test("GridSample", 16);
+ std::string mode = "bilinear";
+ std::string padding_mode = "zeros";
+ int64_t align_corners = 0;
+ std::initializer_list X_shape{2, 2, 3, 2};
+ std::initializer_list X_data{-1.199109f, -0.025686f, 1.802375f, -1.059653f, 3.402826f, -0.568670f, -0.475489f, 1.743163f, 1.060884f, -0.015953f, 1.275653f, 0.009457f, -0.369450f, 1.218198f, 0.255044f, 0.273993f, 1.404381f, 1.082878f, 0.788966f, -0.137615f, 0.122478f, -1.076701f, -0.650897f, -1.619658f};
+ std::initializer_list Grid_shape{2, 3, 2, 2};
+ std::initializer_list Grid_data{0.038587f, -0.371014f, -0.260918f, 0.159481f, 0.594851f, -0.840708f, 1.007133f, -0.130476f, -1.005535f, -0.649269f, 1.061781f, 1.097433f, -1.111536f, 0.846358f, 0.601391f, 0.710302f, 1.015835f, -0.646740f, 0.378931f, 0.491080f, -0.354592f, 0.401584f, -0.345256f, 0.741914f};
+ std::initializer_list Y_shape{2, 2, 3, 2};
+ std::initializer_list Y_data{-0.199899f, 1.437523f, -0.017180f, -0.422530f, -0.554188f, -0.088180f, 0.613663f, 0.843979f, 1.165913f, 0.161823f, -0.215288f, 0.001466f, 0.398506f, 0.909392f, 0.576145f, 0.897902f, 0.920312f, 1.201733f, -0.184698f, -1.360176f, -0.080218f, -1.352020f, -0.497572f, -0.710420f};
+ test.AddInput("X", X_shape, X_data);
+ test.AddInput("Grid", Grid_shape, Grid_data);
+ test.AddAttribute("mode", mode);
+ test.AddAttribute("padding_mode", padding_mode);
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", Y_shape, Y_data);
+ test.ConfigEp(DefaultCpuExecutionProvider())
+ .RunWithConfig();
+}
+
+TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) {
+ OpTester test("GridSample", 16);
+ std::string mode = "bilinear";
+ std::string padding_mode = "border";
+ int64_t align_corners = 1;
+ std::initializer_list X_shape{2, 2, 3, 2};
+ std::initializer_list X_data{-0.546073f, -0.630178f, -0.634650f, 0.974665f, 0.209843f, 0.029890f, 1.709235f, -0.725759f, -0.876951f, 0.522287f, 0.462005f, -1.329269f, -0.295974f, 1.371414f, 0.973846f, 0.765543f, -0.403897f, -0.326279f, 0.748218f, -0.195299f, 0.676756f, -0.080633f, 0.158123f, 0.099984f};
+ std::initializer_list Grid_shape{2, 3, 2, 2};
+ std::initializer_list Grid_data{1.182462f, -0.759228f, 0.230068f, -0.103567f, -0.252788f, -0.268017f, 0.762529f, 0.057356f, -1.168338f, -0.708432f, -0.409080f, 0.603860f, -0.776560f, 1.131504f, -0.267275f, -0.215474f, 0.940270f, 0.603129f, 1.017745f, 0.694133f, -0.364025f, -0.796167f, -0.089284f, 0.993165f};
+ std::initializer_list Y_shape{2, 2, 3, 2};
+ std::initializer_list Y_data{-0.243777f, 0.256440f, -0.179228f, 0.741578f, -0.571899f, 0.031558f, -0.425264f, 0.007242f, -0.044977f, 0.271677f, 0.955187f, -0.224230f, -0.395226f, 0.771988f, 0.108104f, 0.007673f, 0.371491f, -0.360026f, 0.151628f, 0.399982f, 0.038327f, 0.044739f, 0.445689f, 0.133017f};
+ test.AddInput("X", X_shape, X_data);
+ test.AddInput("Grid", Grid_shape, Grid_data);
+ test.AddAttribute("mode", mode);
+ test.AddAttribute("padding_mode", padding_mode);
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", Y_shape, Y_data);
+ test.ConfigEp(DefaultCpuExecutionProvider())
+ .RunWithConfig();
+}
+
+TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) {
+ OpTester test("GridSample", 16);
+ std::string mode = "bilinear";
+ std::string padding_mode = "border";
+ int64_t align_corners = 0;
+ std::initializer_list X_shape{2, 2, 3, 2};
+ std::initializer_list X_data{-0.873307f, 0.004261f, -1.257887f, -1.084466f, 0.752979f, 0.323648f, -0.275010f, 1.305612f, -0.009480f, -0.831312f, -0.556290f, 2.070567f, 0.710039f, -0.146461f, -0.746745f, 0.725842f, 0.403461f, 0.234374f, 0.173281f, 1.724145f, -0.408946f, 0.782749f, -1.520847f, -0.314686f};
+ std::initializer_list Grid_shape{2, 3, 2, 2};
+ std::initializer_list Grid_data{0.605180f, 0.169896f, 1.021029f, 0.161312f, -0.555188f, 1.135200f, 0.284017f, -1.170817f, -0.341630f, -0.817401f, 1.052104f, -0.198175f, -1.093830f, -0.075436f, 0.753615f, 0.311761f, 0.379445f, 0.111448f, 0.447382f, -0.292382f, -0.477360f, -1.121650f, -0.904004f, 0.520083f};
+ std::initializer_list Y_shape{2, 2, 3, 2};
+ std::initializer_list Y_data{-0.725617f, -0.743749f, 0.752979f, -0.185279f, -0.734326f, -0.760828f, -0.091786f, -0.129152f, -0.556290f, 0.964224f, -0.024687f, -0.196084f, -0.581904f, 0.496011f, 0.499240f, 0.319537f, 0.690648f, 0.150559f, -0.343065f, 0.269544f, 0.455333f, 1.124628f, 0.208392f, -1.276367f};
+ test.AddInput("X", X_shape, X_data);
+ test.AddInput("Grid", Grid_shape, Grid_data);
+ test.AddAttribute("mode", mode);
+ test.AddAttribute("padding_mode", padding_mode);
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", Y_shape, Y_data);
+ test.ConfigEp(DefaultCpuExecutionProvider())
+ .RunWithConfig();
+}
+
+TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) {
+ OpTester test("GridSample", 16);
+ std::string mode = "bilinear";
+ std::string padding_mode = "reflection";
+ int64_t align_corners = 1;
+ std::initializer_list X_shape{2, 2, 3, 2};
+ std::initializer_list X_data{0.540757f, -0.947807f, 0.202144f, -0.350748f, 0.545005f, 1.541211f, 0.600239f, -0.338015f, -1.080823f, -1.391537f, -0.352570f, 1.560770f, -0.822488f, -2.140920f, 0.099553f, -0.697505f, 0.665352f, -2.256198f, -1.002236f, -1.395144f, 0.415783f, 0.268104f, -0.151752f, 0.794042f};
+ std::initializer_list Grid_shape{2, 3, 2, 2};
+ std::initializer_list Grid_data{1.051960f, -0.798975f, -0.129852f, -0.064453f, 0.535452f, 0.820411f, -0.190205f, -0.994177f, 0.594591f, 0.358958f, 0.482039f, -0.740241f, 0.772315f, 1.136586f, 0.104126f, -1.120858f, 0.842388f, -0.889742f, 0.275846f, 0.174381f, -0.561644f, 0.417835f, -1.073319f, 0.273311f};
+ std::initializer_list Y_shape{2, 2, 3, 2};
+ std::initializer_list Y_data{-0.793997f, -0.042818f, 1.034663f, -0.061725f, 0.327743f, -0.470152f, -0.528701f, -1.125254f, 0.678924f, 0.212033f, -0.430627f, -0.410903f, -1.743740f, -1.404122f, -1.882401f, -0.546577f, -0.033295f, 0.203686f, 0.631537f, -1.031405f, -1.182924f, 0.344248f, 0.246420f, 0.266212f};
+ test.AddInput("X", X_shape, X_data);
+ test.AddInput("Grid", Grid_shape, Grid_data);
+ test.AddAttribute("mode", mode);
+ test.AddAttribute("padding_mode", padding_mode);
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", Y_shape, Y_data);
+ test.ConfigEp(DefaultCpuExecutionProvider())
+ .RunWithConfig();
+}
+
+TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) {
+ OpTester test("GridSample", 16);
+ std::string mode = "bilinear";
+ std::string padding_mode = "reflection";
+ int64_t align_corners = 0;
+ std::initializer_list X_shape{2, 2, 3, 2};
+ std::initializer_list X_data{0.584178f, 1.050431f, 1.285579f, -1.616520f, -0.768962f, -1.220462f, 0.573128f, 0.699197f, -1.654887f, 0.493267f, -0.615042f, 1.311865f, 0.788249f, -1.232951f, 0.454381f, -1.436621f, 0.711631f, 0.554599f, -0.807529f, 1.680131f, 0.597634f, -0.238890f, -0.345997f, 1.770104f};
+ std::initializer_list Grid_shape{2, 3, 2, 2};
+ std::initializer_list Grid_data{0.564800f, 1.031186f, 0.795913f, -0.629473f, -0.131544f, -0.377622f, -0.964948f, 0.000496f, 0.902922f, 1.011019f, 0.111961f, 0.272548f, -0.519506f, 0.905811f, -0.499330f, -0.833583f, 0.184792f, 0.719262f, -1.081910f, 1.084761f, 0.431677f, -0.840735f, -0.258489f, 1.041096f};
+ std::initializer_list Y_shape{2, 2, 3, 2};
+ std::initializer_list Y_data{-1.220462f, 0.901641f, 0.521980f, 1.284051f, -1.220462f, -0.717235f, 1.311865f, 0.687708f, -0.023386f, -1.654114f, 1.311865f, 0.029458f, 0.711631f, 0.786895f, 0.604097f, 0.711631f, -1.094857f, 0.673706f, -0.345997f, -0.805863f, 1.103092f, -0.345997f, 1.510167f, 0.165064f};
+ test.AddInput("X", X_shape, X_data);
+ test.AddInput("Grid", Grid_shape, Grid_data);
+ test.AddAttribute("mode", mode);
+ test.AddAttribute("padding_mode", padding_mode);
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", Y_shape, Y_data);
+ test.ConfigEp(DefaultCpuExecutionProvider())
+ .RunWithConfig();
+}
+
+TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) {
+ OpTester test("GridSample", 16);
+ std::string mode = "bicubic";
+ std::string padding_mode = "zeros";
+ int64_t align_corners = 1;
+ std::initializer_list X_shape{2, 2, 3, 2};
+ std::initializer_list X_data{0.497417f, 0.268522f, 1.476879f, 0.354795f, 1.624709f, 0.593423f, -1.725412f, -0.622016f, -0.466707f, -0.319962f, 0.701868f, 0.494252f, -0.630165f, 0.548236f, 1.042740f, 0.253800f, -2.667303f, 1.379165f, -0.519418f, 0.672783f, -0.005627f, -0.180192f, -0.018395f, 0.998084f};
+ std::initializer_list Grid_shape{2, 3, 2, 2};
+ std::initializer_list Grid_data{0.213755f, 0.141747f, -0.562622f, -0.414594f, 0.325025f, -0.834438f, 0.197995f, 0.519270f, -0.472884f, 0.996769f, -0.078973f, 0.544455f, 1.188368f, -0.366802f, 0.652090f, -0.343235f, -0.175288f, -0.203365f, -0.007455f, -0.453322f, 0.281264f, 0.045216f, 0.760668f, -0.242886f};
+ std::initializer_list Y_shape{2, 2, 3, 2};
+ std::initializer_list Y_data{1.007407f, 1.068583f, 0.492134f, 1.222040f, 1.576835f, 1.464183f, -0.238652f, -1.242164f, -1.156880f, 0.279082f, 0.744912f, 0.338287f, 0.215322f, 0.388598f, 0.866571f, 0.556826f, 0.608617f, 0.326312f, 0.044527f, -0.028766f, -0.136528f, -0.084880f, -0.121429f, -0.105516f};
+ test.AddInput("X", X_shape, X_data);
+ test.AddInput("Grid", Grid_shape, Grid_data);
+ test.AddAttribute("mode", mode);
+ test.AddAttribute("padding_mode", padding_mode);
+ test.AddAttribute("align_corners", align_corners);
+ test.AddOutput("Y", Y_shape, Y_data);
+ test.ConfigEp(DefaultCpuExecutionProvider())
+ .RunWithConfig();
+}
+
+TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) {
+ OpTester test("GridSample", 16);
+ std::string mode = "bicubic";
+ std::string padding_mode = "zeros";
+ int64_t align_corners = 0;
+ std::initializer_list X_shape{2, 2, 3, 2};
+ std::initializer_list X_data{-1.065470f, 0.402578f, -0.405242f, -0.583366f, -0.258523f, -0.605559f, -0.188242f, 0.959607f, 1.189619f, -0.179522f, -1.823240f, -0.051351f, -1.636092f, -2.510569f, -1.238273f, -0.929619f, -0.058536f, 0.772879f, 0.468944f, 0.259886f, 0.757624f, -2.041813f, -0.552378f, 0.626977f};
+ std::initializer_list Grid_shape{2, 3, 2, 2};
+ std::initializer_list Grid_data{-1.199809f, 0.061445f, -0.035546f, 0.180524f, 0.919500f, 1.166411f, -0.711939f, -0.074825f, -0.480808f, -1.105975f, -0.873191f, 1.126273f, 0.699673f, 0.644581f, 0.666892f, -0.953375f, 0.126023f, 1.116858f, -0.669703f, 1.067513f, 0.315406f, 0.844252f, -0.514065f, 0.553221f};
+ std::initializer_list