diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 1740140094df2..e60da111afd36 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -138,7 +138,8 @@ Do not modify directly.* |||[7, 8]|**T** = tensor(double), tensor(float)
**T1** = tensor(bool)| |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| |||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)| +|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(double), tensor(float)| +|||[16, 19]|**T1** = tensor(float)
**T2** = tensor(float)| |HammingWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |HannWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 2ca3b1cdf817e..4553e7ee18913 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -798,7 +798,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 18, If); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, RoiAlign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, GridSample); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 19, float, GridSample); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, string, Where); @@ -960,6 +960,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh // Opset 20 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, GridSample); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, GridSample); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, AffineGrid); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN); @@ -2183,8 +2185,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { RoiAlign)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2401,6 +2403,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { // Opset 20 BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/tensor/grid_sample.cc b/onnxruntime/core/providers/cpu/tensor/grid_sample.cc index c58a7d8337114..a83ba378d7f1e 100644 --- a/onnxruntime/core/providers/cpu/tensor/grid_sample.cc +++ b/onnxruntime/core/providers/cpu/tensor/grid_sample.cc @@ -11,17 +11,23 @@ namespace onnxruntime { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_CPU_OPERATOR_TYPED_KERNEL( \ - GridSample, \ - 16, \ - T, \ - KernelDefBuilder() \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - GridSample); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(GridSample, kOnnxDomain, 16, 19, T, kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + GridSample); + +#define REGISTER_KERNEL_TYPED_20(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(GridSample, kOnnxDomain, 20, T, kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + GridSample); REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED_20(float) +REGISTER_KERNEL_TYPED_20(double) // Restore normalized location to actual image location // When align_corners is true: @@ -44,16 +50,15 @@ T GsDenormalize(T n, int64_t length, bool align_corners) { } // Reflect by the near border till within the borders -// Use float for borders to avoid potential issues with integer T template -T GsReflect(T x, float x_min, float x_max) { - float dx = {}; - float fx = static_cast(x); - float range = x_max - x_min; +T GsReflect(T x, T x_min, T x_max) { + T dx = {}; + T fx = static_cast(x); + T range = x_max - x_min; if (fx < x_min) { dx = x_min - fx; int n = static_cast(dx / range); - float r = dx - n * range; + T r = dx - n * range; if (n % 2 == 0) { fx = x_min + r; } else { @@ -62,7 +67,7 @@ T GsReflect(T x, float x_min, float x_max) { } else if (fx > x_max) { dx = fx - x_max; int n = static_cast(dx / range); - float r = dx - n * range; + T r = dx - n * range; if (n % 2 == 0) { fx = x_max - r; } else { @@ -75,9 +80,9 @@ T GsReflect(T x, float x_min, float x_max) { // Calculate cubic convolution interpolation coefficients // ROBERT G. KEYS https://ieeexplore.ieee.org/document/1163711 -// Use float to avoid potential issues with integer T -void GsGetCubicCoeffs(float x, float coeffs[4]) { - constexpr float cubic_alpha = -0.75f; +template +void GsGetCubicCoeffs(T x, T coeffs[4]) { + constexpr T cubic_alpha = -0.75f; x = std::abs(x); coeffs[0] = ((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha; coeffs[1] = ((cubic_alpha + 2) * x - (cubic_alpha + 3)) * x * x + 1; @@ -86,9 +91,9 @@ void GsGetCubicCoeffs(float x, float coeffs[4]) { } template -T GsBicubicInterpolate(T p[4][4], float x, float y) { - float v[4] = {}; - float coeffs[4] = {}; +T GsBicubicInterpolate(T p[4][4], T x, T y) { + T v[4] = {}; + T coeffs[4] = {}; GsGetCubicCoeffs(x, coeffs); for (int64_t i = 0; i < 4; i++) { v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3]; @@ -98,7 +103,7 @@ T GsBicubicInterpolate(T p[4][4], float x, float y) { } template -T GridSample::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, float border[/* 4 */]) const { +T GridSample::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, T border[/* 4 */]) const { T pixel = {}; // default 0 if (padding_mode_ == Zeros) { if (c >= 0 && c < W && r >= 0 && r < H) { @@ -116,6 +121,27 @@ T GridSample::PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, in return pixel; } +template +T GridSample::PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W, T border[/* 6 */]) const { + T pixel = {}; // default 0 + if (padding_mode_ == Zeros) { + if (w >= 0 && w < W && h >= 0 && h < H && d >= 0 && d < D) { + pixel = image[d * H * W + h * W + w]; + } + } else if (padding_mode_ == Border) { + w = std::clamp(w, 0, W - 1); + h = std::clamp(h, 0, H - 1); + d = std::clamp(d, 0, D - 1); + pixel = image[d * H * W + h * W + w]; + } else { // (padding_mode_ == Reflection) + w = static_cast(GsReflect(static_cast(w), border[0], border[3])); + h = static_cast(GsReflect(static_cast(h), border[1], border[4])); + d = static_cast(GsReflect(static_cast(d), border[2], border[5])); + pixel = image[d * H * W + h * W + w]; + } + return pixel; +} + // When grid sampling, padding is applied before interpolation. // For instance, in bilinear mode and zeros padding-mode, pixel p at actual // image location (-0.5, -0.5) @@ -134,113 +160,203 @@ Status GridSample::Compute(OpKernelContext* context) const { const auto& input_dims = input->Shape(); const auto& grid_dims = grid->Shape(); - if (input_dims.NumDimensions() != 4 || grid_dims.NumDimensions() != 4) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only 4-D tensor is supported"); - } + int64_t data_dims = input_dims.NumDimensions() - 2; + ORT_ENFORCE(static_cast(grid_dims.NumDimensions()) == data_dims + 2, + "grid dimensions must be ", data_dims + 2, "for input dimension of ", data_dims); + + ORT_ENFORCE(grid_dims[grid_dims.NumDimensions() - 1] == data_dims, + "Last dimension of grid: ", grid_dims[grid_dims.NumDimensions() - 1], ", expect ", data_dims); + + ORT_ENFORCE(input_dims.NumDimensions() == 4 || input_dims.NumDimensions() == 5, "Only 4-D or 5-D tensor is supported"); auto N = input_dims[0]; auto C = input_dims[1]; - auto H_in = input_dims[2]; - auto W_in = input_dims[3]; - auto H_out = grid_dims[1]; - auto W_out = grid_dims[2]; ORT_ENFORCE(grid_dims[0] == N, "Grid batch size ", grid_dims[0], " does not match input batch size ", N); - ORT_ENFORCE(grid_dims[3] == 2, "Last dimension of grid: ", grid_dims[3], ", expect 2"); - TensorShape Y_shape = {N, C, H_out, W_out}; - auto& Y = *context->Output(0, Y_shape); - // Return early if the output tensor is going to be of size 0 - if (Y.Shape().Size() == 0) { - return Status::OK(); + if (input_dims.NumDimensions() == 5) { + ORT_ENFORCE(mode_ != Cubic, "Only support GridSample Cubic mode in 4-D cases."); } - // Force float here to avoid possible issue in integer T case - float x_min = -0.5f; - float x_max = W_in - 0.5f; - float y_min = -0.5f; - float y_max = H_in - 0.5f; - - if (align_corners_) { - x_min = 0.f; - x_max = W_in - 1.f; - y_min = 0.f; - y_max = H_in - 1.f; - } - float border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b - - concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr; - for (int64_t n = 0; n < N; n++) { - const T* grid_data = grid->Data() + n * (H_out * W_out) * 2; - concurrency::ThreadPool::TrySimpleParallelFor( - tp, onnxruntime::narrow(C), - [&](std::ptrdiff_t c) { - const T* X_data = input->Data() + (n * C + c) * (H_in * W_in); - T* Y_data = Y.MutableData() + (n * C + c) * (H_out * W_out); - - for (int64_t oy = 0; oy < H_out; oy++) { - for (int64_t ox = 0; ox < W_out; ox++) { - const T* gridpoint = grid_data + (oy * W_out + ox) * 2; - T* Y_gridpoint = Y_data + oy * W_out + ox; - auto nx = gridpoint[0]; // normalized location - auto ny = gridpoint[1]; - auto x = GsDenormalize(nx, W_in, align_corners_); // actual location - auto y = GsDenormalize(ny, H_in, align_corners_); - - if (mode_ == Nearest) { - x = static_cast(std::nearbyintf(static_cast(x))); - y = static_cast(std::nearbyintf(static_cast(y))); - } + if (data_dims == 2) { + // sample 2d; + auto H_in = input_dims[2]; + auto W_in = input_dims[3]; + auto H_out = grid_dims[1]; + auto W_out = grid_dims[2]; + TensorShape Y_shape = {N, C, H_out, W_out}; + auto& Y = *context->Output(0, Y_shape); + // Return early if the output tensor is going to be of size 0 + if (Y.Shape().Size() == 0) { + return Status::OK(); + } - if (x < x_min || x > x_max || y < y_min || y > y_max) { // out of bound - if (padding_mode_ == Border) { - // use original border in both align_corner cases - x = std::clamp(x, static_cast(0), static_cast(W_in - 1)); - y = std::clamp(y, static_cast(0), static_cast(H_in - 1)); - } else if (padding_mode_ == Reflection) { - x = GsReflect(x, x_min, x_max); - y = GsReflect(y, y_min, y_max); - } - } // out of bound + T x_min = -0.5f; + T x_max = W_in - 0.5f; + T y_min = -0.5f; + T y_max = H_in - 0.5f; - if (mode_ == Nearest) { - // x, y are integers in all padding modes - *Y_gridpoint = PixelAtGrid(X_data, static_cast(y), static_cast(x), H_in, W_in, border); - continue; - } + if (align_corners_) { + x_min = 0.f; + x_max = W_in - 1.f; + y_min = 0.f; + y_max = H_in - 1.f; + } + T border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b - if (mode_ == Bilinear) { - int64_t x1 = static_cast(std::floor(x)); - int64_t y1 = static_cast(std::floor(y)); - int64_t x2 = x1 + 1; - int64_t y2 = y1 + 1; - - T p11 = PixelAtGrid(X_data, y1, x1, H_in, W_in, border); - T p12 = PixelAtGrid(X_data, y1, x2, H_in, W_in, border); - T p21 = PixelAtGrid(X_data, y2, x1, H_in, W_in, border); - T p22 = PixelAtGrid(X_data, y2, x2, H_in, W_in, border); - - T dx2 = static_cast(x2) - x; - T dx1 = x - static_cast(x1); - T dy2 = static_cast(y2) - y; - T dy1 = y - static_cast(y1); - *Y_gridpoint = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22); + concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr; + for (int64_t n = 0; n < N; n++) { + const T* grid_data = grid->Data() + n * (H_out * W_out) * 2; + concurrency::ThreadPool::TrySimpleParallelFor( + tp, onnxruntime::narrow(C), + [&](std::ptrdiff_t c) { + const T* X_data = input->Data() + (n * C + c) * (H_in * W_in); + T* Y_data = Y.MutableData() + (n * C + c) * (H_out * W_out); + + for (int64_t oy = 0; oy < H_out; oy++) { + for (int64_t ox = 0; ox < W_out; ox++) { + const T* gridpoint = grid_data + (oy * W_out + ox) * 2; + T* Y_gridpoint = Y_data + oy * W_out + ox; + auto nx = gridpoint[0]; // normalized location + auto ny = gridpoint[1]; + auto x = GsDenormalize(nx, W_in, align_corners_); // actual location + auto y = GsDenormalize(ny, H_in, align_corners_); + + if (mode_ == Nearest) { + x = static_cast(std::nearbyint(static_cast(x))); + y = static_cast(std::nearbyint(static_cast(y))); + // x, y are integers in all padding modes + *Y_gridpoint = PixelAtGrid(X_data, static_cast(y), static_cast(x), H_in, W_in, border); + } else if (mode_ == Linear) { + int64_t x1 = static_cast(std::floor(x)); + int64_t y1 = static_cast(std::floor(y)); + int64_t x2 = x1 + 1; + int64_t y2 = y1 + 1; + + T p11 = PixelAtGrid(X_data, y1, x1, H_in, W_in, border); + T p12 = PixelAtGrid(X_data, y1, x2, H_in, W_in, border); + T p21 = PixelAtGrid(X_data, y2, x1, H_in, W_in, border); + T p22 = PixelAtGrid(X_data, y2, x2, H_in, W_in, border); + + T dx2 = static_cast(x2) - x; + T dx1 = x - static_cast(x1); + T dy2 = static_cast(y2) - y; + T dy1 = y - static_cast(y1); + *Y_gridpoint = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22); + } else if (mode_ == Cubic) { + int64_t x0 = static_cast(std::floor(x)) - 1; // top-left corner of the bbox + int64_t y0 = static_cast(std::floor(y)) - 1; + + T p[4][4] = {}; // [H][W] + for (int64_t h = 0; h < 4; h++) { + for (int64_t w = 0; w < 4; w++) { + p[h][w] = PixelAtGrid(X_data, h + y0, w + x0, H_in, W_in, border); + } + } + T dx = static_cast(x - x0 - 1); + T dy = static_cast(y - y0 - 1); + *Y_gridpoint = GsBicubicInterpolate(p, dx, dy); + } } - if (mode_ == Bicubic) { - int64_t x0 = static_cast(std::floor(x)) - 1; // top-left corner of the bbox - int64_t y0 = static_cast(std::floor(y)) - 1; - T p[4][4] = {}; // [H][W] - for (int64_t h = 0; h < 4; h++) { - for (int64_t w = 0; w < 4; w++) { - p[h][w] = PixelAtGrid(X_data, h + y0, w + x0, H_in, W_in, border); + } + }); + } + } else if (data_dims == 3) { + // sample 3d; + auto D_in = input_dims[2]; + auto H_in = input_dims[3]; + auto W_in = input_dims[4]; + auto D_out = grid_dims[1]; + auto H_out = grid_dims[2]; + auto W_out = grid_dims[3]; + TensorShape Y_shape = {N, C, D_out, H_out, W_out}; + auto& Y = *context->Output(0, Y_shape); + // Return early if the output tensor is going to be of size 0 + if (Y.Shape().Size() == 0) { + return Status::OK(); + } + + T x_min = -0.5f; + T x_max = W_in - 0.5f; + T y_min = -0.5f; + T y_max = H_in - 0.5f; + T z_min = -0.5f; + T z_max = D_in - 0.5f; + + if (align_corners_) { + x_min = 0.f; + x_max = W_in - 1.f; + y_min = 0.f; + y_max = H_in - 1.f; + z_min = 0.f; + z_max = D_in - 1.f; + } + T border[] = {x_min, y_min, z_min, x_max, y_max, z_max}; + + concurrency::ThreadPool* tp = D_out * H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr; + for (int64_t n = 0; n < N; n++) { + const T* grid_data = grid->Data() + n * (D_out * H_out * W_out) * 3; + concurrency::ThreadPool::TrySimpleParallelFor( + tp, onnxruntime::narrow(C), + [&](std::ptrdiff_t c) { + const T* X_data = input->Data() + (n * C + c) * (D_in * H_in * W_in); + T* Y_data = Y.MutableData() + (n * C + c) * (D_out * H_out * W_out); + + for (int64_t oz = 0; oz < D_out; oz++) { + for (int64_t oy = 0; oy < H_out; oy++) { + for (int64_t ox = 0; ox < W_out; ox++) { + const T* gridpoint = grid_data + (oz * H_out * W_out + oy * W_out + ox) * 3; + T* Y_gridpoint = Y_data + oz * H_out * W_out + oy * W_out + ox; + auto nx = gridpoint[0]; // normalized location + auto ny = gridpoint[1]; + auto nz = gridpoint[2]; + auto x = GsDenormalize(nx, W_in, align_corners_); // actual location + auto y = GsDenormalize(ny, H_in, align_corners_); + auto z = GsDenormalize(nz, D_in, align_corners_); + + if (mode_ == Nearest) { + x = static_cast(std::nearbyint(static_cast(x))); + y = static_cast(std::nearbyint(static_cast(y))); + z = static_cast(std::nearbyint(static_cast(z))); + + // x, y are integers in all padding modes + *Y_gridpoint = PixelAtGrid3D(X_data, static_cast(z), static_cast(y), static_cast(x), + D_in, H_in, W_in, border); + } else if (mode_ == Linear) { + int64_t x1 = static_cast(std::floor(x)); + int64_t y1 = static_cast(std::floor(y)); + int64_t z1 = static_cast(std::floor(z)); + int64_t x2 = x1 + 1; + int64_t y2 = y1 + 1; + int64_t z2 = z1 + 1; + + T dx2 = static_cast(x2) - x; + T dx1 = x - static_cast(x1); + T dy2 = static_cast(y2) - y; + T dy1 = y - static_cast(y1); + T dz2 = static_cast(z2) - z; + T dz1 = z - static_cast(z1); + + T p111 = PixelAtGrid3D(X_data, z1, y1, x1, D_in, H_in, W_in, border); + T p112 = PixelAtGrid3D(X_data, z1, y1, x2, D_in, H_in, W_in, border); + T p121 = PixelAtGrid3D(X_data, z1, y2, x1, D_in, H_in, W_in, border); + T p122 = PixelAtGrid3D(X_data, z1, y2, x2, D_in, H_in, W_in, border); + T Y_gridpoint_z1 = dy2 * (dx2 * p111 + dx1 * p112) + dy1 * (dx2 * p121 + dx1 * p122); + + T p211 = PixelAtGrid3D(X_data, z2, y1, x1, D_in, H_in, W_in, border); + T p212 = PixelAtGrid3D(X_data, z2, y1, x2, D_in, H_in, W_in, border); + T p221 = PixelAtGrid3D(X_data, z2, y2, x1, D_in, H_in, W_in, border); + T p222 = PixelAtGrid3D(X_data, z2, y2, x2, D_in, H_in, W_in, border); + T Y_gridpoint_z2 = dy2 * (dx2 * p211 + dx1 * p212) + dy1 * (dx2 * p221 + dx1 * p222); + *Y_gridpoint = dz2 * Y_gridpoint_z1 + dz1 * Y_gridpoint_z2; } } - T dx = static_cast(x - x0 - 1); - T dy = static_cast(y - y0 - 1); - *Y_gridpoint = GsBicubicInterpolate(p, static_cast(dx), static_cast(dy)); } } - } - }); + }); + } + } else { + // shall not reach here due to above checks + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only support GirdSample in 4-D or 5-D cases."); } return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/tensor/grid_sample.h b/onnxruntime/core/providers/cpu/tensor/grid_sample.h index 2dd828b3ae3f1..dee0c4701ee21 100644 --- a/onnxruntime/core/providers/cpu/tensor/grid_sample.h +++ b/onnxruntime/core/providers/cpu/tensor/grid_sample.h @@ -15,37 +15,52 @@ template class GridSample final : public OpKernel { public: explicit GridSample(const OpKernelInfo& info) : OpKernel(info) { - std::string mode_str = info.GetAttrOrDefault("mode", "bilinear"); - std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros"); - align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0)); - ORT_ENFORCE(mode_str == "bilinear" || mode_str == "nearest" || mode_str == "bicubic", - "mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic"); - ORT_ENFORCE(padding_mode_str == "zeros" || padding_mode_str == "border" || padding_mode_str == "reflection", - "padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection"); - if (mode_str == "bicubic") { - mode_ = Bicubic; - } else if (mode_str == "nearest") { - mode_ = Nearest; + int start_version = info.node().SinceVersion(); + if (start_version >= 20) { + std::string mode_str = info.GetAttrOrDefault("mode", "linear"); + if (mode_str == "cubic") { + mode_ = Cubic; + } else if (mode_str == "nearest") { + mode_ = Nearest; + } else if (mode_str == "linear") { + mode_ = Linear; + } else { + ORT_THROW("mode \"", mode_str, "\" not supported, expect linear, nearest or cubic"); + } } else { - mode_ = Bilinear; + std::string mode_str = info.GetAttrOrDefault("mode", "bilinear"); + if (mode_str == "bicubic") { + mode_ = Cubic; + } else if (mode_str == "nearest") { + mode_ = Nearest; + } else if (mode_str == "bilinear") { + mode_ = Linear; + } else { + ORT_THROW("mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic"); + } } + + std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros"); + align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0)); if (padding_mode_str == "reflection") { padding_mode_ = Reflection; } else if (padding_mode_str == "border") { padding_mode_ = Border; - } else { + } else if (padding_mode_str == "zeros") { padding_mode_ = Zeros; + } else { + ORT_THROW("padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection"); } } Status Compute(OpKernelContext* context) const override; private: - enum GridSampleInterpolationMode { - Bilinear, + typedef enum { + Linear, + Cubic, Nearest, - Bicubic - }; + } GridSampleInterpolationMode; enum GridSamplePaddingMode { Zeros, @@ -53,9 +68,10 @@ class GridSample final : public OpKernel { Reflection }; - T PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, float border[/* 4 */]) const; + T PixelAtGrid(const T* image, int64_t r, int64_t c, int64_t H, int64_t W, T border[/* 4 */]) const; + T PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W, T border[/* 6 */]) const; - GridSampleInterpolationMode mode_{Bilinear}; + GridSampleInterpolationMode mode_{Linear}; GridSamplePaddingMode padding_mode_{Zeros}; bool align_corners_{0}; }; diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 47c3798721679..636c0bbfa94e9 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -944,8 +944,6 @@ std::unique_ptr> GetBrokenTests(const std::string& provider {"simple_rnn_batchwise", "type error", {}}, {"mod_float_mixed_sign_example", "fmod attribute must be true for floating point types", {}}, {"col2im_pads", "result mismatch", {"opset18"}}, - {"gridsample_volumetric_nearest_align_corners_0", "result differs", {}}, - {"gridsample_volumetric_nearest_align_corners_1", "result differs", {}}, {"reduce_l1_empty_set", "unknown version", {}}, {"reduce_l1_empty_set_expanded", "unknown version", {}}, {"reduce_l2_empty_set", "unknown version", {}}, @@ -1351,6 +1349,8 @@ std::unique_ptr> GetBrokenTests(const std::string& provider broken_tests->insert({"sce_sum_log_prob", "result differs"}); broken_tests->insert({"sce_sum_log_prob_expanded", "result differs"}); broken_tests->insert({"gridsample_reflection_padding", "result differs"}); + broken_tests->insert({"gridsample_volumetric_nearest_align_corners_0", "unknown version"}); + broken_tests->insert({"gridsample_volumetric_nearest_align_corners_1", "unknown version"}); broken_tests->insert({"spacetodepth", "result differs"}); } diff --git a/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py b/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py index 22bad6f1be534..7dcd6484a5688 100644 --- a/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py +++ b/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py @@ -1,3 +1,9 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# This code is used to generate the test cases for the AffineGrid operator +# in onnxruntime/test/providers/cpu/tensor/affine_grid_test.cc + import argparse import numpy as np diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc new file mode 100644 index 0000000000000..0f097622abff0 --- /dev/null +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -0,0 +1,1019 @@ +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" + +#include + +namespace onnxruntime { +namespace test { +// DO NOT edit following tests. They are generated by: +// onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py +TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "nearest"; + std::string padding_mode = "zeros"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-1.125840f, -1.152360f, -0.250579f, -0.433879f, 0.848710f, 0.692009f, -0.316013f, -2.115219f, 0.468096f, -0.157712f, 1.443660f, 0.266049f, 0.166455f, 0.874382f, -0.143474f, -0.111609f, 0.931827f, 1.259009f, 2.004981f, 0.053737f, 0.618057f, -0.412802f, -0.841065f, -2.316042f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.063110f, -0.615220f, 0.203022f, -1.120434f, -0.867079f, -0.618636f, 0.757125f, 0.703586f, -0.532194f, -0.043299f, 0.767473f, 1.192960f, 0.476259f, 0.162111f, 0.804584f, -0.706563f, 0.223613f, -0.930367f, -0.831703f, -0.619900f, 0.542968f, 0.482592f, -0.710823f, 0.362529f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-1.152360f, -1.152360f, -1.125840f, 0.692009f, -0.250579f, 0.692009f, -2.115219f, -2.115219f, -0.316013f, 0.266049f, 0.468096f, 0.266049f, -0.111609f, 0.874382f, 0.874382f, 0.166455f, -0.111609f, -0.143474f, -0.412802f, 0.053737f, 0.053737f, 2.004981f, -0.412802f, 0.618057f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "nearest"; + std::string padding_mode = "zeros"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.569248f, 0.919971f, 1.110816f, 1.289874f, -1.478174f, 2.567233f, -0.473120f, 0.335551f, -0.003304f, -0.534441f, 1.168688f, 0.394503f, 1.941462f, 0.791498f, -0.020252f, -0.437170f, -1.535287f, -0.412679f, 0.966303f, 1.624783f, -0.365619f, -1.302440f, 0.099403f, 0.441822f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{-1.143118f, -0.021569f, -0.903671f, -0.925628f, -0.066120f, 0.180174f, -0.491436f, 0.712053f, -0.730247f, 1.088844f, 0.822360f, -1.011940f, -0.298661f, 0.054147f, 0.175081f, 0.284609f, 0.470914f, 0.071880f, -0.585515f, 0.567827f, -1.151099f, -0.711248f, -0.300396f, -0.584536f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{0.000000f, -0.569248f, 1.110816f, -1.478174f, 0.000000f, 0.000000f, 0.000000f, -0.473120f, -0.003304f, 1.168688f, 0.000000f, 0.000000f, -0.020252f, -0.437170f, -0.437170f, -1.535287f, 0.000000f, 1.941462f, -0.365619f, -1.302440f, -1.302440f, 0.099403f, 0.000000f, 0.966303f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "nearest"; + std::string padding_mode = "border"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.883376f, -0.418913f, -0.804826f, 0.565610f, 0.610365f, 0.466884f, 1.950657f, -1.063099f, -0.829367f, -1.407257f, 1.626847f, 0.172273f, -1.611502f, -0.479448f, -0.143351f, -0.317295f, 0.573655f, 0.997931f, 0.543609f, 0.078804f, 0.862860f, -0.019490f, 0.991047f, -0.777735f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{-1.080070f, -0.080985f, 1.055303f, -0.489470f, 1.083604f, 0.434584f, -1.082953f, 0.759237f, -0.138473f, -0.535688f, 0.959584f, -0.969714f, 0.128766f, -0.251242f, 0.856935f, 0.334973f, 0.576606f, 0.423791f, -0.288570f, -0.252367f, -0.988898f, 0.650213f, 0.952774f, 0.821070f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-0.804826f, 0.565610f, 0.565610f, 0.610365f, -0.883376f, -0.418913f, -0.829367f, -1.407257f, -1.407257f, 1.626847f, 1.950657f, -1.063099f, -0.317295f, -0.317295f, -0.317295f, -0.143351f, 0.573655f, 0.997931f, -0.019490f, -0.019490f, -0.019490f, 0.862860f, 0.991047f, -0.777735f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "nearest"; + std::string padding_mode = "border"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.559630f, 0.533472f, 0.406887f, 0.394587f, 0.171511f, 0.876045f, -0.287087f, 1.021640f, 0.438649f, -0.010704f, 1.338354f, -0.279405f, -0.551834f, -2.889061f, -1.509981f, 1.024115f, 0.195393f, -0.737109f, 1.700101f, 0.346216f, 0.971125f, 1.450250f, -0.051909f, -0.628431f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.149807f, 1.074831f, 0.734055f, -0.758657f, 0.538205f, -0.848275f, -0.508590f, 0.352947f, 0.396231f, 0.900274f, -0.386299f, 0.001921f, 0.617788f, -1.160511f, 0.867577f, -0.992307f, 0.016539f, -0.204020f, -0.632008f, 0.158605f, 0.992302f, -0.350783f, -0.712433f, -0.443807f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{0.876045f, 0.533472f, 0.533472f, 0.171511f, 0.876045f, 0.406887f, -0.279405f, 1.021640f, 1.021640f, 1.338354f, -0.279405f, 0.438649f, -2.889061f, -2.889061f, 1.024115f, -1.509981f, -2.889061f, -0.551834f, 0.346216f, 0.346216f, 1.450250f, 0.971125f, 0.346216f, 1.700101f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "nearest"; + std::string padding_mode = "reflection"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.039373f, -0.801472f, -0.495544f, -0.361514f, 0.585113f, -1.156007f, -0.143365f, -0.194741f, -0.906885f, -0.591838f, 0.150785f, -1.041149f, -0.720534f, -2.214754f, -0.683730f, 0.516358f, 0.792848f, 0.083228f, 0.422800f, -1.868747f, -1.105713f, 0.143731f, 0.583597f, 1.348155f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.829854f, -0.893309f, 0.491599f, -0.403504f, -0.578962f, 0.215574f, -0.623348f, 0.276486f, 0.235657f, -0.890987f, 0.199798f, 0.511115f, 0.474997f, -0.151054f, -0.983745f, -0.184985f, 0.416769f, -0.437853f, 0.455497f, 0.799155f, -0.626582f, 0.011834f, 0.496199f, 0.094053f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-0.801472f, -0.361514f, -0.495544f, -0.495544f, -0.801472f, -1.156007f, -0.194741f, -0.591838f, -0.906885f, -0.906885f, -0.194741f, -1.041149f, 0.516358f, -0.683730f, 0.516358f, 0.083228f, -0.683730f, 0.516358f, 0.143731f, -1.105713f, 0.143731f, 1.348155f, -1.105713f, 0.143731f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "nearest"; + std::string padding_mode = "reflection"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.129230f, -0.054595f, 0.408347f, 1.126366f, 1.935057f, 1.007685f, 1.004642f, -0.433520f, -0.562711f, -0.832754f, -1.395545f, -0.399295f, -0.309940f, -0.056062f, 0.517413f, -1.596237f, 0.356960f, -2.297482f, -0.871083f, -1.674028f, 0.563055f, -1.435067f, 0.719400f, -1.370747f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{-0.811910f, -1.183845f, -0.963667f, 0.947364f, 0.649243f, 1.125859f, 0.961345f, -1.071655f, -0.818917f, -0.193899f, -0.779319f, 0.833276f, -0.907209f, -0.585482f, -1.159310f, -0.681295f, 0.986973f, 0.982512f, 0.859005f, 0.926553f, 1.067024f, -0.307276f, 0.528003f, 1.069117f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-0.129230f, 1.935057f, 1.007685f, -0.054595f, 0.408347f, 1.935057f, 1.004642f, -1.395545f, -0.399295f, -0.433520f, -0.562711f, -1.395545f, -0.309940f, -0.309940f, -2.297482f, -2.297482f, -1.596237f, -2.297482f, -0.871083f, -0.871083f, -1.370747f, -1.370747f, -1.435067f, -1.370747f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "bilinear"; + std::string padding_mode = "zeros"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{0.294201f, 0.797322f, 1.264215f, 0.935492f, 0.545464f, -1.537389f, 0.312439f, 0.740060f, -0.575326f, -1.432532f, -0.666175f, 1.017438f, -2.241368f, 0.437349f, -0.555362f, -0.057943f, 0.658583f, 0.992938f, -0.206548f, -0.244841f, -0.380599f, 1.131112f, -0.090205f, -0.897900f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.595248f, -1.096726f, -0.214731f, -0.891773f, -0.512023f, 0.432352f, -0.852156f, 0.446072f, 1.018534f, 0.078706f, -0.799785f, -0.429942f, 0.262037f, -0.914782f, 0.596172f, -1.089444f, -1.153552f, -1.165993f, -0.243436f, 0.806920f, -1.135775f, 0.997425f, -0.480027f, 0.351461f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{0.628229f, 0.561377f, 0.688215f, 0.861459f, 0.733996f, 0.850061f, 0.590307f, 0.329661f, -0.555725f, -0.595435f, -1.228216f, -0.224152f, -0.524667f, -0.094262f, -1.725798f, 0.562584f, 0.610959f, -0.014286f, -0.162194f, -0.215901f, -0.159037f, -0.282404f, -0.084779f, -0.097448f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "bilinear"; + std::string padding_mode = "zeros"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-1.199109f, -0.025686f, 1.802375f, -1.059653f, 3.402826f, -0.568670f, -0.475489f, 1.743163f, 1.060884f, -0.015953f, 1.275653f, 0.009457f, -0.369450f, 1.218198f, 0.255044f, 0.273993f, 1.404381f, 1.082878f, 0.788966f, -0.137615f, 0.122478f, -1.076701f, -0.650897f, -1.619658f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.038587f, -0.371014f, -0.260918f, 0.159481f, 0.594851f, -0.840708f, 1.007133f, -0.130476f, -1.005535f, -0.649269f, 1.061781f, 1.097433f, -1.111536f, 0.846358f, 0.601391f, 0.710302f, 1.015835f, -0.646740f, 0.378931f, 0.491080f, -0.354592f, 0.401584f, -0.345256f, 0.741914f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-0.199899f, 1.437523f, -0.017180f, -0.422530f, -0.554188f, -0.088180f, 0.613663f, 0.843979f, 1.165913f, 0.161823f, -0.215288f, 0.001466f, 0.398506f, 0.909392f, 0.576145f, 0.897902f, 0.920312f, 1.201733f, -0.184698f, -1.360176f, -0.080218f, -1.352020f, -0.497572f, -0.710420f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "bilinear"; + std::string padding_mode = "border"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.546073f, -0.630178f, -0.634650f, 0.974665f, 0.209843f, 0.029890f, 1.709235f, -0.725759f, -0.876951f, 0.522287f, 0.462005f, -1.329269f, -0.295974f, 1.371414f, 0.973846f, 0.765543f, -0.403897f, -0.326279f, 0.748218f, -0.195299f, 0.676756f, -0.080633f, 0.158123f, 0.099984f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{1.182462f, -0.759228f, 0.230068f, -0.103567f, -0.252788f, -0.268017f, 0.762529f, 0.057356f, -1.168338f, -0.708432f, -0.409080f, 0.603860f, -0.776560f, 1.131504f, -0.267275f, -0.215474f, 0.940270f, 0.603129f, 1.017745f, 0.694133f, -0.364025f, -0.796167f, -0.089284f, 0.993165f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-0.243777f, 0.256440f, -0.179228f, 0.741578f, -0.571899f, 0.031558f, -0.425264f, 0.007242f, -0.044977f, 0.271677f, 0.955187f, -0.224230f, -0.395226f, 0.771988f, 0.108104f, 0.007673f, 0.371491f, -0.360026f, 0.151628f, 0.399982f, 0.038327f, 0.044739f, 0.445689f, 0.133017f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "bilinear"; + std::string padding_mode = "border"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.873307f, 0.004261f, -1.257887f, -1.084466f, 0.752979f, 0.323648f, -0.275010f, 1.305612f, -0.009480f, -0.831312f, -0.556290f, 2.070567f, 0.710039f, -0.146461f, -0.746745f, 0.725842f, 0.403461f, 0.234374f, 0.173281f, 1.724145f, -0.408946f, 0.782749f, -1.520847f, -0.314686f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.605180f, 0.169896f, 1.021029f, 0.161312f, -0.555188f, 1.135200f, 0.284017f, -1.170817f, -0.341630f, -0.817401f, 1.052104f, -0.198175f, -1.093830f, -0.075436f, 0.753615f, 0.311761f, 0.379445f, 0.111448f, 0.447382f, -0.292382f, -0.477360f, -1.121650f, -0.904004f, 0.520083f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-0.725617f, -0.743749f, 0.752979f, -0.185279f, -0.734326f, -0.760828f, -0.091786f, -0.129152f, -0.556290f, 0.964224f, -0.024687f, -0.196084f, -0.581904f, 0.496011f, 0.499240f, 0.319537f, 0.690648f, 0.150559f, -0.343065f, 0.269544f, 0.455333f, 1.124628f, 0.208392f, -1.276367f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "bilinear"; + std::string padding_mode = "reflection"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{0.540757f, -0.947807f, 0.202144f, -0.350748f, 0.545005f, 1.541211f, 0.600239f, -0.338015f, -1.080823f, -1.391537f, -0.352570f, 1.560770f, -0.822488f, -2.140920f, 0.099553f, -0.697505f, 0.665352f, -2.256198f, -1.002236f, -1.395144f, 0.415783f, 0.268104f, -0.151752f, 0.794042f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{1.051960f, -0.798975f, -0.129852f, -0.064453f, 0.535452f, 0.820411f, -0.190205f, -0.994177f, 0.594591f, 0.358958f, 0.482039f, -0.740241f, 0.772315f, 1.136586f, 0.104126f, -1.120858f, 0.842388f, -0.889742f, 0.275846f, 0.174381f, -0.561644f, 0.417835f, -1.073319f, 0.273311f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-0.793997f, -0.042818f, 1.034663f, -0.061725f, 0.327743f, -0.470152f, -0.528701f, -1.125254f, 0.678924f, 0.212033f, -0.430627f, -0.410903f, -1.743740f, -1.404122f, -1.882401f, -0.546577f, -0.033295f, 0.203686f, 0.631537f, -1.031405f, -1.182924f, 0.344248f, 0.246420f, 0.266212f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "bilinear"; + std::string padding_mode = "reflection"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{0.584178f, 1.050431f, 1.285579f, -1.616520f, -0.768962f, -1.220462f, 0.573128f, 0.699197f, -1.654887f, 0.493267f, -0.615042f, 1.311865f, 0.788249f, -1.232951f, 0.454381f, -1.436621f, 0.711631f, 0.554599f, -0.807529f, 1.680131f, 0.597634f, -0.238890f, -0.345997f, 1.770104f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.564800f, 1.031186f, 0.795913f, -0.629473f, -0.131544f, -0.377622f, -0.964948f, 0.000496f, 0.902922f, 1.011019f, 0.111961f, 0.272548f, -0.519506f, 0.905811f, -0.499330f, -0.833583f, 0.184792f, 0.719262f, -1.081910f, 1.084761f, 0.431677f, -0.840735f, -0.258489f, 1.041096f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-1.220462f, 0.901641f, 0.521980f, 1.284051f, -1.220462f, -0.717235f, 1.311865f, 0.687708f, -0.023386f, -1.654114f, 1.311865f, 0.029458f, 0.711631f, 0.786895f, 0.604097f, 0.711631f, -1.094857f, 0.673706f, -0.345997f, -0.805863f, 1.103092f, -0.345997f, 1.510167f, 0.165064f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "bicubic"; + std::string padding_mode = "zeros"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{0.497417f, 0.268522f, 1.476879f, 0.354795f, 1.624709f, 0.593423f, -1.725412f, -0.622016f, -0.466707f, -0.319962f, 0.701868f, 0.494252f, -0.630165f, 0.548236f, 1.042740f, 0.253800f, -2.667303f, 1.379165f, -0.519418f, 0.672783f, -0.005627f, -0.180192f, -0.018395f, 0.998084f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.213755f, 0.141747f, -0.562622f, -0.414594f, 0.325025f, -0.834438f, 0.197995f, 0.519270f, -0.472884f, 0.996769f, -0.078973f, 0.544455f, 1.188368f, -0.366802f, 0.652090f, -0.343235f, -0.175288f, -0.203365f, -0.007455f, -0.453322f, 0.281264f, 0.045216f, 0.760668f, -0.242886f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{1.007407f, 1.068583f, 0.492134f, 1.222040f, 1.576835f, 1.464183f, -0.238652f, -1.242164f, -1.156880f, 0.279082f, 0.744912f, 0.338287f, 0.215322f, 0.388598f, 0.866571f, 0.556826f, 0.608617f, 0.326312f, 0.044527f, -0.028766f, -0.136528f, -0.084880f, -0.121429f, -0.105516f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "bicubic"; + std::string padding_mode = "zeros"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-1.065470f, 0.402578f, -0.405242f, -0.583366f, -0.258523f, -0.605559f, -0.188242f, 0.959607f, 1.189619f, -0.179522f, -1.823240f, -0.051351f, -1.636092f, -2.510569f, -1.238273f, -0.929619f, -0.058536f, 0.772879f, 0.468944f, 0.259886f, 0.757624f, -2.041813f, -0.552378f, 0.626977f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{-1.199809f, 0.061445f, -0.035546f, 0.180524f, 0.919500f, 1.166411f, -0.711939f, -0.074825f, -0.480808f, -1.105975f, -0.873191f, 1.126273f, 0.699673f, 0.644581f, 0.666892f, -0.953375f, 0.126023f, 1.116858f, -0.669703f, 1.067513f, 0.315406f, 0.844252f, -0.514065f, 0.553221f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-0.086429f, -0.590424f, -0.090572f, -0.393926f, -0.379182f, -0.031455f, 0.347836f, 0.182097f, 0.050161f, 1.154870f, -0.134312f, -0.509844f, 0.697346f, -1.440179f, 0.264668f, 0.021389f, 0.729883f, -0.236038f, 0.576661f, 0.348301f, 0.149351f, -0.327477f, 0.607344f, -0.405680f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "bicubic"; + std::string padding_mode = "border"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{0.203585f, -1.032829f, 1.130481f, -0.570301f, -2.100938f, 0.389922f, 0.087343f, -0.857360f, 1.193520f, -0.019760f, 0.280285f, 1.811013f, 1.838673f, 0.164184f, 1.436009f, 0.167011f, -1.139939f, -0.029833f, -0.009878f, 0.079750f, 0.216590f, -0.265852f, -0.528116f, -0.451915f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.797796f, -1.010726f, 0.868577f, -1.132977f, 0.268082f, -0.786042f, -0.476635f, 0.212483f, -0.471816f, -0.189867f, -1.137389f, -1.131448f, 0.464836f, -0.507934f, -0.730068f, -0.473499f, -0.981082f, -0.959280f, 0.718047f, 0.609891f, 0.159844f, -0.655512f, 0.399241f, 0.053910f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-0.934180f, -1.004565f, -0.467118f, 0.384839f, 0.792549f, 0.188357f, -0.785741f, -0.871727f, -0.372851f, 0.958270f, 0.751528f, 0.046397f, 0.598629f, 1.686400f, 1.817043f, 0.015806f, 0.866266f, 0.480930f, -0.013358f, 0.152904f, -0.001292f, -0.385043f, 0.030959f, -0.152332f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "bicubic"; + std::string padding_mode = "border"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.427361f, 0.814325f, -1.412076f, -0.099774f, 0.074936f, 0.590322f, 0.398556f, -0.635891f, -1.081747f, -0.330179f, 0.271759f, -1.089819f, -0.746656f, -0.942538f, -1.251568f, -1.730282f, -0.722323f, 0.525964f, -0.436259f, -0.188952f, -0.499550f, 1.502071f, -0.014112f, 1.194050f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{-0.102021f, -0.935855f, -0.007380f, -0.996053f, -0.258157f, 0.695455f, -0.834420f, -0.808862f, -0.293012f, -0.328961f, 0.203145f, 0.199219f, 0.608516f, -0.826657f, -0.084685f, 0.671149f, 1.037966f, -0.087535f, -0.694344f, 0.344955f, 0.683373f, -0.749700f, -0.696352f, 0.530398f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{0.154701f, 0.273277f, 0.226316f, -0.467055f, -0.820643f, -0.311691f, 0.084699f, -0.052970f, 0.001158f, 0.679701f, -0.467804f, -0.607116f, -0.871407f, -0.210613f, -1.860685f, -1.059387f, -0.902250f, -0.918798f, -0.360562f, 0.476049f, 1.499304f, -0.418396f, -0.298854f, -0.235927f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "bicubic"; + std::string padding_mode = "reflection"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-1.084082f, -0.128738f, -0.681077f, -1.309896f, 0.660269f, -1.412063f, 1.834581f, 0.456195f, 0.162801f, -0.638266f, 0.897973f, -0.383653f, 0.297945f, 1.809414f, -0.091298f, 1.092744f, -0.102453f, -1.726535f, -0.484632f, 0.712097f, 1.820312f, -0.852073f, -0.341399f, -0.138106f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{-0.501236f, -0.770480f, -0.140656f, -1.129896f, 0.470370f, 0.885106f, 0.288068f, -0.118568f, 0.594968f, -0.761702f, 1.173892f, -1.193212f, -1.149534f, -0.283562f, 0.980213f, 0.120151f, 0.460855f, -0.879608f, 0.437623f, -0.134092f, 0.480988f, 0.847491f, 0.521616f, -0.102077f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-0.953278f, -0.722872f, -1.065112f, -1.071529f, -0.344328f, -0.233562f, 1.436462f, 1.232983f, -0.181487f, -0.297043f, 0.464837f, 0.396673f, 0.053896f, 0.733510f, 1.541248f, 1.117701f, -1.352406f, 1.131762f, 1.324986f, -0.882173f, 0.469635f, -0.247133f, -0.196824f, -0.393592f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) { + OpTester test("GridSample", 16); + std::string mode = "bicubic"; + std::string padding_mode = "reflection"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-1.122981f, 0.620969f, -0.876394f, -1.774003f, -0.810376f, -1.475962f, 0.667025f, 0.668804f, -0.748346f, 1.422400f, 0.138469f, -0.165945f, 1.266886f, -0.496157f, 0.158060f, 0.488900f, 0.414476f, 0.419527f, 0.238000f, -0.034674f, 0.229435f, 0.234530f, 0.320846f, 0.703888f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.471637f, -0.923628f, -0.909401f, 0.684338f, 0.224360f, 1.092855f, -0.320755f, -0.579618f, -0.111056f, 0.006071f, 0.915173f, -1.195296f, -0.085441f, 0.530823f, -0.660820f, -0.609769f, 0.579921f, -1.149822f, 0.284347f, -0.929024f, 0.596474f, -1.026049f, 0.737766f, -1.135959f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{0.998063f, -0.689213f, -1.266024f, -0.870706f, -1.217616f, 1.292693f, 0.543307f, 0.219521f, -0.255151f, 0.543599f, 0.062982f, 0.527696f, 0.387590f, 1.352544f, -0.758053f, -0.262859f, -0.820496f, -0.934255f, 0.434353f, 0.262797f, -0.092283f, -0.021089f, -0.106052f, -0.119717f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "nearest"; + std::string padding_mode = "zeros"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{0.404710f, -0.654932f, 0.052124f, 0.340055f, -0.212416f, 1.562917f, -0.907159f, -1.566185f, 0.596746f, 1.002548f, -0.820504f, 0.509186f, 0.951389f, 0.773736f, -2.144711f, 0.044147f, 1.290612f, 0.664926f, 0.530731f, -0.423196f, -0.388699f, 0.333224f, 0.293744f, -0.157543f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.528957f, 0.982925f, -0.033286f, -0.806271f, 0.793837f, -0.411498f, 0.621343f, -0.295724f, 0.510113f, 1.079311f, 1.115827f, -1.092078f, -0.793776f, -0.496160f, -0.765241f, 1.151400f, -0.105983f, -0.796009f, -0.533987f, -0.662838f, 0.489587f, -1.046701f, -1.118884f, -1.182913f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{1.562917f, 0.404710f, 0.340055f, 0.340055f, 1.562917f, -0.654932f, 0.509186f, -0.907159f, 1.002548f, 1.002548f, 0.509186f, -1.566185f, -2.144711f, 1.290612f, 0.951389f, 0.951389f, 0.773736f, 0.951389f, -0.388699f, 0.293744f, 0.530731f, 0.530731f, -0.423196f, 0.530731f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "nearest"; + std::string padding_mode = "zeros"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 3, 2}; + std::initializer_list X_data{-1.495959f, 0.018231f, 0.345600f, 0.031206f, 0.400390f, 0.425763f, 0.839517f, 1.238945f, 0.523906f, -1.658372f, 0.548335f, -1.398321f, -1.976414f, 1.232491f, -0.545575f, -0.069414f, 0.732245f, -0.150333f, -0.707132f, 0.467497f, 0.278677f, 1.335679f, 1.155313f, -0.056298f, 0.430615f, -0.932645f, -1.505319f, 0.103317f, 1.521579f, 0.365497f, 1.428928f, 0.364333f, 1.683777f, 1.010632f, 0.621895f, 2.284701f, 1.574905f, -0.310514f, 1.495724f, 1.003370f, -1.437482f, 0.043097f, -1.645546f, -1.464643f, 0.350139f, -0.105905f, -0.740495f, 1.157691f, 1.443377f, 0.198399f, -1.105180f, -2.037115f, 2.128767f, -0.204457f, 0.468464f, 1.203629f, -0.362309f, -0.130520f, 1.532353f, 1.547599f, -0.831847f, -1.008509f, 0.023218f, 0.342626f, -0.882915f, 0.560640f, -1.142297f, 1.119107f, 0.385787f, -0.068515f, -0.529550f, -0.233903f}; + std::initializer_list Grid_shape{2, 3, 3, 2, 3}; + std::initializer_list Grid_data{0.812645f, 0.528235f, -0.550793f, -0.856977f, -1.073535f, 0.059526f, 1.163856f, -0.227931f, -0.050518f, -0.872033f, 0.368412f, 0.760780f, -1.183099f, -0.844947f, 0.888849f, 0.284117f, -0.074815f, 0.214510f, -0.182450f, -0.838758f, -1.121316f, 0.789250f, -0.142724f, -0.445665f, -0.309738f, -0.654508f, -0.355420f, -1.030097f, 0.898012f, 0.490011f, -0.605186f, -0.409576f, 0.538365f, -0.444367f, 0.316432f, 0.330410f, -0.755392f, 0.300602f, 0.073421f, 1.048061f, -0.434184f, -0.308482f, 1.033921f, -0.979923f, 0.086698f, 1.156203f, -0.538042f, 1.150419f, 1.064809f, 1.116408f, -0.114508f, 1.085560f, -0.522863f, -0.410766f, 0.453879f, 0.253497f, 0.661531f, 1.140383f, -0.751187f, 0.636872f, 0.401477f, 0.633082f, 0.569007f, -0.448884f, -0.948427f, 0.960462f, -0.684283f, 0.767193f, -1.143172f, -0.207603f, 0.012719f, 0.207628f, 0.096998f, 0.378128f, -0.133613f, 0.293885f, 1.187501f, -0.776462f, -0.065516f, -0.458068f, 1.052916f, 1.027248f, -0.032723f, -0.415959f, -0.741439f, 0.858648f, -0.082636f, 1.130172f, 0.684314f, 1.050365f, 0.949108f, -0.779811f, 0.351243f, -0.497591f, 0.602104f, -0.107892f, 0.103884f, -0.829931f, -1.072471f, 0.451888f, 0.278862f, 0.104235f, 0.815033f, -0.501089f, 0.425977f, -0.660914f, 0.248640f, -0.273958f}; + std::initializer_list Y_shape{2, 2, 3, 3, 2}; + std::initializer_list Y_data{0.425763f, 0.839517f, -1.658372f, -0.545575f, -1.976414f, -1.658372f, -1.495959f, -1.658372f, 0.839517f, 0.548335f, -0.545575f, 0.523906f, 0.523906f, -1.658372f, 1.238945f, 1.232491f, -1.398321f, 1.238945f, -0.056298f, 0.430615f, 0.103317f, 1.683777f, 1.428928f, 0.103317f, -0.707132f, 0.103317f, 0.430615f, 1.521579f, 1.683777f, -1.505319f, -1.505319f, 0.103317f, -0.932645f, 0.364333f, 0.365497f, -0.932645f, -2.037115f, 0.198399f, -0.204457f, 1.443377f, -1.437482f, 0.350139f, -0.105905f, 0.043097f, -1.105180f, -0.105905f, -0.740495f, -0.204457f, -1.464643f, -0.740495f, -0.310514f, -0.105905f, -1.464643f, 0.350139f, -0.068515f, 1.119107f, -0.233903f, -1.142297f, 1.532353f, 0.023218f, 0.342626f, 1.547599f, 0.385787f, 0.342626f, -0.882915f, -0.233903f, -1.008509f, -0.882915f, 1.203629f, 0.342626f, -1.008509f, 0.023218f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "nearest"; + std::string padding_mode = "zeros"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-1.948141f, 1.836740f, -0.418393f, -0.125621f, 1.779137f, -0.028049f, 0.367697f, -0.388847f, -0.939514f, -0.129193f, -0.101240f, -3.087570f, -0.778617f, 1.026859f, 0.624162f, 0.291416f, 0.580998f, -0.185200f, 0.333020f, 0.415896f, 0.011702f, 0.014502f, -0.722870f, -0.201041f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.818167f, -0.394078f, 0.627076f, -1.124307f, -0.296864f, -0.244061f, -0.423780f, 0.504000f, -0.546789f, -0.139085f, -0.346504f, -1.126900f, -0.198169f, -1.016972f, 0.699725f, 0.641356f, 1.124151f, -0.402963f, 0.061023f, 0.235069f, 1.197862f, 1.099936f, -0.621047f, -1.021083f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{1.836740f, 0.000000f, -0.418393f, 1.779137f, -0.418393f, 0.000000f, -0.388847f, 0.000000f, -0.939514f, -0.101240f, -0.939514f, 0.000000f, 0.000000f, -0.185200f, 0.000000f, 0.291416f, 0.000000f, 0.000000f, 0.000000f, -0.201041f, 0.000000f, 0.014502f, 0.000000f, 0.000000f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "nearest"; + std::string padding_mode = "zeros"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 3, 2}; + std::initializer_list X_data{0.317302f, 0.629807f, -0.470444f, 0.215051f, 2.234212f, -1.940229f, 0.577203f, -0.166697f, -0.023467f, -0.451050f, -2.199999f, 1.469197f, -1.758133f, -0.570410f, -1.040355f, -0.627640f, 1.398573f, 0.275127f, -0.333592f, -0.677762f, -0.247167f, -0.290725f, -0.986956f, 0.173983f, -0.971920f, 0.225261f, -0.626680f, 1.660835f, 0.972993f, 0.223424f, 2.283593f, -1.145964f, -0.851223f, -2.052948f, -1.351783f, -0.028922f, 0.394421f, 0.057878f, -0.668671f, -0.088841f, 0.560186f, -0.105506f, 0.277478f, 1.047901f, -0.564728f, -0.287761f, 0.653621f, 0.259766f, 1.629452f, -2.337903f, -0.276703f, 0.258084f, -0.552200f, -0.464470f, -0.412042f, -1.047346f, 0.169468f, 1.334588f, 0.580615f, 1.217562f, -2.487876f, -1.218598f, -0.256617f, 1.397251f, 0.694875f, 0.732315f, 0.574448f, 0.673838f, -1.870634f, -0.855206f, 1.068415f, 0.096061f}; + std::initializer_list Grid_shape{2, 3, 3, 2, 3}; + std::initializer_list Grid_data{0.650046f, -0.680891f, -0.200337f, -1.006178f, -0.676990f, 0.500592f, -1.118072f, -0.684288f, 0.899676f, -0.615418f, -0.499387f, -0.336929f, 0.512951f, -0.787164f, 0.120318f, 0.490083f, -0.087112f, 0.216982f, -0.915417f, 0.542519f, 0.448475f, -0.150519f, -0.992244f, 0.479971f, 0.783050f, -0.209890f, 0.565605f, 0.444791f, -0.479961f, -0.083304f, 1.194526f, 0.005665f, -0.955336f, -0.087514f, 0.596991f, -0.391708f, -0.628420f, 0.988534f, 0.634814f, -0.203871f, 0.061307f, -0.126915f, 0.278599f, 0.042647f, -0.726162f, 0.222329f, 0.031386f, 0.077584f, -0.457305f, 0.307467f, -0.970375f, 0.358708f, 0.650272f, -0.132064f, -0.932160f, -0.004362f, 0.001704f, -1.037046f, -0.848754f, 1.109926f, 0.897382f, 0.665044f, 0.831311f, 0.461956f, 0.675346f, 0.794786f, -0.280329f, -0.152546f, 0.855656f, -0.000432f, -0.780824f, -0.930479f, 0.671131f, 0.993983f, 0.931935f, 0.199703f, 0.828337f, -1.101760f, -0.864556f, -1.154677f, 0.966824f, -0.010858f, -0.552558f, 0.406048f, -0.449199f, -0.769613f, 0.462838f, 0.219719f, -0.859342f, -0.790394f, 0.562644f, 0.912452f, 0.097688f, -0.602742f, 0.579449f, 0.209287f, -1.050575f, -0.777654f, 0.262652f, 0.742529f, -0.385517f, 0.580240f, -0.743175f, 1.148320f, 0.855053f, 0.224769f, 0.533871f, 0.417788f}; + std::initializer_list Y_shape{2, 2, 3, 3, 2}; + std::initializer_list Y_data{-0.166697f, 0.000000f, 0.000000f, 0.317302f, -0.166697f, -0.451050f, 1.398573f, -1.758133f, -0.627640f, -0.166697f, 0.000000f, 2.234212f, 1.398573f, -0.023467f, 0.215051f, -0.451050f, -0.470444f, 1.469197f, 0.225261f, 0.000000f, 0.000000f, -0.333592f, 0.225261f, 1.660835f, -1.351783f, 2.283593f, -2.052948f, 0.225261f, 0.000000f, -0.986956f, -1.351783f, -0.626680f, -0.290725f, 1.660835f, -0.247167f, 0.223424f, -0.564728f, 0.000000f, -0.464470f, -0.464470f, -0.276703f, 0.394421f, -0.464470f, 0.000000f, 0.000000f, 1.629452f, 1.629452f, 0.057878f, 0.259766f, 0.653621f, 0.000000f, -2.337903f, 0.000000f, -0.464470f, -0.256617f, 0.000000f, 0.096061f, 0.096061f, -1.870634f, -0.412042f, 0.096061f, 0.000000f, 0.000000f, 0.574448f, 0.574448f, -1.047346f, 0.732315f, 0.694875f, 0.000000f, 0.673838f, 0.000000f, 0.096061f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "nearest"; + std::string padding_mode = "border"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{0.660065f, 0.995767f, -0.226389f, 0.590604f, -2.628610f, 0.444899f, 0.023282f, 0.024018f, -0.584701f, 1.988638f, -0.023379f, 0.711650f, -1.062933f, -0.064113f, 1.178346f, -0.652373f, 1.259795f, 1.508661f, -0.079368f, 0.819443f, 0.836356f, -0.362184f, -1.153828f, -0.561180f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{-0.447651f, -0.521958f, 0.673539f, 0.222645f, 1.010165f, 0.451903f, 0.966699f, -0.966970f, 0.964714f, -0.551345f, -0.321222f, 0.007182f, -0.225038f, 0.237367f, 1.069316f, -0.716982f, 0.370785f, -0.964445f, 0.188419f, 0.988574f, 0.809140f, 1.027635f, 0.649589f, -0.099282f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{0.660065f, 0.590604f, 0.590604f, 0.995767f, 0.995767f, -0.226389f, 0.023282f, 1.988638f, 1.988638f, 0.024018f, 0.024018f, -0.584701f, 1.178346f, -0.064113f, -0.064113f, 1.508661f, 1.508661f, -0.652373f, 0.836356f, 0.819443f, 0.819443f, -0.561180f, -0.561180f, -0.362184f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "nearest"; + std::string padding_mode = "border"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 3, 2}; + std::initializer_list X_data{-0.920922f, -0.560469f, -2.244605f, -0.061799f, 0.523656f, 0.110097f, -0.944521f, 0.818932f, 1.069286f, 0.611457f, -0.355875f, 1.664810f, 0.116694f, 2.318200f, 0.681699f, -0.792880f, -0.025672f, -0.592222f, 0.229768f, -0.521888f, 0.570937f, -0.029345f, -0.873323f, 1.721509f, 2.011626f, -0.310838f, 1.121670f, 0.778967f, -0.450894f, 1.030269f, 0.166967f, -0.244737f, 0.227200f, -0.416612f, -0.276513f, 0.714623f, 0.908783f, -1.393580f, -0.983675f, -0.366833f, 1.473970f, 0.624368f, -0.607720f, -0.523833f, -0.124702f, -0.766457f, -0.131027f, 2.227047f, 1.399269f, 0.053366f, -0.295771f, -0.283811f, 0.019280f, -0.104450f, -0.574185f, -2.130628f, 0.617878f, -1.728151f, -0.272528f, 1.299354f, -1.109310f, -1.881107f, -1.300843f, -0.765376f, -0.477722f, -1.230664f, -0.495792f, 1.061688f, 1.244247f, -0.550821f, -0.520524f, 1.541448f}; + std::initializer_list Grid_shape{2, 3, 3, 2, 3}; + std::initializer_list Grid_data{-1.189605f, -0.312072f, 0.459409f, 1.033285f, -1.083635f, 0.572921f, -1.138649f, -1.147562f, -0.751493f, -0.158500f, 0.335153f, -0.912613f, 0.924528f, 1.085165f, 0.073832f, 0.976781f, -0.543258f, -0.474714f, -0.154854f, 0.131118f, -0.837104f, -0.960885f, 0.474040f, 0.345992f, 1.173923f, -0.489256f, 0.423768f, -0.484246f, 0.592379f, -0.066474f, 0.889570f, 0.666682f, 0.998817f, 0.616675f, 0.045084f, 1.034127f, -0.704858f, 1.131824f, 1.172625f, 1.146321f, -0.560545f, -0.635830f, 0.075922f, 0.373677f, 0.601953f, 0.488043f, 1.021787f, -0.300648f, -0.393688f, 0.402240f, 0.334401f, -0.699993f, 0.116070f, -0.911100f, -0.352043f, -0.470968f, 1.051900f, -1.080208f, -0.708510f, -1.174356f, 0.302647f, -0.923627f, 0.388249f, -0.833533f, -0.768697f, -0.613051f, 0.180083f, 1.102657f, 1.124055f, -0.090660f, -1.175396f, -0.396450f, -0.457333f, -0.255235f, 0.458506f, 0.603882f, 0.532050f, 0.342802f, -0.485794f, -0.012730f, 0.152721f, -0.612948f, -0.107348f, -0.149795f, -1.133775f, 0.813507f, -0.121323f, -1.037352f, 0.949408f, -0.645689f, 0.424853f, 1.190055f, 0.055551f, 0.345244f, 0.476794f, 0.906949f, -0.368187f, -0.675263f, -0.093908f, 0.938461f, 0.103178f, 0.833774f, -0.008922f, 0.368184f, 0.041727f, 0.032575f, -1.141943f, -1.049081f}; + std::initializer_list Y_shape{2, 2, 3, 3, 2}; + std::initializer_list Y_data{1.069286f, 2.318200f, -0.920922f, -2.244605f, 1.664810f, 0.818932f, -2.244605f, 1.069286f, 0.611457f, -0.355875f, -0.592222f, -0.792880f, -0.025672f, -0.560469f, -0.792880f, 1.664810f, 1.069286f, -2.244605f, 1.121670f, -0.244737f, 0.229768f, 0.570937f, 1.030269f, -0.310838f, 0.570937f, 1.121670f, 0.778967f, -0.450894f, 0.714623f, -0.416612f, -0.276513f, -0.521888f, -0.416612f, 1.030269f, 1.121670f, 0.570937f, -0.295771f, 0.908783f, -0.523833f, 0.908783f, -0.104450f, -0.607720f, -0.124702f, 2.227047f, -0.124702f, -0.124702f, -0.131027f, 1.473970f, 2.227047f, -0.283811f, -0.607720f, -0.283811f, -0.124702f, -1.393580f, 1.244247f, -0.574185f, -1.881107f, -0.574185f, 1.541448f, -1.109310f, -1.300843f, -1.230664f, -1.300843f, -1.300843f, -0.477722f, -0.272528f, -1.230664f, -0.550821f, -1.109310f, -0.550821f, -1.300843f, -2.130628f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "nearest"; + std::string padding_mode = "border"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{0.950589f, -1.656624f, 0.767704f, -0.650720f, -1.404308f, -0.531582f, -0.280854f, 0.344309f, -0.959146f, -0.115645f, 0.515696f, -0.114243f, 1.971614f, 0.274268f, 0.543080f, -1.758563f, 1.771011f, 0.934901f, 0.695798f, 1.905137f, 1.598307f, 1.108385f, 0.156008f, 1.290824f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.482490f, -0.910951f, -0.001676f, -0.442514f, 0.580438f, 1.039346f, -0.159076f, -0.603960f, -0.922037f, -0.705026f, 0.346468f, 0.275332f, 0.646235f, -0.178307f, 0.616600f, -1.069108f, 0.322583f, 1.164952f, -1.187638f, -0.622953f, 0.768203f, -0.187618f, -0.639652f, 0.732078f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-1.656624f, 0.950589f, -0.531582f, 0.950589f, 0.950589f, -0.650720f, 0.344309f, -0.280854f, -0.114243f, -0.280854f, -0.280854f, -0.115645f, -1.758563f, 0.274268f, 0.934901f, 1.971614f, -1.758563f, 1.771011f, 1.108385f, 1.905137f, 1.290824f, 0.695798f, 1.108385f, 0.156008f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "nearest"; + std::string padding_mode = "border"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 3, 2}; + std::initializer_list X_data{0.465448f, -0.337086f, -0.870849f, -0.389573f, -0.083941f, 1.306894f, 0.719508f, -0.203690f, -1.143864f, 1.163003f, 0.312170f, -2.008687f, 1.731257f, -0.270431f, 1.095352f, -1.673520f, 0.492743f, 0.521962f, -1.938783f, -0.186813f, -0.836257f, -1.835450f, 0.476500f, -0.123386f, 0.246604f, 1.374159f, -0.158435f, 1.268192f, -0.704226f, -0.195314f, -0.277259f, 0.582961f, -0.340940f, 0.192264f, 0.463124f, -2.719402f, -0.593470f, -1.165777f, 0.566071f, 1.622836f, -0.886798f, 1.874877f, -0.849095f, 0.550185f, 0.604298f, 0.073976f, -0.800372f, -0.097283f, -1.576251f, -0.633278f, -1.776745f, -0.827586f, 0.665697f, 0.884698f, 0.467112f, -0.645219f, -0.510110f, 0.032418f, -1.056009f, -0.206175f, -0.173385f, 0.947787f, 1.937234f, 0.615880f, -0.311580f, 0.770921f, -0.841602f, 1.796220f, 0.479491f, 1.609346f, 1.113868f, -0.453360f}; + std::initializer_list Grid_shape{2, 3, 3, 2, 3}; + std::initializer_list Grid_data{-0.151540f, -0.033291f, -0.597203f, 0.836404f, -0.686848f, -0.485355f, -0.936738f, -1.009057f, 1.065352f, -0.926635f, -0.165670f, -0.347352f, 0.439545f, 0.320963f, -0.919909f, 1.077689f, -1.195359f, 0.118687f, -0.100253f, -0.278089f, 0.817760f, 1.013180f, 0.156316f, -0.423839f, 0.892139f, 0.753924f, 0.215530f, -0.328214f, 0.050592f, 1.069553f, 0.130134f, -0.236478f, -1.015986f, -0.643059f, 0.866682f, -0.042256f, -0.079912f, 0.467233f, -0.789513f, -0.081063f, -0.337505f, 0.627865f, 0.976589f, 0.753489f, 0.894667f, -1.072442f, -0.426020f, 0.142099f, -1.019226f, 0.325527f, -0.786578f, 0.514215f, 0.971223f, -1.026539f, 1.005531f, 0.559922f, -0.791906f, 1.148613f, -1.039306f, -0.807864f, -0.596935f, -0.060766f, 0.215484f, -0.352165f, -1.137417f, -0.138518f, 0.910459f, 0.923925f, 0.600710f, 0.174227f, 0.298169f, -0.925092f, 0.485927f, -1.194283f, -0.495564f, -0.315357f, 0.881199f, -0.034981f, -0.546611f, 0.209651f, -0.995724f, -0.317709f, 0.332343f, -0.079474f, -0.126024f, 0.733410f, -0.911554f, -0.605911f, 1.161566f, 0.238787f, -0.194293f, 0.621583f, 0.721901f, -0.200521f, -0.499850f, -0.196149f, 0.435730f, -0.153196f, 0.698401f, -0.978582f, -0.588758f, 0.914808f, 0.157427f, 0.241646f, 0.394674f, -0.283552f, -0.479889f, 0.344261f}; + std::initializer_list Y_shape{2, 2, 3, 3, 2}; + std::initializer_list Y_data{-0.870849f, -0.337086f, 1.731257f, -0.870849f, -0.389573f, -0.203690f, 1.095352f, -0.389573f, -2.008687f, 1.095352f, -0.389573f, 0.312170f, -0.083941f, 1.731257f, 0.521962f, 0.719508f, -0.870849f, 1.306894f, -0.836257f, -0.186813f, -0.277259f, -0.836257f, -1.835450f, 1.374159f, -0.340940f, -1.835450f, -0.195314f, -0.340940f, -1.835450f, -0.704226f, 0.476500f, -0.277259f, -2.719402f, 0.246604f, -0.836257f, -0.123386f, 1.874877f, -1.165777f, 0.604298f, -0.849095f, 0.884698f, 1.622836f, -1.165777f, -0.800372f, 0.566071f, 0.604298f, -0.886798f, -0.800372f, 0.665697f, -0.849095f, -0.827586f, -1.576251f, -0.827586f, -1.576251f, -0.206175f, -0.645219f, 1.937234f, -0.173385f, -0.453360f, 0.032418f, -0.645219f, -0.311580f, -0.510110f, 1.937234f, -1.056009f, -0.311580f, 1.113868f, -0.173385f, 1.609346f, -0.841602f, 1.609346f, -0.841602f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "nearest"; + std::string padding_mode = "reflection"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{0.079043f, 0.407494f, 1.038992f, -0.437542f, 0.991216f, 0.409636f, 1.050403f, -0.687172f, -2.021689f, 0.789633f, 0.538178f, 0.414847f, 2.221617f, -0.254833f, -0.179968f, -0.952356f, -1.213159f, 0.499103f, -0.374865f, 0.441938f, -0.114847f, 0.716887f, 1.059090f, 0.438870f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.355147f, -0.222342f, -1.197658f, 0.844060f, 1.188586f, 0.605435f, 1.174232f, 0.327060f, -0.094032f, -0.955794f, -1.048806f, -0.826196f, -0.304468f, 0.698768f, -0.495101f, -0.046607f, -0.016936f, -0.784415f, -0.032484f, 1.158664f, 0.959105f, 0.913943f, -0.118352f, 0.021282f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-0.437542f, 0.991216f, 0.409636f, -0.437542f, 0.079043f, 0.079043f, 0.789633f, 0.538178f, 0.414847f, 0.789633f, 1.050403f, 1.050403f, -1.213159f, -0.179968f, 2.221617f, -1.213159f, 0.499103f, -0.179968f, 1.059090f, -0.114847f, -0.374865f, 1.059090f, 0.438870f, -0.114847f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "nearest"; + std::string padding_mode = "reflection"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 3, 2}; + std::initializer_list X_data{0.189379f, 0.825309f, -0.701365f, 0.787800f, -1.102514f, 0.126954f, 1.824453f, -0.144635f, -1.712534f, 0.361739f, -0.462516f, -2.153102f, 0.536963f, 0.581639f, -1.325014f, -1.314673f, -0.524797f, -1.304159f, -1.093757f, -1.703444f, -0.672976f, 0.505303f, 1.497654f, -0.545441f, -1.334648f, 0.474489f, 0.484384f, 0.434399f, -0.733471f, 0.452991f, 0.324606f, -1.307459f, -0.640603f, -0.450100f, 0.772854f, 1.281813f, -0.481714f, 1.224667f, -0.437546f, 0.371986f, -0.320368f, -1.011020f, -1.199298f, 0.213302f, 1.795444f, 0.409271f, 1.328065f, -1.037527f, 0.224494f, 0.217863f, -0.925740f, 0.344755f, -1.445667f, -0.935542f, -0.427280f, -2.010803f, -1.174929f, 1.434105f, -1.168630f, 0.321896f, -0.561974f, -0.209305f, -1.063838f, 1.451708f, 0.266913f, -0.132535f, 0.798299f, 0.619547f, -0.324459f, 0.255630f, 0.488773f, -0.142060f}; + std::initializer_list Grid_shape{2, 3, 3, 2, 3}; + std::initializer_list Grid_data{-0.034431f, 1.048250f, 0.160255f, -0.446426f, 0.879791f, -0.683555f, 0.039704f, 0.269729f, 0.538601f, -1.107191f, 0.058867f, -0.310704f, 0.778040f, 0.403733f, 0.480956f, 0.721512f, -0.268657f, -0.076883f, 0.962704f, -0.967187f, -0.829464f, 0.087786f, -0.475353f, 0.068725f, 1.060032f, -0.139108f, -1.023162f, -0.545493f, 1.102040f, -0.263627f, -0.526173f, 0.540152f, 0.148556f, -1.058015f, 0.999344f, 0.675750f, 1.043022f, 0.525119f, -0.404585f, -0.391737f, 0.581547f, -0.232625f, 0.235264f, -1.162786f, -0.593187f, 0.445737f, -0.059159f, -0.576901f, -1.046721f, 0.762672f, -0.241271f, -1.179040f, 1.157741f, 0.583952f, -0.717767f, -0.875798f, 1.159575f, 0.005010f, -0.721707f, 0.690536f, -0.249959f, 0.082204f, -0.625120f, -1.016394f, -0.796947f, -0.131764f, -0.868737f, 1.182731f, 0.012988f, -0.459398f, 0.474264f, -1.063883f, -0.613791f, 0.450721f, -1.019595f, 0.598084f, 0.100866f, -1.000569f, -1.190919f, 0.379261f, 0.567202f, -0.239888f, -1.061107f, -0.691616f, 0.127540f, 0.043657f, 0.307172f, 0.212184f, -0.062900f, 0.633272f, 1.164016f, 0.999377f, 1.090411f, -0.405004f, -0.409578f, -0.132722f, 0.354671f, 0.485734f, -0.106963f, -0.775112f, -0.905400f, 1.155262f, -0.322627f, -0.162203f, -0.735432f, -0.594912f, 0.263568f, 0.505424f}; + std::initializer_list Y_shape{2, 2, 3, 3, 2}; + std::initializer_list Y_data{-0.462516f, -1.102514f, -1.314673f, -1.712534f, 0.361739f, 0.361739f, 0.825309f, 0.361739f, 0.787800f, -0.462516f, -0.462516f, -0.524797f, -2.153102f, -0.462516f, 0.825309f, 0.787800f, -0.462516f, -0.524797f, -0.733471f, 1.497654f, -0.450100f, 0.484384f, 0.434399f, 0.434399f, -1.703444f, 0.434399f, 0.505303f, -0.733471f, -0.733471f, 0.772854f, 0.452991f, -0.733471f, -1.703444f, 0.505303f, -0.733471f, 0.772854f, 0.224494f, 0.217863f, -0.437546f, -1.199298f, 1.328065f, -0.437546f, -0.437546f, 0.371986f, -0.925740f, -0.481714f, 0.409271f, 0.344755f, -0.935542f, 1.795444f, 0.409271f, 0.224494f, -0.437546f, -0.925740f, 0.798299f, 0.619547f, -1.174929f, -0.561974f, 0.266913f, -1.174929f, -1.174929f, 1.434105f, -0.324459f, -0.427280f, 1.451708f, 0.255630f, -0.142060f, -1.063838f, 1.451708f, 0.798299f, -1.174929f, -0.324459f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "nearest"; + std::string padding_mode = "reflection"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.769854f, -0.805659f, 0.813652f, -0.010183f, 0.276463f, -0.771678f, -2.563015f, -1.243904f, 2.365071f, 0.730651f, -0.068795f, -1.495438f, 0.211578f, -1.042373f, 0.884036f, -0.746288f, 1.011368f, 0.194463f, -0.307214f, 0.556053f, 0.629364f, 0.083601f, 0.248627f, -0.822453f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.569884f, 1.163780f, -0.977608f, -0.145509f, 0.651234f, 1.099753f, -0.853766f, 0.509955f, 0.495437f, 0.723445f, -0.827299f, 0.856340f, -0.522676f, -0.738659f, 0.238269f, 1.016568f, -0.794666f, 0.640690f, -0.137431f, 0.383085f, 0.936085f, 0.325824f, -0.996188f, -0.361291f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-0.771678f, 0.813652f, -0.771678f, 0.276463f, -0.771678f, 0.276463f, -1.495438f, 2.365071f, -1.495438f, -0.068795f, -1.495438f, -0.068795f, 0.211578f, 0.194463f, 1.011368f, 1.011368f, -0.746288f, 0.211578f, -0.307214f, -0.822453f, 0.248627f, 0.248627f, 0.083601f, -0.307214f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "nearest"; + std::string padding_mode = "reflection"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 3, 2}; + std::initializer_list X_data{-0.185898f, 0.403325f, 0.737314f, 0.545995f, -1.010481f, -1.204522f, -0.147342f, 0.232425f, -1.339485f, 0.013892f, -1.098319f, 0.478079f, 0.051159f, -0.906061f, -0.428560f, 0.583460f, 1.137472f, 1.487881f, 1.349931f, -0.118774f, 0.436410f, 1.334689f, -1.115846f, 0.159820f, 0.617671f, 0.546630f, 1.861115f, 0.500044f, 0.623446f, 0.541840f, -0.279259f, -0.573875f, 0.783115f, -1.125017f, -1.166457f, -0.827232f, 0.273074f, 0.702953f, 1.288608f, -1.037043f, 0.021860f, 0.575628f, -0.034170f, 1.400741f, 0.508057f, 0.994702f, -2.267981f, 1.677437f, 0.175134f, 0.712679f, -0.440408f, -1.248550f, 1.618839f, -0.214598f, 0.486398f, -0.478466f, 0.912471f, 0.499651f, -0.886606f, -0.929524f, 0.449260f, 0.017969f, -0.050906f, 1.799695f, -0.033007f, -1.884108f, -1.392415f, -0.852990f, -0.052969f, 0.819434f, 0.089723f, 0.598047f}; + std::initializer_list Grid_shape{2, 3, 3, 2, 3}; + std::initializer_list Grid_data{-0.118828f, 0.082315f, 0.328488f, -0.834821f, -0.138863f, -0.988801f, -0.976128f, 0.156412f, -1.171383f, 0.319534f, -1.105438f, -0.834991f, -0.248995f, -1.145138f, 0.969159f, 0.983228f, -0.626795f, 0.251376f, 0.613890f, 0.381328f, -0.160747f, -1.131853f, 0.872567f, -1.052516f, -0.222240f, 0.074438f, -0.395210f, -0.438906f, -1.037125f, 0.066119f, -0.136254f, 1.046163f, -0.395065f, 0.927498f, 0.056808f, -0.539139f, -0.285382f, -0.136177f, 0.012430f, -0.197703f, 0.356128f, 0.988219f, 0.188620f, 0.434655f, 0.741024f, 0.258662f, 0.553165f, 0.629461f, 1.123216f, -1.095185f, 0.410630f, -0.054374f, -0.215508f, -0.462650f, 0.721441f, 1.097745f, -0.979308f, 0.648336f, 0.827460f, 0.209729f, 0.014136f, 0.923431f, 0.035578f, -0.299309f, -0.088614f, 0.385002f, 0.300407f, -0.064744f, 0.378800f, 0.323185f, -0.972071f, 0.299012f, 0.734213f, 0.137618f, -0.109532f, 0.919238f, -1.048417f, -0.547724f, -0.542389f, 1.036863f, -1.160666f, 0.119013f, -1.162427f, -0.039461f, 0.447285f, -0.280625f, 1.164882f, 0.003820f, -0.611796f, 0.309439f, 0.624077f, -0.002384f, 1.026569f, -0.759499f, 0.512014f, 0.681403f, 0.596030f, -0.000440f, 0.342557f, -0.941414f, -0.941707f, -0.074588f, -0.150400f, 0.891031f, 0.871352f, 0.813657f, -0.549640f, -0.942044f}; + std::initializer_list Y_shape{2, 2, 3, 3, 2}; + std::initializer_list Y_data{-1.339485f, 0.737314f, 0.737314f, 0.403325f, 0.051159f, 0.232425f, 0.478079f, -1.010481f, 0.737314f, -0.147342f, -1.010481f, 0.545995f, -1.339485f, 1.137472f, 1.487881f, 1.487881f, -0.906061f, 0.737314f, 1.861115f, 0.436410f, 0.436410f, -0.118774f, -0.279259f, 0.546630f, 0.541840f, -1.115846f, 0.436410f, 0.617671f, -1.115846f, 1.334689f, 1.861115f, -1.166457f, -0.827232f, -0.827232f, -0.573875f, 0.436410f, 0.575628f, 1.677437f, 1.677437f, -0.440408f, -1.248550f, 1.400741f, 0.994702f, 0.702953f, 0.021860f, 1.400741f, -1.248550f, 1.400741f, -1.248550f, 1.618839f, -1.248550f, -0.034170f, 1.618839f, 0.702953f, -0.929524f, -1.884108f, -1.884108f, -0.052969f, 0.819434f, 0.017969f, 1.799695f, -0.478466f, -0.886606f, 0.017969f, 0.819434f, 0.017969f, 0.819434f, 0.089723f, 0.819434f, 0.449260f, 0.089723f, -0.478466f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "linear"; + std::string padding_mode = "zeros"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{0.010274f, 1.493496f, -0.264303f, 0.035897f, -0.751962f, -0.370195f, -0.514836f, 0.399928f, -0.191651f, -0.239505f, -1.931184f, -1.074773f, -0.121908f, 0.050673f, -0.741501f, -0.229127f, -0.360925f, 0.264077f, 1.537180f, 1.603202f, -1.241810f, -0.388456f, -0.609742f, 0.095097f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{-0.118589f, -0.020968f, -0.893597f, 1.170924f, -0.517539f, 0.698168f, -0.672718f, 0.008056f, 0.410793f, -1.101817f, 0.550440f, -0.918534f, 0.167456f, -0.237959f, 0.687868f, 1.166281f, 0.270439f, -0.034265f, -0.594534f, 0.447403f, -0.577587f, 0.495680f, -0.520113f, 0.813977f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-0.115313f, -0.606595f, -0.518616f, -0.218999f, 0.948961f, 1.063015f, -0.210622f, -1.563324f, -1.265386f, -0.212304f, 0.117155f, 0.159843f, -0.342175f, 0.138844f, -0.402196f, -0.457139f, -0.432849f, -0.286783f, -0.191760f, -0.012426f, -0.621658f, -0.799488f, -0.763820f, -0.551571f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "linear"; + std::string padding_mode = "zeros"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 3, 2}; + std::initializer_list X_data{-1.787070f, -0.894227f, -0.113069f, 0.713917f, 0.041566f, -1.847208f, 0.013441f, -1.439041f, 1.051864f, 1.576791f, 1.180527f, -1.457019f, 0.298446f, 1.142738f, -0.961347f, -0.471509f, -0.074154f, 0.047739f, -0.679950f, -2.306940f, -0.552171f, -0.357144f, -0.492247f, -0.455872f, 0.399680f, 0.057915f, -0.362704f, 1.083763f, -0.084941f, -1.691393f, -1.913178f, 0.696366f, 1.172833f, 0.901506f, -1.189840f, -1.197158f, 0.007338f, 0.161468f, -1.048452f, -0.480832f, 0.391235f, 1.056413f, -0.116648f, 0.632195f, 0.840261f, -2.187738f, 0.302910f, -0.956190f, -0.362645f, 0.771747f, 0.524840f, -0.954672f, -1.084612f, -0.525794f, -0.969691f, -1.056405f, -0.364709f, 0.336189f, -0.178281f, 1.015025f, -0.532580f, 0.036602f, -0.434395f, -1.208987f, -1.084039f, 0.642844f, -0.819208f, -0.982898f, -0.109210f, -1.231957f, 1.083089f, -0.870451f}; + std::initializer_list Grid_shape{2, 3, 3, 2, 3}; + std::initializer_list Grid_data{0.350638f, -0.554259f, 0.740901f, -1.134597f, -0.450763f, -0.706065f, -0.712365f, -0.727142f, -1.130749f, 0.205940f, -0.237380f, -1.010413f, -0.000494f, -0.199898f, 0.495032f, -0.939943f, -0.337590f, 0.247001f, 0.508664f, 0.090780f, 0.325198f, 1.199561f, -0.415694f, 0.817854f, 1.033666f, -1.061540f, 0.290273f, 0.679739f, -0.187185f, 0.662278f, 0.040817f, 0.913540f, 0.025838f, -0.768267f, 0.911326f, 0.356885f, 1.020923f, 0.297892f, 0.637209f, 0.748214f, 0.202064f, -0.278959f, 0.247841f, -0.836700f, 0.040996f, -0.385697f, 0.075869f, -0.950110f, 0.733227f, -1.107135f, 0.513890f, 0.790272f, -1.099795f, 1.084212f, -0.892061f, -0.235640f, 0.621837f, -0.380523f, 1.069422f, -0.529383f, -0.160661f, -0.784422f, -0.556715f, 1.171015f, 0.902476f, 0.088357f, 0.098667f, -1.018314f, 0.905937f, -0.179914f, -0.500513f, -0.954987f, 0.986618f, 0.569025f, 0.722795f, 0.124254f, -0.814285f, 0.491561f, 0.138395f, 0.402690f, -0.298810f, -0.566298f, 0.985118f, 0.402260f, -0.487031f, 0.107159f, -0.260850f, -0.102620f, 0.672911f, -0.955102f, 1.086040f, 0.807667f, 0.001031f, -0.490841f, 0.244670f, -0.794290f, 0.779461f, -0.634633f, 0.229290f, -1.180597f, 0.574650f, 0.812338f, 0.900697f, 0.097950f, 0.708525f, 0.409153f, 0.804739f, 0.677169f}; + std::initializer_list Y_shape{2, 2, 3, 3, 2}; + std::initializer_list Y_data{0.171946f, -0.411342f, -1.046998f, -0.002345f, 0.246533f, 0.396970f, 0.664278f, 0.199883f, -0.636287f, 0.162358f, -0.061161f, 0.528084f, 0.041846f, 0.750291f, -0.476442f, 0.142258f, -0.067844f, 0.869081f, 0.360025f, -0.406785f, -0.701985f, -0.718142f, 0.519179f, -0.022693f, 0.618451f, 0.708731f, 0.224429f, 0.784241f, -0.812606f, -0.521137f, 0.266524f, 0.190886f, 0.231077f, -0.465330f, 0.204730f, 0.348489f, 0.356190f, 0.256096f, -0.038212f, -0.943162f, 0.258902f, -0.360112f, -0.920536f, 0.126677f, -0.523600f, -0.361337f, -0.154168f, 0.179761f, -1.141155f, -0.423488f, -0.225410f, -0.204886f, -1.162816f, -0.678226f, -0.384409f, -0.146245f, -0.622531f, 0.312188f, -0.828836f, -0.541017f, -0.778291f, -0.602484f, -0.328754f, -0.163964f, -0.508068f, 0.193021f, 0.273133f, -0.217934f, -0.562420f, 0.287725f, -1.097279f, -0.306201f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "linear"; + std::string padding_mode = "zeros"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{0.185965f, 0.133937f, -0.763030f, 0.733342f, 1.932445f, -0.582571f, -1.312078f, 0.738952f, 0.444459f, 0.742593f, -0.805960f, -0.202535f, 0.970323f, -0.801176f, 0.277655f, -1.938051f, -1.879800f, 0.287116f, 0.261958f, -0.358247f, -0.107750f, 0.748162f, -0.742330f, 0.344665f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{-0.460252f, 0.734353f, -1.069308f, 1.005361f, 1.198595f, -0.327629f, 0.474026f, 1.196645f, 0.361782f, 0.469280f, 0.440632f, -0.490951f, 0.292918f, -0.639568f, 1.024697f, -0.514217f, 0.274326f, -0.347614f, 0.600117f, 0.019780f, 0.659824f, -0.324940f, -0.704174f, 0.460072f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{1.646426f, 0.409452f, 0.132247f, -0.106052f, -0.009495f, 0.270785f, -0.702581f, -0.170769f, 0.223282f, -0.044740f, 0.006388f, 0.645576f, -0.476802f, -0.504368f, -0.897503f, -1.684608f, -1.162742f, -0.963921f, -0.197266f, -0.050021f, 0.151796f, 0.662485f, 0.175502f, -0.434265f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "linear"; + std::string padding_mode = "zeros"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 3, 2}; + std::initializer_list X_data{-0.299262f, -0.304887f, 0.906636f, -0.392850f, -0.050410f, 0.548199f, -1.235108f, -0.475848f, 0.635455f, 0.307462f, -1.241370f, -0.538672f, 0.863466f, 0.799983f, -0.090064f, -0.751721f, 0.956040f, -0.117709f, -2.183699f, -0.484444f, 1.105900f, 0.164466f, 0.720736f, 0.168044f, -0.656400f, 1.770106f, -0.544832f, 1.358424f, 0.981648f, -1.759268f, -0.526924f, 1.322339f, 0.148774f, 0.321413f, -1.257438f, -0.383775f, -2.117908f, -0.077921f, -0.197889f, 0.555813f, -1.517724f, 1.419652f, -0.891774f, 1.684663f, -1.524669f, -2.055758f, -0.299843f, -0.644860f, 0.428609f, -1.704372f, 1.257671f, -0.886508f, -0.029344f, -1.718824f, -0.294273f, 1.537690f, -1.366837f, -1.610098f, 0.650240f, -0.288219f, 0.837292f, 0.431683f, -0.405852f, 0.492271f, 0.416507f, 0.971658f, -0.183526f, 0.615709f, -0.081615f, 1.160796f, 1.431487f, 0.485687f}; + std::initializer_list Grid_shape{2, 3, 3, 2, 3}; + std::initializer_list Grid_data{0.884040f, -0.825214f, 0.496720f, -0.440955f, 1.195811f, 0.169268f, -1.042100f, 0.206524f, 0.145895f, -1.160650f, 0.240829f, 1.144915f, 0.345332f, -0.006382f, -0.248763f, 0.318888f, -0.534619f, 1.181719f, 1.037350f, 0.560600f, -0.446974f, -1.126746f, -0.690807f, 1.166754f, -1.101454f, -1.145775f, -0.086488f, 0.381780f, -1.194351f, -1.114106f, 0.006524f, -0.402521f, 0.836016f, 0.344533f, -1.041627f, -1.081571f, 0.824102f, -0.212785f, -0.524949f, 0.377977f, -0.235842f, 0.573897f, 0.304308f, -0.519568f, -0.961787f, 0.649611f, -0.720973f, -0.132725f, 0.164074f, -0.698360f, 0.653669f, -0.844065f, 0.294728f, 0.128341f, 0.440293f, -1.177701f, 0.069319f, 0.585007f, -0.768260f, 0.296941f, 0.004702f, 1.018020f, -0.254096f, 0.008198f, -0.521925f, -0.295744f, 0.343532f, -1.157334f, 0.910329f, 0.862921f, 0.508195f, 0.898317f, -0.373544f, 0.273330f, 0.061050f, -0.829794f, -0.461335f, -0.426012f, -0.296704f, -1.065526f, -0.843948f, -0.113955f, -0.182548f, -1.089296f, 0.256401f, 0.653393f, 0.999377f, 1.009925f, -0.838519f, -0.384579f, -0.569276f, 0.220093f, 0.321562f, 0.266984f, 0.701244f, 0.633093f, -0.644096f, 0.823778f, 0.809482f, 0.158802f, -1.044029f, -0.735991f, 0.334411f, 0.414891f, 1.118940f, 0.610743f, 0.434932f, -0.040928f}; + std::initializer_list Y_shape{2, 2, 3, 3, 2}; + std::initializer_list Y_data{0.222880f, -0.137918f, 0.042779f, 0.027606f, 0.146833f, 0.119531f, 0.062001f, 0.077615f, -0.124874f, -0.020856f, 0.248748f, -0.050235f, -0.185885f, -0.124030f, -0.148987f, -0.345107f, 0.753440f, -0.055873f, 0.674388f, 0.063018f, -0.054480f, -0.034452f, 0.780917f, 0.193151f, -0.140647f, -0.047364f, -0.095816f, -0.046983f, 0.254384f, -0.123703f, 0.191358f, 0.674903f, -0.311971f, 1.032054f, 0.672506f, 0.009147f, 0.281933f, 0.135835f, -0.145082f, -0.392560f, -0.229593f, -0.632284f, -0.936929f, -0.916689f, -0.502247f, -0.108609f, -0.645451f, 0.242939f, -0.165902f, -1.220095f, -0.015084f, -0.300940f, -0.352557f, -0.886474f, 0.109150f, 0.398365f, 0.235757f, 0.358618f, 0.082189f, 0.268617f, 0.077955f, -0.157573f, 0.023048f, -0.346908f, 0.360128f, 0.389098f, 0.122882f, 0.675956f, 0.735857f, 0.354858f, 0.244544f, 0.631102f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "linear"; + std::string padding_mode = "border"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-1.916003f, 0.150784f, -0.179898f, 0.402727f, -0.549764f, 1.772484f, 1.014343f, 0.502823f, 0.976771f, -0.071957f, 0.519875f, 0.408665f, 1.435640f, -0.807775f, -0.181661f, -0.574026f, -0.335351f, -0.155602f, 0.348749f, 1.055618f, 0.737784f, -0.394725f, 0.597608f, 0.006105f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{-0.189838f, -1.050410f, -1.072351f, -0.930754f, -0.502573f, 0.186642f, -0.564332f, -0.042774f, -0.143740f, 1.097448f, -0.547044f, 1.127440f, -0.921224f, -1.001202f, 0.390232f, -0.698394f, 0.615509f, -0.663897f, 0.944958f, 1.161950f, 0.076823f, 0.256464f, 1.118784f, 0.711380f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-1.078787f, -1.795786f, -0.023270f, -0.113413f, 0.444460f, -0.023826f, 0.807136f, 1.011742f, 0.674182f, 0.754935f, 0.472262f, 0.494688f, 1.347277f, -0.223507f, -0.417529f, -0.160549f, -0.353331f, -0.276367f, 0.376591f, 0.571813f, 0.551111f, 0.022384f, 0.166782f, -0.109583f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "linear"; + std::string padding_mode = "border"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 3, 2}; + std::initializer_list X_data{-0.332555f, 0.980958f, 0.002632f, -1.976749f, 0.979548f, 1.109773f, -0.534887f, 0.705692f, -0.143637f, -0.600830f, 0.315853f, -0.604687f, -0.300652f, -0.375240f, 0.377196f, -0.140920f, 1.159946f, 2.364598f, 0.320719f, 0.397938f, -0.680097f, -1.201632f, 0.270077f, -0.036712f, -0.972864f, 0.792393f, -1.159168f, -0.016679f, -0.665027f, 0.809646f, -1.684452f, 0.049476f, 0.065748f, 0.279619f, -1.079668f, 0.301309f, 1.010100f, -0.119015f, -0.104838f, 0.916627f, -0.522838f, 0.485269f, -1.221088f, 2.044754f, -0.669823f, 0.128370f, 0.080480f, 0.372679f, -0.046427f, -0.732652f, -0.395790f, 0.012594f, -0.170518f, -0.706783f, -0.862588f, -1.177275f, -1.165262f, 0.914826f, -0.661128f, -0.386656f, -0.599246f, 0.544643f, 0.930679f, -1.146137f, 0.212913f, -0.022433f, 1.692830f, 0.187511f, -0.631569f, -0.311540f, -0.885167f, -0.429959f}; + std::initializer_list Grid_shape{2, 3, 3, 2, 3}; + std::initializer_list Grid_data{-0.453992f, 0.394222f, 0.755023f, -0.025610f, 0.658840f, 0.982105f, -0.642922f, -0.265292f, -1.080379f, 0.275464f, 0.855228f, -0.233029f, 0.191483f, 0.383441f, -0.025595f, 0.932929f, 0.174866f, -1.179535f, -0.990943f, -1.188918f, 0.049460f, 0.648682f, -0.158317f, 1.078936f, -0.215883f, 0.245340f, 1.082089f, 0.607310f, -0.038283f, 1.155868f, -0.716957f, 0.446971f, 0.757844f, -0.743030f, -1.127212f, 0.383835f, -0.455267f, -0.605570f, 0.238686f, -0.870514f, 1.079285f, -0.107719f, -0.384303f, 1.003178f, 0.334130f, 0.228627f, -0.573757f, 1.143690f, -0.365482f, 0.998076f, -0.088210f, 0.601965f, 0.843747f, -0.893403f, -0.799804f, -1.186625f, 0.865515f, 1.031983f, -0.438564f, -0.587735f, 0.200868f, 0.646055f, 0.296203f, -0.250092f, -0.763290f, 1.026321f, -0.777136f, -1.159559f, -0.479127f, 0.239290f, 0.446029f, 0.464001f, -0.695158f, -0.460548f, -0.533616f, -0.581111f, -1.010728f, 0.245640f, -0.348981f, -1.155007f, -0.700701f, -0.720655f, -0.517635f, -0.741485f, -0.208103f, 0.430035f, -0.971177f, -0.102798f, -0.345348f, -0.613510f, -0.266458f, -0.508597f, 0.038577f, -0.866220f, 0.227567f, 1.101759f, 0.994334f, -0.538031f, 0.369874f, -1.134245f, 1.010332f, -1.195878f, -1.072351f, -1.077155f, -1.114385f, 0.162516f, -0.317319f, 0.287217f}; + std::initializer_list Y_shape{2, 2, 3, 3, 2}; + std::initializer_list Y_data{0.517362f, 1.168304f, -0.283719f, -0.056944f, -0.345007f, -1.383013f, -0.517978f, -0.099340f, 0.531814f, -0.051495f, 0.570203f, -0.350444f, -0.195512f, 0.335075f, 0.533103f, -0.173681f, 0.110927f, 0.549661f, -0.303447f, -0.209369f, -0.479343f, 0.113517f, -0.222508f, -0.981697f, -1.000072f, 0.163343f, -0.019158f, 0.217390f, -0.442252f, -1.020732f, -0.645033f, -0.481248f, -0.359233f, -0.271288f, -0.165768f, -0.092544f, -0.219889f, 0.671201f, -0.041137f, -0.289275f, -0.022793f, -0.130253f, -0.072692f, -0.451858f, 0.402947f, 0.168711f, 0.110811f, 0.202315f, -0.200036f, -0.331588f, 0.583341f, -0.522838f, 1.010100f, -0.018650f, 1.269564f, -0.168394f, -0.209390f, 0.740205f, -0.675828f, -0.325915f, -0.404694f, 0.067064f, -0.744102f, -0.639736f, -0.416580f, -0.317643f, 0.004590f, -0.665815f, -0.163600f, -0.661128f, -0.862588f, -0.132515f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "linear"; + std::string padding_mode = "border"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.050553f, -0.825690f, -0.616085f, 0.337113f, 0.370334f, -0.105073f, -0.565382f, 0.396842f, -0.373193f, -0.780451f, -1.932970f, 1.104960f, -2.569945f, 0.661190f, -0.192302f, 0.734279f, 0.351872f, -1.068136f, 0.173665f, -0.778153f, -0.981877f, 1.485344f, 0.431733f, 0.428167f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{-0.330875f, 0.589988f, 0.011588f, -1.144325f, -1.038357f, 0.435055f, -1.053243f, -0.957144f, -0.715458f, 1.143742f, -0.341215f, -0.494762f, -0.810255f, 0.767649f, -0.193763f, 0.231402f, 0.286668f, 0.338432f, 0.768106f, 0.062272f, 0.124125f, -0.077928f, -0.932481f, -0.274618f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{0.204265f, -0.447104f, 0.027635f, -0.050553f, 0.370334f, -0.248695f, -1.306797f, -0.073120f, -1.391077f, -0.565382f, -1.932970f, -0.419110f, 0.351872f, 0.030903f, -0.124253f, 0.565919f, 0.276202f, -1.171718f, 0.431733f, 0.001712f, 0.689913f, 1.386595f, 0.443614f, -0.505878f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "linear"; + std::string padding_mode = "border"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 3, 2}; + std::initializer_list X_data{-0.727099f, 0.057663f, -0.548384f, 0.078163f, -0.133679f, 0.211872f, 0.271687f, -1.221973f, -2.630687f, -0.558102f, -0.327183f, 0.039894f, 1.222102f, 0.144418f, 0.696676f, -2.231791f, 0.910544f, 2.749837f, -0.354036f, -0.106102f, 2.453576f, 0.332319f, -1.743712f, 1.416859f, 0.260041f, -1.179930f, 0.407328f, 0.375476f, 2.028488f, 0.174825f, -1.467126f, 0.079045f, 0.870076f, -0.895165f, 0.631429f, 0.358222f, 1.484120f, -0.622331f, 0.727481f, 0.644213f, 1.299103f, -0.378573f, 1.360908f, 0.905514f, 0.180065f, 0.972162f, 1.246238f, -0.537204f, -1.241497f, -0.772822f, -0.149044f, -1.642060f, 0.120091f, 0.937023f, 0.422106f, 0.652040f, 0.045585f, -1.089530f, 0.356099f, 0.536075f, -1.840257f, -1.035736f, 0.348653f, 0.187942f, 0.150011f, 0.521798f, 1.271739f, 0.977495f, 0.811927f, 0.641729f, 0.964401f, -0.693074f}; + std::initializer_list Grid_shape{2, 3, 3, 2, 3}; + std::initializer_list Grid_data{1.017692f, -0.818194f, 0.525611f, -0.556812f, -0.124601f, 1.120205f, 0.153552f, -1.144168f, 1.103147f, -0.050771f, -0.600881f, -0.633732f, 1.029039f, 0.020253f, 0.662802f, 0.788674f, -0.465758f, 0.101853f, -0.776226f, 1.002064f, -0.634553f, 0.797064f, 0.304043f, 0.740241f, -0.845484f, -0.037319f, 0.621792f, -0.047898f, -0.017218f, 0.584766f, -0.896882f, -0.240587f, 0.546590f, 0.588539f, 1.114539f, -0.237379f, 0.284327f, -0.590432f, -0.201402f, -0.602420f, 0.889284f, 0.007310f, 0.488176f, 0.660055f, 0.223618f, 0.127703f, -0.087830f, -1.016490f, 0.193341f, -0.265853f, -1.008634f, 1.118021f, -0.127930f, -0.598904f, -1.168221f, -1.105256f, 0.456964f, -0.547805f, -0.518368f, -0.694346f, 0.968648f, -0.288466f, 0.777819f, 0.952657f, -0.930362f, 0.895254f, -0.229149f, 1.149323f, 0.612939f, -1.162419f, 0.222934f, 0.421831f, -0.435327f, 0.909973f, -0.993750f, -0.380767f, 1.143396f, 1.171977f, 0.599451f, -0.716336f, -1.032482f, -0.975683f, -0.299985f, 0.679795f, 0.379920f, -0.145729f, 1.079221f, 0.942322f, -0.560859f, -0.519668f, -0.014079f, 0.249021f, -0.008590f, 0.463277f, 0.827937f, -0.216375f, 0.589310f, 0.163207f, 0.460623f, 0.494016f, -0.320739f, -0.535032f, 0.512922f, -0.768302f, 0.630003f, -0.769945f, 0.823242f, 0.481487f}; + std::initializer_list Y_shape{2, 2, 3, 3, 2}; + std::initializer_list Y_data{-0.144687f, 0.794879f, 0.517780f, -0.372025f, -2.071523f, -0.953122f, -0.143000f, 0.040151f, 0.511071f, -0.723342f, 0.441486f, 0.101130f, -0.668215f, -0.313612f, 0.918245f, -0.165560f, -0.141496f, -0.002992f, -0.187333f, 0.433250f, -0.456623f, -0.082449f, -0.849978f, -0.635311f, -1.562003f, -0.323540f, 0.716348f, 0.089914f, 0.085623f, 0.617075f, -0.522245f, 2.013170f, 0.249061f, 0.948093f, 0.518262f, 0.230788f, -0.422900f, 1.315807f, -1.265941f, -0.772822f, 0.375354f, 0.159706f, 1.190603f, 0.217497f, -0.622331f, -0.640623f, -1.324261f, -0.126419f, 0.497220f, -0.421485f, -0.512049f, 0.218454f, -0.680520f, 0.432900f, 0.292848f, 0.338349f, 0.787015f, 0.977495f, 0.494135f, 0.649655f, 0.367739f, 0.766775f, 0.652040f, 1.018832f, 0.738819f, 0.107251f, 0.287288f, 0.515065f, 0.300961f, -0.279154f, 0.866776f, 0.738188f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "linear"; + std::string padding_mode = "reflection"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.599439f, 0.317612f, -0.294302f, -0.530613f, 0.754687f, 0.092241f, -1.009405f, -1.155944f, 0.336327f, 0.159353f, -1.134330f, 0.510271f, 0.271972f, 1.301884f, 1.027400f, 1.193876f, 0.304363f, 1.027256f, 0.186801f, 0.719412f, -0.310900f, -1.123812f, -0.312771f, 2.729156f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.853801f, 0.833200f, -0.477474f, 0.131677f, 0.571825f, 0.858708f, -1.120796f, 1.194690f, -0.301706f, 0.488934f, -0.745307f, -0.923452f, -0.812682f, 0.707226f, -0.591920f, 0.697573f, 0.362777f, 0.477332f, -0.266909f, -0.379588f, -0.561456f, -0.670762f, 1.106438f, -0.065215f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{0.031577f, -0.232574f, 0.133168f, 0.515460f, 0.063332f, -0.470541f, 0.353729f, 0.159106f, 0.163701f, -0.770097f, -0.133556f, -0.925350f, 0.568498f, 0.636194f, 0.976680f, 0.921805f, 0.684184f, 1.189063f, -0.133022f, 0.070598f, 0.388079f, -0.232737f, 0.042589f, -0.965013f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "linear"; + std::string padding_mode = "reflection"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 3, 2}; + std::initializer_list X_data{-0.441629f, 0.199148f, 1.214051f, -0.000869f, 0.863692f, -0.067719f, -0.621662f, 0.235179f, 0.691041f, 0.176564f, 0.036477f, -0.085879f, 0.785440f, -1.837889f, -0.300151f, -1.710413f, 0.484432f, 2.160478f, -0.049246f, 0.372475f, -1.060470f, -1.000841f, -0.473439f, 0.963055f, 0.174518f, 0.932434f, 0.039338f, -0.343549f, -1.446623f, -0.673622f, 0.520395f, -0.279228f, -0.367065f, -0.871085f, 0.649273f, -0.835047f, 1.063542f, -1.829784f, 1.476173f, -1.048210f, -1.127299f, 1.204756f, -0.998390f, -1.014054f, -1.032717f, 0.977184f, 0.959897f, -0.749289f, 0.784492f, 1.343993f, 1.291144f, 0.099496f, 2.086763f, 0.529948f, -2.296640f, 0.570701f, 0.491216f, -0.003836f, -0.591929f, -0.076994f, 1.239698f, -0.888840f, 0.623497f, 0.769879f, 2.240972f, -2.081689f, 0.798466f, 1.207944f, -0.486804f, -0.488222f, -0.746382f, -0.220282f}; + std::initializer_list Grid_shape{2, 3, 3, 2, 3}; + std::initializer_list Grid_data{-0.169044f, 0.178997f, 1.112567f, -0.825642f, -0.359793f, 0.170758f, -0.081412f, 0.319486f, 0.630993f, -0.493702f, 0.093438f, 1.085657f, -0.679024f, -0.813753f, -0.920282f, 0.717311f, -1.100678f, -0.583561f, 0.810473f, -0.719377f, 0.975857f, -0.560957f, 0.189840f, 0.157082f, -0.029434f, 0.747413f, 1.019186f, -0.749235f, 0.673000f, 0.320624f, -0.022362f, -0.839050f, 0.355966f, 0.871005f, -1.030007f, -1.108265f, -1.179701f, 0.277273f, -0.344802f, -0.372753f, 1.117390f, -0.306079f, -0.762057f, 0.107942f, -0.658634f, -0.351593f, 0.633875f, 0.276953f, -0.823465f, 1.142446f, 0.811875f, -0.818022f, 0.522699f, 0.493103f, -0.861061f, -0.843352f, -0.993629f, 0.534540f, 0.209070f, 0.507143f, -0.527071f, 0.902309f, 0.153227f, -0.957513f, -0.302041f, 0.612404f, 0.263859f, -0.183579f, -0.838388f, -0.746482f, 1.035039f, -0.687403f, 0.850371f, -0.401659f, 0.011995f, -1.168548f, -0.390077f, 1.011575f, -1.077360f, 0.603794f, -1.009901f, 0.175023f, -1.087964f, -0.949961f, -0.968757f, -0.416100f, 0.163389f, -0.879807f, 0.304124f, 0.722748f, 0.978239f, 1.062535f, 0.790067f, -0.353356f, -0.110591f, 1.061730f, 0.596951f, -0.318231f, 0.905999f, -1.048710f, 1.027042f, 0.671407f, -0.880154f, -0.978736f, 0.938431f, 1.183815f, 0.104716f, -0.468883f}; + std::initializer_list Y_shape{2, 2, 3, 3, 2}; + std::initializer_list Y_data{-0.414201f, 0.167816f, -0.042305f, -0.423495f, -0.101419f, 0.120192f, -1.543294f, 0.344146f, 0.709278f, 0.248721f, -0.269138f, 0.158159f, 0.659876f, 0.226329f, 0.874509f, 0.240959f, 0.412611f, 0.225904f, -0.448580f, 0.057703f, -0.426538f, -0.401142f, -0.147435f, 0.401852f, -0.355426f, -0.286018f, -0.219687f, -0.564205f, 0.282723f, 0.363522f, -0.543706f, -0.787722f, -0.692217f, -0.594894f, 0.091005f, -0.328214f, 0.919003f, 0.408116f, 0.631220f, 0.303619f, -0.197801f, -0.308153f, 0.094457f, 1.027881f, -0.077622f, -0.597219f, -0.661449f, 0.947805f, 0.279352f, 0.828246f, 0.571205f, 1.646163f, 0.714257f, 0.049881f, -1.680014f, -0.056047f, 0.892393f, 0.250564f, 0.138843f, 0.178706f, 0.161286f, 0.036891f, -0.141908f, -0.510903f, 0.733949f, -0.112944f, -0.581858f, -0.269439f, 0.056781f, 0.200325f, 0.814038f, 0.277386f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "linear"; + std::string padding_mode = "reflection"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.173652f, -1.513725f, -0.704586f, -1.952375f, -0.699404f, -0.806298f, 1.640852f, -0.138969f, -0.695411f, -1.352111f, 0.568797f, -0.564294f, -0.056468f, 0.641604f, -0.438370f, 0.450167f, -1.091401f, 1.669729f, -0.908544f, 0.244467f, 0.172109f, 1.156741f, -0.617128f, 1.155460f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.252250f, -0.151452f, 0.824706f, -0.588292f, -0.591147f, -0.155082f, -0.732938f, 0.457493f, -0.439559f, 0.492330f, 0.696447f, 0.700722f, -0.220298f, 0.654884f, -0.635434f, -1.195619f, -0.114204f, -0.870080f, -0.929674f, 0.305035f, 1.025429f, -0.472240f, -0.067881f, -0.869393f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-1.538390f, -1.565293f, -0.581079f, -0.701030f, -0.725252f, -0.806298f, -0.850602f, -0.281588f, -0.151944f, 0.172138f, 0.177246f, -0.564294f, -0.316822f, -0.056468f, 0.212846f, -0.737167f, 0.585773f, 0.245182f, -0.111277f, -0.908544f, -0.463717f, -0.189009f, 0.510522f, -0.410307f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "linear"; + std::string padding_mode = "reflection"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 3, 2}; + std::initializer_list X_data{1.179856f, 1.432512f, 1.016210f, -0.661096f, 0.335863f, 0.565957f, -0.517555f, 2.232456f, -0.615173f, -0.073628f, -0.260768f, -1.952025f, 0.304237f, 0.902323f, -0.485170f, 0.781595f, -1.777093f, -0.274107f, -1.030698f, 0.181435f, 1.947646f, 1.007702f, -0.100718f, 0.154090f, -0.483193f, 1.565921f, -0.932274f, 0.313820f, -0.439116f, -0.411861f, -0.821795f, -1.685022f, -0.013518f, 0.519914f, -0.175407f, -0.507962f, 0.050913f, 0.981904f, 1.087165f, 1.758657f, 0.075954f, -0.481552f, 0.085590f, 0.537831f, -0.419622f, -1.756791f, 1.324879f, -0.267061f, -0.683518f, 0.605393f, 0.041004f, -0.756742f, 0.744950f, -0.508619f, -0.594679f, -1.165646f, -0.699604f, -0.271502f, 0.437731f, -2.206233f, 1.088781f, -0.629873f, -0.904741f, -1.233533f, 2.466710f, -0.117309f, -0.684130f, 0.598811f, 0.288846f, -1.195569f, 0.935300f, 0.962852f}; + std::initializer_list Grid_shape{2, 3, 3, 2, 3}; + std::initializer_list Grid_data{0.625842f, 0.210304f, -0.725943f, -0.553764f, -0.182412f, -0.296478f, -0.254040f, -0.820211f, 0.869312f, 0.622346f, 0.236815f, 0.271706f, 0.140482f, 0.897281f, 0.271537f, 0.182799f, -0.659653f, 0.400310f, -1.122656f, 0.378466f, -1.040147f, -0.496646f, 0.633526f, -0.714734f, 0.955528f, -0.663024f, 1.136629f, 0.369854f, -0.520025f, 0.731855f, -1.062711f, -0.760189f, -0.751812f, 0.157968f, 0.117892f, -1.032129f, 1.157953f, -0.001147f, -0.640796f, 0.028663f, -0.515104f, 0.331070f, 0.434411f, -0.340393f, 0.069958f, 0.714010f, -0.780518f, -0.267586f, -0.177029f, -0.793935f, 0.097737f, 0.044103f, -0.969274f, 0.246164f, 1.145360f, 0.638273f, -0.650926f, 1.098440f, -0.824873f, -0.610135f, 0.529312f, 0.954650f, 1.145143f, 1.033109f, -0.660775f, 0.274592f, -0.753497f, 0.026500f, 0.994206f, 0.590870f, -1.108049f, -0.516447f, -1.012489f, 0.565286f, -0.152334f, -0.877228f, -0.383453f, 0.393797f, 0.111096f, 1.125969f, -0.015932f, 0.377468f, -0.363512f, 0.143194f, 0.042988f, 1.030777f, 0.502813f, -0.683870f, -1.066269f, -1.141727f, -0.435790f, 0.155118f, 1.128919f, -0.117905f, 0.469189f, 0.609870f, -0.919201f, -0.992659f, 0.454699f, 0.559331f, -0.558762f, 0.188050f, -1.174933f, 0.015126f, 0.294147f, 0.011359f, -0.190476f, 0.499476f}; + std::initializer_list Y_shape{2, 2, 3, 3, 2}; + std::initializer_list Y_data{-0.274014f, 0.145076f, 0.451342f, -0.273219f, -1.128307f, 0.962473f, 0.629978f, 0.370138f, 0.901663f, 0.778787f, 1.179856f, 0.014218f, -0.634683f, 0.585419f, 0.972130f, 1.911376f, 0.389205f, 0.849839f, 0.738424f, 0.054296f, -1.034114f, 0.096287f, -0.408114f, -0.474491f, 0.784791f, 0.001762f, -1.672976f, -1.127656f, -1.030698f, 1.105979f, 0.979492f, -0.258014f, 0.693543f, 1.010218f, -0.008927f, -0.078404f, -0.384825f, 0.944247f, -0.508619f, 0.548774f, 0.068986f, 0.881841f, 0.869967f, -0.274754f, 0.337312f, -0.374188f, 0.161655f, 0.050913f, 0.146763f, 0.119233f, -0.438980f, 0.228062f, -0.187221f, -0.376543f, -2.077576f, -1.120214f, 0.962852f, -0.133462f, 0.314542f, -1.044921f, 1.568017f, -0.060947f, 0.838264f, -0.652863f, 0.978122f, -0.594679f, 0.366536f, 0.596221f, -0.120431f, -0.435362f, -0.328892f, -0.434798f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "cubic"; + std::string padding_mode = "zeros"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{0.741614f, -1.612838f, 0.274100f, -0.685296f, -0.032079f, -0.246424f, 0.089412f, -0.776545f, -0.152179f, 0.312533f, -1.503701f, -0.720829f, 0.877575f, 0.407229f, -0.889951f, 0.603605f, -0.140859f, 2.032775f, -0.520668f, 1.063163f, -1.008883f, 0.194195f, -0.303240f, -0.967884f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{-0.932019f, -0.034394f, 0.554511f, 0.484230f, 0.141120f, 0.485083f, -0.836516f, 0.999462f, 0.026764f, 0.775689f, 0.265464f, -0.133497f, 0.514005f, 1.139161f, 1.183700f, -1.010095f, 0.072779f, -0.862052f, 0.699178f, 0.861473f, -0.842637f, -0.069355f, 0.830374f, 0.793568f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{0.274192f, -0.348792f, -0.238780f, -0.048938f, -0.195915f, -0.488976f, -0.104505f, -0.351103f, -0.583059f, -1.533095f, -1.141282f, 0.187052f, 1.668728f, 0.345182f, 0.682750f, 1.893112f, -0.775917f, 1.920082f, -0.889375f, 1.071508f, 0.336517f, -0.933740f, -0.981629f, -0.893789f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "cubic"; + std::string padding_mode = "zeros"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{0.333395f, 0.977190f, 0.214232f, 0.363731f, -1.352515f, -0.980304f, -0.354887f, -0.481711f, -0.607915f, -0.309748f, 2.262781f, 0.963363f, 1.997079f, 0.987449f, -0.537662f, 1.011585f, 0.822184f, 0.567108f, 0.135401f, -0.943315f, -0.614181f, 0.030652f, 0.914757f, 0.971777f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{-0.487111f, 0.913573f, 0.641905f, -0.093110f, 0.512522f, 0.358369f, 0.655341f, -0.964320f, 0.370929f, -1.136512f, -0.789199f, -0.447185f, -0.116915f, -1.132446f, 0.029865f, 0.191588f, -0.476239f, 0.389224f, 1.048588f, -0.204978f, -0.639094f, -1.062994f, -0.876243f, -0.663705f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-1.051920f, 0.501832f, -0.508839f, 0.563480f, 0.297178f, 0.246571f, 1.781955f, -0.353574f, 0.481200f, -0.258839f, -0.145200f, -0.469558f, 0.624262f, 0.351267f, 0.180256f, 0.571859f, 0.903895f, 1.383745f, -0.081406f, 0.133665f, 0.348401f, -0.164219f, 0.138237f, 0.203282f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "cubic"; + std::string padding_mode = "border"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.480448f, 0.682093f, 0.237716f, -1.234307f, 2.139750f, 2.410321f, 0.491472f, -0.553422f, 0.032129f, -0.162503f, 0.144036f, -1.889875f, -0.293944f, -1.390146f, -1.552136f, 1.604720f, -1.707202f, 0.182427f, -0.631000f, 0.196649f, 0.427711f, -0.014224f, -1.319834f, -2.703346f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.503717f, 0.572989f, 0.179517f, -0.060398f, 0.503876f, 0.288627f, -1.148268f, 0.194010f, -0.532910f, -0.636357f, 0.464076f, 0.245386f, 0.203212f, -0.569260f, 0.554489f, 1.126118f, 0.146805f, 0.493232f, -1.052794f, 0.713394f, 0.416866f, 0.540634f, 0.500415f, -0.315629f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{0.885659f, -0.722912f, -0.180469f, 0.697015f, -0.322127f, -0.292851f, -0.867861f, -0.047527f, -0.447720f, 0.028100f, 0.191874f, -0.378776f, -0.321888f, -0.277691f, -0.037604f, -1.766707f, 0.320836f, 0.415106f, 0.179209f, -2.609096f, -0.929794f, -0.788240f, -1.212243f, 0.337704f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "cubic"; + std::string padding_mode = "border"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.924256f, -2.309784f, 1.272769f, 0.548427f, -1.478527f, -3.472946f, -1.252325f, 0.268589f, 0.326270f, 0.105016f, 0.515184f, -0.951158f, -0.658693f, -2.018776f, 0.981625f, -0.401504f, 1.560519f, -0.129836f, -1.876357f, 0.511516f, -1.825582f, 0.358958f, -0.805392f, -1.409127f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.874856f, -1.090775f, 1.169192f, 0.447098f, 0.583418f, 0.267395f, 0.788144f, 1.129706f, -0.102229f, -0.984624f, 1.101916f, -0.253070f, -0.578731f, 0.738703f, 0.669694f, 0.160659f, -0.075327f, -0.229561f, 1.100291f, 0.731142f, 0.714643f, 0.765214f, -0.628031f, 0.250554f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-2.647128f, -2.154235f, -0.768645f, -3.893546f, -1.698376f, -0.114530f, 0.458115f, -0.696657f, -0.370692f, -1.169692f, -0.754730f, 0.320002f, 1.683550f, -0.301499f, -0.176003f, -0.236653f, -0.278257f, 1.480160f, -0.700350f, 0.095525f, -0.891605f, -1.569065f, -1.633715f, -1.535763f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "cubic"; + std::string padding_mode = "reflection"; + int64_t align_corners = 1; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.328038f, -0.658850f, -0.054298f, 0.012663f, -0.077366f, 0.644305f, -1.262985f, 0.922028f, 0.189962f, 0.518836f, 1.168413f, -0.286220f, 0.431207f, -0.295352f, -0.357675f, -0.311715f, 0.839514f, -0.651820f, -0.283934f, 0.430508f, 0.206334f, 0.765966f, -1.144732f, -0.507045f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{-0.372000f, -1.056863f, -0.360826f, -0.268314f, 0.691035f, -0.595044f, 0.720198f, 0.166462f, -0.201118f, -1.069416f, 1.184721f, -0.213980f, 0.755038f, -0.620722f, -1.168597f, -0.956522f, -0.614982f, -0.382162f, -0.169456f, 1.000817f, -1.106710f, 0.598940f, 1.009714f, 0.007723f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{-0.403118f, -0.158055f, -0.496030f, 0.161379f, -0.440603f, -0.193607f, -0.746082f, -0.076433f, 0.751030f, 0.360851f, -0.488453f, 0.664305f, -0.259139f, 0.411796f, -0.156648f, 0.281569f, 0.437515f, -0.313812f, 0.573781f, -0.265706f, 0.200380f, -0.906155f, -0.724311f, 0.760352f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) { + OpTester test("GridSample", 20); + std::string mode = "cubic"; + std::string padding_mode = "reflection"; + int64_t align_corners = 0; + std::initializer_list X_shape{2, 2, 3, 2}; + std::initializer_list X_data{-0.290962f, 0.867797f, -0.085436f, -1.597520f, 0.695524f, 0.838739f, 0.513032f, 0.166242f, -0.546135f, -0.780313f, -0.512993f, -0.449479f, 1.594718f, 0.953375f, 0.692587f, -0.798364f, -0.128799f, -0.456210f, 2.098909f, -1.561220f, 1.713821f, -0.701970f, -0.287280f, -1.708048f}; + std::initializer_list Grid_shape{2, 3, 2, 2}; + std::initializer_list Grid_data{0.934471f, 0.728362f, -0.458301f, -1.040800f, 0.157908f, 0.753451f, -0.122762f, 0.100970f, 0.889432f, 0.495471f, 0.897108f, 0.176205f, 0.134514f, -0.287037f, -0.202498f, -0.637759f, 0.802292f, 1.094459f, 0.445338f, 0.034096f, -0.396126f, -1.184798f, -0.222199f, -0.851887f}; + std::initializer_list Y_shape{2, 2, 3, 2}; + std::initializer_list Y_data{1.037788f, -0.275160f, 0.953595f, -0.518196f, 0.118127f, -1.525148f, -0.413483f, 0.696689f, -0.450182f, -0.696169f, -0.561886f, -0.828986f, 0.343953f, 1.379632f, -0.417260f, -0.781500f, 1.666511f, 1.599268f, 0.106200f, 1.088396f, -2.079140f, -0.612122f, 1.822402f, 1.173807f}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); + test.AddAttribute("mode", mode); + test.AddAttribute("padding_mode", padding_mode); + test.AddAttribute("align_corners", align_corners); + test.AddOutput("Y", Y_shape, Y_data); + test.ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py new file mode 100644 index 0000000000000..e4d58e79243ef --- /dev/null +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# This code is used to generate the test cases for the GridSample operator +# in onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc + +import torch + +# Define the input dimensions +N, C, D, H, W = 2, 2, 3, 3, 2 + +# Define the modes, padding modes, and whether to align corners +modes = ["nearest", "bilinear", "bicubic"] +padding_modes = ["zeros", "border", "reflection"] +align_corners_options = [True, False] + +# Loop over the combinations of parameters +torch.manual_seed(0) +for opset_version in [16, 20]: + for mode in modes: + for padding_mode in padding_modes: + for align_corners in align_corners_options: + for ndim in [4, 5]: + if ndim == 5 and mode == "bicubic": + continue + + if opset_version < 20 and ndim == 5: + continue + + # Create a random input tensor with the specified dimensions + input_shape = (N,) + (C,) + (((D, H, W)) if ndim == 5 else ((H, W))) + input_tensor = torch.randn(*input_shape) + + # Create a random grid tensor with the specified dimensions + grid_shape = (N,) + (((D, H, W)) if ndim == 5 else ((H, W))) + (ndim - 2,) + + # Between -1.2 to + 1.2 + grid_tensor = torch.rand(*grid_shape) * 2.4 - 1.2 + + # Apply grid_sample + output_tensor = torch.nn.functional.grid_sample( + input_tensor, grid_tensor, mode=mode, padding_mode=padding_mode, align_corners=align_corners + ) + + X_data_str = "{" + ", ".join([f"{x:.6f}f" for x in input_tensor.numpy().flatten()]) + "}" + Grid_data_str = "{" + ", ".join([f"{x:.6f}f" for x in grid_tensor.numpy().flatten()]) + "}" + + Y_shape = output_tensor.shape + Y_data_str = "{" + ", ".join([f"{x:.6f}f" for x in output_tensor.numpy().flatten()]) + "}" + + onnx_mode = mode + if opset_version >= 20: + if mode == "bilinear": + onnx_mode = "linear" + elif mode == "bicubic": + onnx_mode = "cubic" + + onnx_align_corners = 1 if align_corners else 0 + + test_name = f"test_grid_sample_{opset_version}_{ndim}D_{mode}_{padding_mode}_{'align_corners' if align_corners else 'no_align_corners'}" + print(f"TEST(GridsampleTest, {test_name}) {{") + print(f'OpTester test("GridSample", {opset_version});') + print(f'std::string mode = "{onnx_mode}";') + print(f'std::string padding_mode = "{padding_mode}";') + print(f"int64_t align_corners = {onnx_align_corners};") + print(f"std::initializer_list X_shape {{ {', '.join(map(str, input_shape))} }};") + print(f"std::initializer_list X_data { X_data_str };") + print(f"std::initializer_list Grid_shape {{ {', '.join(map(str, grid_shape))} }};") + print(f"std::initializer_list Grid_data { Grid_data_str };") + print(f"std::initializer_list Y_shape {{ {', '.join(map(str, Y_shape))} }};") + print(f"std::initializer_list Y_data { Y_data_str };") + + print('test.AddInput("X", X_shape, X_data);') + print('test.AddInput("Grid", Grid_shape, Grid_data);') + print('test.AddAttribute("mode", mode);') + print('test.AddAttribute("padding_mode", padding_mode);') + print('test.AddAttribute("align_corners", align_corners);') + print('test.AddOutput("Y", Y_shape, Y_data);') + print("test.Run();") + print("}") + print("\n") diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index c552ec3aea72d..bfdc0b1d26953 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -239,24 +239,6 @@ "^test_ai_onnx_ml_label_encoder_string_int_no_default", "^test_ai_onnx_ml_label_encoder_tensor_mapping", "^test_ai_onnx_ml_label_encoder_tensor_value_only_mapping", - "^test_gridsample_aligncorners_true", - "^test_gridsample_bicubic_align_corners_0_additional_1", - "^test_gridsample_bicubic_align_corners_1_additional_1", - "^test_gridsample_bicubic", - "^test_gridsample_bilinear_align_corners_0_additional_1", - "^test_gridsample_bilinear_align_corners_1_additional_1", - "^test_gridsample_bilinear", - "^test_gridsample_border_padding", - "^test_gridsample", - "^test_gridsample_nearest_align_corners_0_additional_1", - "^test_gridsample_nearest_align_corners_1_additional_1", - "^test_gridsample_nearest", - "^test_gridsample_reflection_padding", - "^test_gridsample_volumetric_bilinear_align_corners_0", - "^test_gridsample_volumetric_bilinear_align_corners_1", - "^test_gridsample_volumetric_nearest_align_corners_0", - "^test_gridsample_volumetric_nearest_align_corners_1", - "^test_gridsample_zeros_padding", "^test_image_decoder_decode_bmp_rgb", "^test_image_decoder_decode_jpeg2k_rgb", "^test_image_decoder_decode_jpeg_bgr",