From 566afc878e579e4ebb4295da6d08167ed6c2b620 Mon Sep 17 00:00:00 2001 From: askrebko Date: Mon, 1 Aug 2022 12:49:57 +0300 Subject: [PATCH] Add a operand type check for Corr2D op and a test - Make sure that input, kernel, output and constant have the same value and use as inferred type - Adding a negative lit test to check params of the op --- lib/Conversion/LowerDIP/LowerDIPPass.cpp | 13 +++- .../DIP/correlation2D_invalid_type.mlir | 61 +++++++++++++++++++ 2 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 tests/Dialect/DIP/correlation2D_invalid_type.mlir diff --git a/lib/Conversion/LowerDIP/LowerDIPPass.cpp b/lib/Conversion/LowerDIP/LowerDIPPass.cpp index f16629dfa..613595ff5 100644 --- a/lib/Conversion/LowerDIP/LowerDIPPass.cpp +++ b/lib/Conversion/LowerDIP/LowerDIPPass.cpp @@ -73,8 +73,16 @@ class DIPCorr2DOpLowering : public OpRewritePattern { auto boundaryOptionAttr = op.boundary_option(); Value strideVal = rewriter.create(loc, stride); - auto memRefTy = input.getType().cast(); - auto elemTy = memRefTy.getElementType(); + auto inElemTy = input.getType().cast().getElementType(); + auto kElemTy = kernel.getType().cast().getElementType(); + auto outElemTy = output.getType().cast().getElementType(); + auto constElemTy = constantValue.getType(); + if (inElemTy != kElemTy || kElemTy != outElemTy || outElemTy != constElemTy) { + return op->emitOpError() << "input, kernel, output and constant must have the same element type"; + } + // NB: we can infer element type for all operation to be the same as input + // since we verified that the operand types are the same + auto elemTy = inElemTy; IntegerType i1 = IntegerType::get(ctx, 1); @@ -99,6 +107,7 @@ class DIPCorr2DOpLowering : public OpRewritePattern { FloatType f32 = FloatType::getF32(ctx); IntegerType i32 = IntegerType::get(ctx, 32); Value zeroPaddingElem = {}; + // TODO: extend for other types and add a check for supported types if (elemTy.isF32()) { zeroPaddingElem = rewriter.create(loc, (APFloat)(float)0, f32); } else if (elemTy.isInteger(32)) { diff --git a/tests/Dialect/DIP/correlation2D_invalid_type.mlir b/tests/Dialect/DIP/correlation2D_invalid_type.mlir new file mode 100644 index 000000000..db49215bf --- /dev/null +++ b/tests/Dialect/DIP/correlation2D_invalid_type.mlir @@ -0,0 +1,61 @@ +// +// x86 +// +// RUN: not buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \ +// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts 2>&1 | FileCheck %s + +memref.global "private" @global_input_f32 : memref<3x3xf32> = dense<[[0. , 1. , 2. ], + [10., 11., 12.], + [20., 21., 22.]]> + + +memref.global "private" @global_identity_f32 : memref<3x3xf32> = dense<[[0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.]]> + +memref.global "private" @global_output_f32 : memref<3x3xf32> = dense<[[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]> + +memref.global "private" @global_input_i32 : memref<3x3xi32> = dense<[[0 , 1 , 2 ], + [10, 11, 12], + [20, 21, 22]]> + + +memref.global "private" @global_identity_i32 : memref<3x3xi32> = dense<[[0, 0, 0], + [0, 1, 0], + [0, 0, 0]]> + +memref.global "private" @global_output_i32 : memref<3x3xi32> = dense<[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]> + +func.func @main() -> i32 { + %input_f32 = memref.get_global @global_input_f32 : memref<3x3xf32> + %identity_f32 = memref.get_global @global_identity_f32 : memref<3x3xf32> + %output_f32 = memref.get_global @global_output_f32 : memref<3x3xf32> + %c_f32 = arith.constant 0. : f32 + + %input_i32 = memref.get_global @global_input_i32 : memref<3x3xi32> + %identity_i32 = memref.get_global @global_identity_i32 : memref<3x3xi32> + %output_i32 = memref.get_global @global_output_i32 : memref<3x3xi32> + %c_i32 = arith.constant 0 : i32 + + %x = arith.constant 1 : index + %y = arith.constant 1 : index + + dip.corr_2d %input_i32, %identity_f32, %output_f32, %x, %x, %c_f32 : memref<3x3xi32>, memref<3x3xf32>, memref<3x3xf32>, index, index, f32 + // CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type + + dip.corr_2d %input_f32, %identity_i32, %output_f32, %x, %x, %c_f32 : memref<3x3xf32>, memref<3x3xi32>, memref<3x3xf32>, index, index, f32 + // CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type + + dip.corr_2d %input_f32, %identity_f32, %output_i32, %x, %x, %c_f32 : memref<3x3xf32>, memref<3x3xf32>, memref<3x3xi32>, index, index, f32 + // CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type + + dip.corr_2d %input_f32, %identity_f32, %output_f32, %x, %x, %c_i32 : memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, index, index, i32 + // CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type + + %ret = arith.constant 0 : i32 + return %ret : i32 +}