diff --git a/examples/BuddyMobileNetV3/CMakeLists.txt b/examples/BuddyMobileNetV3/CMakeLists.txt index a38d56bdd5..c5a932c673 100644 --- a/examples/BuddyMobileNetV3/CMakeLists.txt +++ b/examples/BuddyMobileNetV3/CMakeLists.txt @@ -71,5 +71,5 @@ SET_TARGET_PROPERTIES(MOBILENETV3 PROPERTIES LINKER_LANGUAGE C) add_executable(buddy-mobilenetv3-run buddy-mobilenetv3-main.cpp) target_link_directories(buddy-mobilenetv3-run PRIVATE ${LLVM_LIBRARY_DIR}) -set(BUDDY_MOBILENETV3_LIBS MOBILENETV3 mlir_c_runner_utils ${PNG_LIBRARIES}) +set(BUDDY_MOBILENETV3_LIBS MOBILENETV3 mlir_c_runner_utils BuddyLibDIP ${PNG_LIBRARIES}) target_link_libraries(buddy-mobilenetv3-run ${BUDDY_MOBILENETV3_LIBS}) diff --git a/examples/BuddyMobileNetV3/buddy-mobilenetv3-main.cpp b/examples/BuddyMobileNetV3/buddy-mobilenetv3-main.cpp index 067f07168a..a9eb1a2aa1 100644 --- a/examples/BuddyMobileNetV3/buddy-mobilenetv3-main.cpp +++ b/examples/BuddyMobileNetV3/buddy-mobilenetv3-main.cpp @@ -15,6 +15,7 @@ //===----------------------------------------------------------------------===// #include +#include #include #include #include @@ -27,13 +28,13 @@ #include constexpr size_t ParamsSize = 2554968; -const std::string ImgName = "dog-224*224.png"; +const std::string ImgName = "dog.png"; // Declare the mobilenet C interface. extern "C" void _mlir_ciface_forward(MemRef *output, MemRef *arg0, MemRef *arg1, - dip::Image *input); + MemRef *input); /// Print [Log] label in bold blue format. void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } @@ -115,7 +116,10 @@ int main() { std::string mobilenetDir = getenv("MOBILENETV3_EXAMPLE_PATH"); std::string imgPath = mobilenetDir + "/images/" + ImgName; dip::Image input(imgPath, dip::DIP_RGB, true /* norm */); - + MemRef inputResize = dip::Resize4D_NCHW( + &input, dip::INTERPOLATION_TYPE::BILINEAR_INTERPOLATION, + {1, 3, 224, 224} /*{image_cols, image_rows}*/); + MemRef output(sizesOutput); // Load model parameters from the specified file. @@ -126,7 +130,7 @@ int main() { loadParameters(paramsDir, intDir, paramsContainerf32, ParamsContainerInt64); // Call the forward function of the model. _mlir_ciface_forward(&output, ¶msContainerf32, &ParamsContainerInt64, - &input); + &inputResize); auto out = output.getData(); softmax(out, 1000); diff --git a/examples/DIPDialect/CMakeLists.txt b/examples/DIPDialect/CMakeLists.txt index 2b4f970bdd..2f897ad633 100644 --- a/examples/DIPDialect/CMakeLists.txt +++ b/examples/DIPDialect/CMakeLists.txt @@ -26,5 +26,8 @@ target_link_libraries(rotation2D ${DIP_LIBS}) add_executable(resize2D resize2D.cpp) target_link_libraries(resize2D ${DIP_LIBS}) -add_executable(resize4D resize4D.cpp) -target_link_libraries(resize4D ${DIP_LIBS}) +add_executable(resize4D_nhwc resize4D_nhwc.cpp) +target_link_libraries(resize4D_nhwc ${DIP_LIBS}) + +add_executable(resize4D_nchw resize4D_nchw.cpp) +target_link_libraries(resize4D_nchw ${DIP_LIBS}) diff --git a/examples/DIPDialect/resize4D_nchw.cpp b/examples/DIPDialect/resize4D_nchw.cpp new file mode 100644 index 0000000000..95d77cc27d --- /dev/null +++ b/examples/DIPDialect/resize4D_nchw.cpp @@ -0,0 +1,58 @@ +//====- resize4D.cpp - Example of buddy-opt tool =============================// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements a 4D resize example with dip.resize_4d operation. +// The dip.resize_4d operation will be compiled into an object file with the +// buddy-opt tool. +// This file will be linked with the object file to generate the executable +// file. +// +//===----------------------------------------------------------------------===// +#include "buddy/DIP/imgcodecs/loadsave.h" +#include +#include +#include +#include +#include +#include + +using namespace std; + +void testImplementation(int argc, char *argv[]) { + // Read as colar image. + dip::Image inputBatch(argv[1], dip::DIP_RGB, true); + + // Note : Both values in output image dimensions and scaling ratios must be + // positive numbers. + MemRef output = dip::Resize4D_NCHW( + &inputBatch, dip::INTERPOLATION_TYPE::BILINEAR_INTERPOLATION, + {1, 3, 224, 224} /*{image_cols, image_rows}*/); + + // Define Img with the output of Resize4D. + intptr_t outSizes[3] = {output.getSizes()[2], output.getSizes()[3], + output.getSizes()[1]}; + + Img outputImageResize4D(output.getData(), outSizes); + + // dip::imwrite(argv[2], outputImageResize4D); + + return; +} + +int main(int argc, char *argv[]) { + testImplementation(argc, argv); + return 0; +} diff --git a/examples/DIPDialect/resize4D.cpp b/examples/DIPDialect/resize4D_nhwc.cpp similarity index 97% rename from examples/DIPDialect/resize4D.cpp rename to examples/DIPDialect/resize4D_nhwc.cpp index 2930c0610f..affb8a8a09 100644 --- a/examples/DIPDialect/resize4D.cpp +++ b/examples/DIPDialect/resize4D_nhwc.cpp @@ -40,7 +40,7 @@ void testImplementation(int argc, char *argv[]) { // Note : Both values in output image dimensions and scaling ratios must be // positive numbers. - MemRef output = dip::Resize4D( + MemRef output = dip::Resize4D_NHWC( &inputBatch, dip::INTERPOLATION_TYPE::BILINEAR_INTERPOLATION, {1, 224, 224, 3} /*{image_cols, image_rows}*/); diff --git a/frontend/Interfaces/buddy/DIP/DIP.h b/frontend/Interfaces/buddy/DIP/DIP.h index c7cbd2bf82..8598b61fc5 100644 --- a/frontend/Interfaces/buddy/DIP/DIP.h +++ b/frontend/Interfaces/buddy/DIP/DIP.h @@ -23,6 +23,8 @@ #include "buddy/Core/Container.h" #include "buddy/DIP/ImageContainer.h" +#include "buddy/DIP/ImgContainer.h" +#include #include namespace dip { // Availale types of boundary extrapolation techniques provided in DIP dialect. @@ -70,19 +72,27 @@ void _mlir_ciface_resize_2d_nearest_neighbour_interpolation( float verticalScalingFactor, MemRef *output); // Declare the Resize4D C interface. -void _mlir_ciface_resize_4d_nearest_neighbour_interpolation( +void _mlir_ciface_resize_4d_nhwc_nearest_neighbour_interpolation( Img *input, float horizontalScalingFactor, float verticalScalingFactor, MemRef *output); +void _mlir_ciface_resize_4d_nchw_nearest_neighbour_interpolation( + dip::Image *input, float horizontalScalingFactor, + float verticalScalingFactor, MemRef *output); + void _mlir_ciface_resize_2d_bilinear_interpolation( Img *input, float horizontalScalingFactor, float verticalScalingFactor, MemRef *output); // Declare the Resize4D C interface. -void _mlir_ciface_resize_4d_bilinear_interpolation( +void _mlir_ciface_resize_4d_nhwc_bilinear_interpolation( Img *input, float horizontalScalingFactor, float verticalScalingFactor, MemRef *output); +void _mlir_ciface_resize_4d_nchw_bilinear_interpolation( + dip::Image *input, float horizontalScalingFactor, + float verticalScalingFactor, MemRef *output); + // Declare the Morphology 2D C interface. void _mlir_ciface_erosion_2d_constant_padding( Img input, MemRef *kernel, MemRef *output, @@ -213,17 +223,38 @@ inline MemRef Resize2D_Impl(Img *input, } // Helper function for applying 4D resize operation on images. -inline MemRef Resize4D_Impl(Img *input, - INTERPOLATION_TYPE type, - std::vector scalingRatios, - intptr_t outputSize[4]) { +inline MemRef Resize4D_NHWC_Impl(Img *input, + INTERPOLATION_TYPE type, + std::vector scalingRatios, + intptr_t outputSize[4]) { MemRef output(outputSize); if (type == INTERPOLATION_TYPE::NEAREST_NEIGHBOUR_INTERPOLATION) { - detail::_mlir_ciface_resize_4d_nearest_neighbour_interpolation( + detail::_mlir_ciface_resize_4d_nhwc_nearest_neighbour_interpolation( input, scalingRatios[0], scalingRatios[1], &output); } else if (type == INTERPOLATION_TYPE::BILINEAR_INTERPOLATION) { - detail::_mlir_ciface_resize_4d_bilinear_interpolation( + detail::_mlir_ciface_resize_4d_nhwc_bilinear_interpolation( + input, scalingRatios[0], scalingRatios[1], &output); + } else { + throw std::invalid_argument( + "Please chose a supported type of interpolation " + "(Nearest neighbour interpolation or Bilinear interpolation)\n"); + } + + return output; +} + +inline MemRef Resize4D_NCHW_Impl(dip::Image *input, + INTERPOLATION_TYPE type, + std::vector scalingRatios, + intptr_t outputSize[4]) { + MemRef output(outputSize); + + if (type == INTERPOLATION_TYPE::NEAREST_NEIGHBOUR_INTERPOLATION) { + detail::_mlir_ciface_resize_4d_nchw_nearest_neighbour_interpolation( + input, scalingRatios[0], scalingRatios[1], &output); + } else if (type == INTERPOLATION_TYPE::BILINEAR_INTERPOLATION) { + detail::_mlir_ciface_resize_4d_nchw_bilinear_interpolation( input, scalingRatios[0], scalingRatios[1], &output); } else { throw std::invalid_argument( @@ -369,16 +400,32 @@ inline MemRef Resize2D(Img *input, INTERPOLATION_TYPE type, } // User interface for 4D Resize. -inline MemRef Resize4D(Img *input, INTERPOLATION_TYPE type, - std::vector size) { +inline MemRef Resize4D_NHWC(Img *input, + INTERPOLATION_TYPE type, + std::vector size) { if (size.size() != 4) { throw std::invalid_argument("Dimension of an image should be 4\n"); } intptr_t outputSize[4] = {size[0], size[1], size[2], size[3]}; - return detail::Resize4D_Impl(input, type, - {(float)input->getSizes()[1] / (float)size[1], - (float)input->getSizes()[2] / (float)size[2]}, - outputSize); + return detail::Resize4D_NHWC_Impl( + input, type, + {(float)input->getSizes()[1] / (float)size[1], + (float)input->getSizes()[2] / (float)size[2]}, + outputSize); +} + +inline MemRef Resize4D_NCHW(dip::Image *input, + INTERPOLATION_TYPE type, + std::vector size) { + if (size.size() != 4) { + throw std::invalid_argument("Dimension of an image should be 4\n"); + } + intptr_t outputSize[4] = {size[0], size[1], size[2], size[3]}; + return detail::Resize4D_NCHW_Impl( + input, type, + {(float)input->getSizes()[2] / (float)size[2], + (float)input->getSizes()[3] / (float)size[3]}, + outputSize); } // User interface for 2D Resize. diff --git a/frontend/Interfaces/buddy/DIP/ImgContainer.h b/frontend/Interfaces/buddy/DIP/ImgContainer.h index 0450a75ae0..382974e967 100644 --- a/frontend/Interfaces/buddy/DIP/ImgContainer.h +++ b/frontend/Interfaces/buddy/DIP/ImgContainer.h @@ -33,7 +33,7 @@ enum ImageModes { DIP_RGB = 1, }; -inline bool isBigEndian() { +inline bool ifBigEndian() { int num = 1; char *ptr = (char *)# return (*ptr == 0); @@ -475,7 +475,7 @@ bool Image::decodePNG(const std::vector &fileData) { // Convert big or little Endian and convert 16 bits to 8 bits if (this->bitDepth == 16) png_set_strip_16(png_ptr); - else if (!isBigEndian()) + else if (!ifBigEndian()) png_set_swap(png_ptr); // Remove alpha channel diff --git a/frontend/Interfaces/lib/DIP.mlir b/frontend/Interfaces/lib/DIP.mlir index 05af467e48..3153d1ebe8 100644 --- a/frontend/Interfaces/lib/DIP.mlir +++ b/frontend/Interfaces/lib/DIP.mlir @@ -54,15 +54,27 @@ func.func @resize_2d_bilinear_interpolation(%inputImage : memref, %hori return } -func.func @resize_4d_nearest_neighbour_interpolation(%inputImage : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %outputImage : memref) attributes{llvm.emit_c_interface} +func.func @resize_4d_nhwc_nearest_neighbour_interpolation(%inputImage : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %outputImage : memref) attributes{llvm.emit_c_interface} { - dip.resize_4d NEAREST_NEIGHBOUR_INTERPOLATION %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref + dip.resize_4d_nhwc NEAREST_NEIGHBOUR_INTERPOLATION %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref return } -func.func @resize_4d_bilinear_interpolation(%inputImage : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %outputImage : memref) attributes{llvm.emit_c_interface} +func.func @resize_4d_nhwc_bilinear_interpolation(%inputImage : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %outputImage : memref) attributes{llvm.emit_c_interface} { - dip.resize_4d BILINEAR_INTERPOLATION %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref + dip.resize_4d_nhwc BILINEAR_INTERPOLATION %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref + return +} + +func.func @resize_4d_nchw_nearest_neighbour_interpolation(%inputImage : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %outputImage : memref) attributes{llvm.emit_c_interface} +{ + dip.resize_4d_nchw NEAREST_NEIGHBOUR_INTERPOLATION %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref + return +} + +func.func @resize_4d_nchw_bilinear_interpolation(%inputImage : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %outputImage : memref) attributes{llvm.emit_c_interface} +{ + dip.resize_4d_nchw BILINEAR_INTERPOLATION %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref return } diff --git a/midend/include/Dialect/DIP/DIPOps.td b/midend/include/Dialect/DIP/DIPOps.td index 3f654fe034..179e66359d 100644 --- a/midend/include/Dialect/DIP/DIPOps.td +++ b/midend/include/Dialect/DIP/DIPOps.td @@ -210,7 +210,7 @@ def DIP_Resize2DOp : DIP_Op<"resize_2d"> }]; } -def DIP_Resize4DOp : DIP_Op<"resize_4d"> +def DIP_Resize4D_NHWCOp : DIP_Op<"resize_4d_nhwc"> { let summary = [{ This operation intends to provide a utility for resizing images using the DIP dialect. @@ -228,10 +228,55 @@ def DIP_Resize4DOp : DIP_Op<"resize_4d"> lowering it every time for each new image (Refer to the example provided in examples directory for the DIP dialect). + The processed image format is (batch, height, weight, channel). + + Syntax : + + ```mlir + dip.resize_4d_nhwc INTERPOLATION_TYPE %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref + ``` + + where ```INTERPOLATION_TYPE``` can be ```NEAREST_NEIGHBOUR_INTERPOLATION``` or + ```BILINEAR_INTERPOLATION```. + }]; + + let arguments = (ins Arg:$memrefI, + F32 : $horizontal_scaling_factor, + F32 : $vertical_scaling_factor, + Arg:$memrefO, + DIP_InterpolationAttr:$interpolation_type); + + let assemblyFormat = [{ + $interpolation_type $memrefI `,` $horizontal_scaling_factor `,` $vertical_scaling_factor `,` $memrefO attr-dict `:` type($memrefI) `,` type($horizontal_scaling_factor) `,` type($vertical_scaling_factor) `,` type($memrefO) + }]; +} + +def DIP_Resize4D_NCHWOp : DIP_Op<"resize_4d_nchw"> +{ + let summary = [{ + This operation intends to provide a utility for resizing images using the DIP dialect. + Image resizing has many applications such as data augmentation, dimension adjustment in ML + models, etc. and can thus be used in native MLIR pipelines catering to above mentioned + use-cases. + + As of now, two different mechanisms for pixel interpolation are provided namely nearest + neighbour interpolation and bilinear interpolation. The user can specify the desired type of + interpolation via an attribute provided as argument to the operation. The operation also + expects scaling ratios (Input image dimension / Output image dimension) for both dimensions + of input and output images as arguments. + + The operation is flexible for its use with images of different sizes without necessarily + lowering it every time for each new image (Refer to the example provided in examples + directory for the DIP dialect). + + The processed image format is (batch, channel, height, weight). + Syntax : ```mlir - dip.resize_4d INTERPOLATION_TYPE %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref + dip.resize_4d_nchw INTERPOLATION_TYPE %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref ``` where ```INTERPOLATION_TYPE``` can be ```NEAREST_NEIGHBOUR_INTERPOLATION``` or diff --git a/midend/include/Utils/DIPUtils.h b/midend/include/Utils/DIPUtils.h index 75849ac89b..a8b77e8f23 100644 --- a/midend/include/Utils/DIPUtils.h +++ b/midend/include/Utils/DIPUtils.h @@ -105,12 +105,12 @@ void fillPixels(OpBuilder &builder, Location loc, Value resXVec, Value resYVec, // Fill appropriate pixel 4D data in its corresponding rotated co-ordinate of // output image. -void fillPixels4D(OpBuilder &builder, Location loc, Value ivs0, Value ivs1, - Value resXVec, Value resYVec, Value xVec, Value yVec, - Value input, Value output, Value c0, Value strideVal, - Value outputRowLastElemF32, Value outputColLastElemF32, - Value inputRowLastElemF32, Value inputColLastElemF32, - Value c0F32); +void fillPixelsNearestNeighbour4D( + OpBuilder &builder, Location loc, Value ivs0, Value ivs1, Value resXVec, + Value resYVec, Value xVec, Value yVec, Value input, Value output, Value c0, + Value strideVal, Value outputRowLastElemF32, Value outputColLastElemF32, + Value inputRowLastElemF32, Value inputColLastElemF32, Value c0F32, + Value dataCondition); // Calculate tan(angle / 2) where angle is a function parameter. Value customTanVal(OpBuilder &builder, Location loc, Value angleVal); @@ -166,7 +166,7 @@ void fillPixelsBilinearInterpolate4D( Value input, Value output, Value c0, Value strideVal, Value xVecWeight, Value yVecWeight, Value outputRowLastElemF32, Value outputColLastElemF32, Value inputRowLastElemF32, Value inputColLastElemF32, Value c0F32, - Value c1F32); + Value c1F32, Value dataCondition); // Helper function for resizing an image using nearest neighbour interpolation // mechanism. @@ -188,7 +188,7 @@ void NearestNeighbourInterpolationResizing4D( Value horizontalScalingFactorVec, Value verticalScalingFactorVec, Value outputRowLastElemF32, Value outputColLastElemF32, Value inputRowLastElemF32, Value inputColLastElemF32, VectorType vectorTy32, - int64_t stride, Value c0, Value c0F32); + int64_t stride, Value c0, Value c0F32, Value dataCondition); // Helper function for resizing an image using bilinear interpolation mechanism. void BilinearInterpolationResizing( @@ -209,7 +209,7 @@ void BilinearInterpolationResizing4D( Value horizontalScalingFactorVec, Value verticalScalingFactorVec, Value outputRowLastElemF32, Value outputColLastElemF32, Value inputRowLastElemF32, Value inputColLastElemF32, VectorType vectorTy32, - int64_t stride, Value c0, Value c0F32, Value c1F32); + int64_t stride, Value c0, Value c0F32, Value c1F32, Value dataCondition); // Util function for morphological transformations ; compares two vectors and // returns a mask diff --git a/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp b/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp index e57431579e..6118ecc1c3 100644 --- a/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp +++ b/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp @@ -360,16 +360,17 @@ class DIPResize2DOpLowering : public OpRewritePattern { int64_t stride; }; -class DIPResize4DOpLowering : public OpRewritePattern { +class DIPResize4D_NHWCOpLowering + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - explicit DIPResize4DOpLowering(MLIRContext *context, int64_t strideParam) + explicit DIPResize4D_NHWCOpLowering(MLIRContext *context, int64_t strideParam) : OpRewritePattern(context) { stride = strideParam; } - LogicalResult matchAndRewrite(dip::Resize4DOp op, + LogicalResult matchAndRewrite(dip::Resize4D_NHWCOp op, PatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto ctx = op->getContext(); @@ -384,7 +385,7 @@ class DIPResize4DOpLowering : public OpRewritePattern { auto inElemTy = input.getType().cast().getElementType(); dip::DIP_ERROR error = - dip::checkDIPCommonTypes(op, {input, output}); + dip::checkDIPCommonTypes(op, {input, output}); if (error == dip::DIP_ERROR::INCONSISTENT_TYPES) { return op->emitOpError() @@ -394,6 +395,10 @@ class DIPResize4DOpLowering : public OpRewritePattern { << inElemTy << "is passed"; } + // true: NHWC, false: NCHW + Value dataCondition = rewriter.create( + loc, rewriter.getI1Type(), rewriter.getBoolAttr(true)); + Value c0 = rewriter.create(loc, 0); Value c1 = rewriter.create(loc, 1); Value c2 = rewriter.create(loc, 2); @@ -401,7 +406,7 @@ class DIPResize4DOpLowering : public OpRewritePattern { Value c0F32 = indexToF32(rewriter, loc, c0); - Value inputBatch = rewriter.create(loc, input, c0); + // Value inputBatch = rewriter.create(loc, input, c0); Value inputRow = rewriter.create(loc, input, c1); Value inputCol = rewriter.create(loc, input, c2); Value inputColor = rewriter.create(loc, input, c3); @@ -433,7 +438,6 @@ class DIPResize4DOpLowering : public OpRewritePattern { FloatType f32 = FloatType::getF32(ctx); VectorType vectorTy32 = VectorType::get({stride}, f32); - // tsworld: line 157 Value horizontalScalingFactorVec = rewriter.create( loc, vectorTy32, horizontalScalingFactor); Value verticalScalingFactorVec = rewriter.create( @@ -441,7 +445,6 @@ class DIPResize4DOpLowering : public OpRewritePattern { // Obtain extreme allocatable value(s) in input and output for bounding // purpose. - // tsworld: before line 170 Value inputRowLastElem = rewriter.create(loc, inputRow, c1); Value inputRowLastElemF32 = indexToF32(rewriter, loc, inputRowLastElem); @@ -462,13 +465,13 @@ class DIPResize4DOpLowering : public OpRewritePattern { rewriter, loc, ctx, lowerBounds1, upperBounds1, steps, strideVal, input, output, horizontalScalingFactorVec, verticalScalingFactorVec, outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, - inputColLastElemF32, vectorTy32, stride, c0, c0F32); + inputColLastElemF32, vectorTy32, stride, c0, c0F32, dataCondition); dip::NearestNeighbourInterpolationResizing4D( rewriter, loc, ctx, lowerBounds2, upperBounds2, steps, strideTailVal, input, output, horizontalScalingFactorVec, verticalScalingFactorVec, outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, - inputColLastElemF32, vectorTy32, stride, c0, c0F32); + inputColLastElemF32, vectorTy32, stride, c0, c0F32, dataCondition); } else if (interpolationAttr == dip::InterpolationType::BilinearInterpolation) { Value c1F32 = indexToF32(rewriter, loc, c1); @@ -477,13 +480,155 @@ class DIPResize4DOpLowering : public OpRewritePattern { rewriter, loc, ctx, lowerBounds1, upperBounds1, steps, strideVal, input, output, horizontalScalingFactorVec, verticalScalingFactorVec, outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, - inputColLastElemF32, vectorTy32, stride, c0, c0F32, c1F32); + inputColLastElemF32, vectorTy32, stride, c0, c0F32, c1F32, + dataCondition); dip::BilinearInterpolationResizing4D( rewriter, loc, ctx, lowerBounds2, upperBounds2, steps, strideTailVal, input, output, horizontalScalingFactorVec, verticalScalingFactorVec, outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, - inputColLastElemF32, vectorTy32, stride, c0, c0F32, c1F32); + inputColLastElemF32, vectorTy32, stride, c0, c0F32, c1F32, + dataCondition); + } + + // Remove the original resize operation. + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t stride; +}; + +class DIPResize4D_NCHWOpLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + explicit DIPResize4D_NCHWOpLowering(MLIRContext *context, int64_t strideParam) + : OpRewritePattern(context) { + stride = strideParam; + } + + LogicalResult matchAndRewrite(dip::Resize4D_NCHWOp op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + + // Register operand values. + Value input = op->getOperand(0); + Value horizontalScalingFactor = op->getOperand(1); + Value verticalScalingFactor = op->getOperand(2); + Value output = op->getOperand(3); + auto interpolationAttr = op.getInterpolationType(); + Value strideVal = rewriter.create(loc, stride); + + auto inElemTy = input.getType().cast().getElementType(); + dip::DIP_ERROR error = + dip::checkDIPCommonTypes(op, {input, output}); + + if (error == dip::DIP_ERROR::INCONSISTENT_TYPES) { + return op->emitOpError() + << "input, and output must have the same element type"; + } else if (error == dip::DIP_ERROR::UNSUPPORTED_TYPE) { + return op->emitOpError() << "supports only f32, f64 and integer types. " + << inElemTy << "is passed"; + } + + // true: NHWC, false: NCHW + Value dataCondition = rewriter.create( + loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value c2 = rewriter.create(loc, 2); + Value c3 = rewriter.create(loc, 3); + + Value c0F32 = indexToF32(rewriter, loc, c0); + + // Value inputBatch = rewriter.create(loc, input, c0); + Value inputColor = rewriter.create(loc, input, c1); + Value inputRow = rewriter.create(loc, input, c2); + Value inputCol = rewriter.create(loc, input, c3); + + Value outputBatch = rewriter.create(loc, output, c0); + Value outputColor = rewriter.create(loc, output, c1); + Value outputRow = rewriter.create(loc, output, c2); + Value outputCol = rewriter.create(loc, output, c3); + + // Determine lower bound for second call of resize function (this is done + // for efficient tail processing). + Value outputColStrideRatio = + rewriter.create(loc, outputCol, strideVal); + Value outputColMultiple = + rewriter.create(loc, strideVal, outputColStrideRatio); + + SmallVector lowerBounds1{c0, c0, c0, c0}; + SmallVector upperBounds1{outputBatch, outputColor, outputRow, + outputColMultiple}; + + SmallVector steps{1, 1, 1, stride}; + Value strideTailVal = + rewriter.create(loc, outputCol, outputColMultiple); + + SmallVector lowerBounds2{c0, c0, c0, outputColMultiple}; + SmallVector upperBounds2{outputBatch, outputColor, outputRow, + outputCol}; + + FloatType f32 = FloatType::getF32(ctx); + VectorType vectorTy32 = VectorType::get({stride}, f32); + + Value horizontalScalingFactorVec = rewriter.create( + loc, vectorTy32, horizontalScalingFactor); + Value verticalScalingFactorVec = rewriter.create( + loc, vectorTy32, verticalScalingFactor); + + // Obtain extreme allocatable value(s) in input and output for bounding + // purpose. + Value inputRowLastElem = rewriter.create(loc, inputRow, c1); + Value inputRowLastElemF32 = indexToF32(rewriter, loc, inputRowLastElem); + + Value inputColLastElem = rewriter.create(loc, inputCol, c1); + Value inputColLastElemF32 = indexToF32(rewriter, loc, inputColLastElem); + + Value outputRowLastElem = + rewriter.create(loc, outputRow, c1); + Value outputRowLastElemF32 = indexToF32(rewriter, loc, outputRowLastElem); + + Value outputColLastElem = + rewriter.create(loc, outputCol, c1); + Value outputColLastElemF32 = indexToF32(rewriter, loc, outputColLastElem); + + if (interpolationAttr == + dip::InterpolationType::NearestNeighbourInterpolation) { + dip::NearestNeighbourInterpolationResizing4D( + rewriter, loc, ctx, lowerBounds1, upperBounds1, steps, strideVal, + input, output, horizontalScalingFactorVec, verticalScalingFactorVec, + outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, + inputColLastElemF32, vectorTy32, stride, c0, c0F32, dataCondition); + + dip::NearestNeighbourInterpolationResizing4D( + rewriter, loc, ctx, lowerBounds2, upperBounds2, steps, strideTailVal, + input, output, horizontalScalingFactorVec, verticalScalingFactorVec, + outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, + inputColLastElemF32, vectorTy32, stride, c0, c0F32, dataCondition); + } else if (interpolationAttr == + dip::InterpolationType::BilinearInterpolation) { + Value c1F32 = indexToF32(rewriter, loc, c1); + + dip::BilinearInterpolationResizing4D( + rewriter, loc, ctx, lowerBounds1, upperBounds1, steps, strideVal, + input, output, horizontalScalingFactorVec, verticalScalingFactorVec, + outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, + inputColLastElemF32, vectorTy32, stride, c0, c0F32, c1F32, + dataCondition); + + dip::BilinearInterpolationResizing4D( + rewriter, loc, ctx, lowerBounds2, upperBounds2, steps, strideTailVal, + input, output, horizontalScalingFactorVec, verticalScalingFactorVec, + outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, + inputColLastElemF32, vectorTy32, stride, c0, c0F32, c1F32, + dataCondition); } // Remove the original resize operation. @@ -1443,7 +1588,8 @@ void populateLowerDIPConversionPatterns(RewritePatternSet &patterns, patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); - patterns.add(patterns.getContext(), stride); + patterns.add(patterns.getContext(), stride); + patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); diff --git a/midend/lib/Utils/DIPUtils.cpp b/midend/lib/Utils/DIPUtils.cpp index 3116c60fd2..da41b65cd6 100644 --- a/midend/lib/Utils/DIPUtils.cpp +++ b/midend/lib/Utils/DIPUtils.cpp @@ -55,8 +55,12 @@ checkDIPCommonTypes(dip::Resize2DOp, const std::vector &args); template DIP_ERROR -checkDIPCommonTypes(dip::Resize4DOp, - const std::vector &args); +checkDIPCommonTypes(dip::Resize4D_NHWCOp, + const std::vector &args); + +template DIP_ERROR +checkDIPCommonTypes(dip::Resize4D_NCHWOp, + const std::vector &args); template DIP_ERROR checkDIPCommonTypes(dip::Erosion2DOp, @@ -113,7 +117,8 @@ DIP_ERROR checkDIPCommonTypes(DIPOP op, const std::vector &args) { } } else if (op->getName().stripDialect() == "rotate_2d" || op->getName().stripDialect() == "resize_2d" || - op->getName().stripDialect() == "resize_4d") { + op->getName().stripDialect() == "resize_4d_nhwc" || + op->getName().stripDialect() == "resize_4d_nchw") { auto inElemTy = getElementType(0); auto outElemTy = getElementType(1); @@ -389,12 +394,12 @@ void fillPixels(OpBuilder &builder, Location loc, Value resXVec, Value resYVec, // Fill appropriate pixel data in its corresponding co-ordinate of the output // image. -void fillPixels4D(OpBuilder &builder, Location loc, Value ivs0, Value ivs1, - Value resXVec, Value resYVec, Value xVec, Value yVec, - Value input, Value output, Value c0, Value strideVal, - Value outputRowLastElemF32, Value outputColLastElemF32, - Value inputRowLastElemF32, Value inputColLastElemF32, - Value c0F32) { +void fillPixelsNearestNeighbour4D( + OpBuilder &builder, Location loc, Value ivs0, Value ivs1, Value resXVec, + Value resYVec, Value xVec, Value yVec, Value input, Value output, Value c0, + Value strideVal, Value outputRowLastElemF32, Value outputColLastElemF32, + Value inputRowLastElemF32, Value inputColLastElemF32, Value c0F32, + Value dataCondition) { builder.create( loc, ValueRange{c0}, builder.getDimIdentityMap(), ValueRange{strideVal}, builder.getDimIdentityMap(), /*step*/ 1, std::nullopt, @@ -407,12 +412,36 @@ void fillPixels4D(OpBuilder &builder, Location loc, Value ivs0, Value ivs1, extractIndices(builder, loc, resXVec, resYVec, ivs[0], outputColLastElemF32, outputRowLastElemF32, c0F32); - Value pixelVal = builder.create( - loc, builder.getF32Type(), input, - ValueRange{ivs0, origIndices[1], origIndices[0], ivs1}); - builder.create( - loc, pixelVal, output, - ValueRange{ivs0, resIndices[1], resIndices[0], ivs1}); + auto ifop = builder.create( + loc, dataCondition, + [&](OpBuilder &builder, Location loc) { + Value pixelVal = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, origIndices[1], origIndices[0], ivs1}); + builder.create(loc, pixelVal); + }, + [&](OpBuilder &builder, Location loc) { + Value pixelVal = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, ivs1, origIndices[1], origIndices[0]}); + builder.create(loc, pixelVal); + }); + Value pixelVal = ifop.getResult(0); + + builder.create( + loc, dataCondition, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, pixelVal, output, + ValueRange{ivs0, resIndices[1], resIndices[0], ivs1}); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, pixelVal, output, + ValueRange{ivs0, ivs1, resIndices[1], resIndices[0]}); + builder.create(loc); + }); builder.create(loc); }); @@ -780,7 +809,7 @@ void fillPixelsBilinearInterpolate4D( Value input, Value output, Value c0, Value strideVal, Value xVecWeight, Value yVecWeight, Value outputRowLastElemF32, Value outputColLastElemF32, Value inputRowLastElemF32, Value inputColLastElemF32, Value c0F32, - Value c1F32) { + Value c1F32, Value dataCondition) { builder.create( loc, ValueRange{c0}, builder.getDimIdentityMap(), ValueRange{strideVal}, builder.getDimIdentityMap(), /*step*/ 1, std::nullopt, @@ -812,18 +841,46 @@ void fillPixelsBilinearInterpolate4D( builder.create(loc, c1F32, indexWeights[0]), builder.create(loc, c1F32, indexWeights[1])}; - Value pixelVal_a = builder.create( - loc, builder.getF32Type(), input, - ValueRange{ivs0, inputIndices_L[1], inputIndices_L[0], ivs1}); - Value pixelVal_b = builder.create( - loc, builder.getF32Type(), input, - ValueRange{ivs0, inputIndices_H[1], inputIndices_L[0], ivs1}); - Value pixelVal_c = builder.create( - loc, builder.getF32Type(), input, - ValueRange{ivs0, inputIndices_L[1], inputIndices_H[0], ivs1}); - Value pixelVal_d = builder.create( - loc, builder.getF32Type(), input, - ValueRange{ivs0, inputIndices_H[1], inputIndices_H[0], ivs1}); + auto ifop = builder.create( + loc, dataCondition, + [&](OpBuilder &builder, Location loc) { + Value pixelVal_a = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, inputIndices_L[1], inputIndices_L[0], ivs1}); + Value pixelVal_b = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, inputIndices_H[1], inputIndices_L[0], ivs1}); + Value pixelVal_c = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, inputIndices_L[1], inputIndices_H[0], ivs1}); + Value pixelVal_d = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, inputIndices_H[1], inputIndices_H[0], ivs1}); + builder.create( + loc, + ValueRange{pixelVal_a, pixelVal_b, pixelVal_c, pixelVal_d}); + }, + [&](OpBuilder &builder, Location loc) { + Value pixelVal_a = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, ivs1, inputIndices_L[1], inputIndices_L[0]}); + Value pixelVal_b = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, ivs1, inputIndices_H[1], inputIndices_L[0]}); + Value pixelVal_c = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, ivs1, inputIndices_L[1], inputIndices_H[0]}); + Value pixelVal_d = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, ivs1, inputIndices_H[1], inputIndices_H[0]}); + builder.create( + loc, + ValueRange{pixelVal_a, pixelVal_b, pixelVal_c, pixelVal_d}); + }); + Value pixelVal_a = ifop.getResult(0); + Value pixelVal_b = ifop.getResult(1); + Value pixelVal_c = ifop.getResult(2); + Value pixelVal_d = ifop.getResult(3); Value weightVal1 = builder.create(loc, indexWeights_UnitComplements[0], @@ -848,14 +905,25 @@ void fillPixelsBilinearInterpolate4D( builder.create(loc, interm1, interm2); Value pixel_interm2 = builder.create(loc, interm3, interm4); - Value pixel_interm3 = + Value pixelVal = builder.create(loc, pixel_interm1, pixel_interm2); - Value pixelVal = roundOff(builder, loc, pixel_interm3); + // Value pixelVal = roundOff(builder, loc, pixel_interm3); - builder.create( - loc, pixelVal, output, - ValueRange{ivs0, resIndices[1], resIndices[0], ivs1}); + builder.create( + loc, dataCondition, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, pixelVal, output, + ValueRange{ivs0, resIndices[1], resIndices[0], ivs1}); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, pixelVal, output, + ValueRange{ivs0, ivs1, resIndices[1], resIndices[0]}); + builder.create(loc); + }); builder.create(loc); }); @@ -902,7 +970,7 @@ void NearestNeighbourInterpolationResizing4D( Value horizontalScalingFactorVec, Value verticalScalingFactorVec, Value outputRowLastElemF32, Value outputColLastElemF32, Value inputRowLastElemF32, Value inputColLastElemF32, VectorType vectorTy32, - int64_t stride, Value c0, Value c0F32) { + int64_t stride, Value c0, Value c0F32, Value dataCondition) { affine::buildAffineLoopNest( builder, loc, lowerBounds, upperBounds, steps, [&](OpBuilder &builder, Location loc, ValueRange ivs) { @@ -911,18 +979,18 @@ void NearestNeighbourInterpolationResizing4D( Value xVec = iotaVec(builder, loc, ctx, ivs[3], strideVal, vectorTy32, c0, stride); - Value resXVecInterm = builder.create( - loc, xVec, horizontalScalingFactorVec); - Value resYVecInterm = - builder.create(loc, yVec, verticalScalingFactorVec); + Value resXVecInterm = + builder.create(loc, xVec, verticalScalingFactorVec); + Value resYVecInterm = builder.create( + loc, yVec, horizontalScalingFactorVec); Value resXVec = roundOff(builder, loc, resXVecInterm); Value resYVec = roundOff(builder, loc, resYVecInterm); - fillPixels4D(builder, loc, ivs[0], ivs[1], xVec, yVec, resXVec, resYVec, - input, output, c0, strideVal, outputRowLastElemF32, - outputColLastElemF32, inputRowLastElemF32, - inputColLastElemF32, c0F32); + fillPixelsNearestNeighbour4D( + builder, loc, ivs[0], ivs[1], xVec, yVec, resXVec, resYVec, input, + output, c0, strideVal, outputRowLastElemF32, outputColLastElemF32, + inputRowLastElemF32, inputColLastElemF32, c0F32, dataCondition); }); } @@ -976,7 +1044,7 @@ void BilinearInterpolationResizing4D( Value horizontalScalingFactorVec, Value verticalScalingFactorVec, Value outputRowLastElemF32, Value outputColLastElemF32, Value inputRowLastElemF32, Value inputColLastElemF32, VectorType vectorTy32, - int64_t stride, Value c0, Value c0F32, Value c1F32) { + int64_t stride, Value c0, Value c0F32, Value c1F32, Value dataCondition) { affine::buildAffineLoopNest( builder, loc, lowerBounds, upperBounds, steps, [&](OpBuilder &builder, Location loc, ValueRange ivs) { @@ -985,10 +1053,10 @@ void BilinearInterpolationResizing4D( Value xVec = iotaVec(builder, loc, ctx, ivs[3], strideVal, vectorTy32, c0, stride); - Value xVecInterm = builder.create( - loc, xVec, horizontalScalingFactorVec); - Value yVecInterm = - builder.create(loc, yVec, verticalScalingFactorVec); + Value xVecInterm = + builder.create(loc, xVec, verticalScalingFactorVec); + Value yVecInterm = builder.create( + loc, yVec, horizontalScalingFactorVec); Value xVecInterm_L = builder.create(loc, xVecInterm); Value xVecInterm_H = builder.create(loc, xVecInterm); @@ -1006,7 +1074,7 @@ void BilinearInterpolationResizing4D( yVecInterm_L, xVecInterm_H, yVecInterm_H, input, output, c0, strideVal, xVecWeight, yVecWeight, outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, inputColLastElemF32, - c0F32, c1F32); + c0F32, c1F32, dataCondition); }); } diff --git a/tests/Dialect/DIP/resize4D_lowering.mlir b/tests/Dialect/DIP/resize4D_lowering.mlir deleted file mode 100644 index 0e58b040ae..0000000000 --- a/tests/Dialect/DIP/resize4D_lowering.mlir +++ /dev/null @@ -1,7 +0,0 @@ -// RUN: buddy-opt -verify-diagnostics %s | buddy-opt | FileCheck %s - -func.func @buddy_resize4d_NEAREST_NEIGHBOUR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { - // CHECK: dip.resize_4d NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref - dip.resize_4d NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref - return -} diff --git a/tests/Dialect/DIP/resize4D_nchw_lowering.mlir b/tests/Dialect/DIP/resize4D_nchw_lowering.mlir new file mode 100644 index 0000000000..92f6cf3728 --- /dev/null +++ b/tests/Dialect/DIP/resize4D_nchw_lowering.mlir @@ -0,0 +1,7 @@ +// RUN: buddy-opt --lower-dip %s | buddy-opt | FileCheck %s + +func.func @buddy_resize4d_nchw_NEAREST_NEIGHBOUR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: memref.store %57, %arg3[%arg4, %arg5, %56, %54] : memref + dip.resize_4d_nchw NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} diff --git a/tests/Dialect/DIP/resize4D_nchw_roundtrip.mlir b/tests/Dialect/DIP/resize4D_nchw_roundtrip.mlir new file mode 100644 index 0000000000..3850a88dc5 --- /dev/null +++ b/tests/Dialect/DIP/resize4D_nchw_roundtrip.mlir @@ -0,0 +1,25 @@ +// RUN: buddy-opt -verify-diagnostics %s | buddy-opt | FileCheck %s + +func.func @buddy_resize4d_nchw_NEAREST_NEIGHBOUR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nchw NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nchw NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} + +func.func @buddy_resize4d_nchw_NEAREST_NEIGHBOUR_INTERPOLATION_f64(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nchw NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nchw NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} + +func.func @buddy_resize4d_nchw_BILINEAR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nchw BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nchw BILINEAR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} + +func.func @buddy_resize4d_nchw_BILINEAR_INTERPOLATION_f64(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nchw BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nchw BILINEAR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} diff --git a/tests/Dialect/DIP/resize4D_nhwc_lowering.mlir b/tests/Dialect/DIP/resize4D_nhwc_lowering.mlir new file mode 100644 index 0000000000..79291b9340 --- /dev/null +++ b/tests/Dialect/DIP/resize4D_nhwc_lowering.mlir @@ -0,0 +1,7 @@ +// RUN: buddy-opt --lower-dip %s | buddy-opt | FileCheck %s + +func.func @buddy_resize4d_nhwc_NEAREST_NEIGHBOUR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: memref.store %57, %arg3[%arg4, %arg5, %56, %54] : memref + dip.resize_4d_nhwc NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} diff --git a/tests/Dialect/DIP/resize4D_nhwc_roundtrip.mlir b/tests/Dialect/DIP/resize4D_nhwc_roundtrip.mlir new file mode 100644 index 0000000000..46bea2fdbd --- /dev/null +++ b/tests/Dialect/DIP/resize4D_nhwc_roundtrip.mlir @@ -0,0 +1,25 @@ +// RUN: buddy-opt -verify-diagnostics %s | buddy-opt | FileCheck %s + +func.func @buddy_resize4d_nhwc_NEAREST_NEIGHBOUR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nhwc NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nhwc NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} + +func.func @buddy_resize4d_nhwc_NEAREST_NEIGHBOUR_INTERPOLATION_f64(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nhwc NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nhwc NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} + +func.func @buddy_resize4d_nhwc_BILINEAR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nhwc BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nhwc BILINEAR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} + +func.func @buddy_resize4d_nhwc_BILINEAR_INTERPOLATION_f64(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nhwc BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nhwc BILINEAR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} diff --git a/tests/Dialect/DIP/resize4D_roundtrip.mlir b/tests/Dialect/DIP/resize4D_roundtrip.mlir deleted file mode 100644 index 35ccf91f64..0000000000 --- a/tests/Dialect/DIP/resize4D_roundtrip.mlir +++ /dev/null @@ -1,25 +0,0 @@ -// RUN: buddy-opt -verify-diagnostics %s | buddy-opt | FileCheck %s - -func.func @buddy_resize4d_NEAREST_NEIGHBOUR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { - // CHECK: dip.resize_4d NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref - dip.resize_4d NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref - return -} - -func.func @buddy_resize4d_NEAREST_NEIGHBOUR_INTERPOLATION_f64(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { - // CHECK: dip.resize_4d NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref - dip.resize_4d NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref - return -} - -func.func @buddy_resize4d_BILINEAR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { - // CHECK: dip.resize_4d BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref - dip.resize_4d BILINEAR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref - return -} - -func.func @buddy_resize4d_BILINEAR_INTERPOLATION_f64(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { - // CHECK: dip.resize_4d BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref - dip.resize_4d BILINEAR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref - return -}