Skip to content

Commit

Permalink
Add a operand type check for Corr2D op and a test
Browse files Browse the repository at this point in the history
  - Make sure that input, kernel, output and constant have the same
    value and use as inferred type
  - Adding a negative lit test to check params of the op
  • Loading branch information
ArtemSkrebkov committed Aug 1, 2022
1 parent 815bfb9 commit 566afc8
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
13 changes: 11 additions & 2 deletions lib/Conversion/LowerDIP/LowerDIPPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,16 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
auto boundaryOptionAttr = op.boundary_option();
Value strideVal = rewriter.create<ConstantIndexOp>(loc, stride);

auto memRefTy = input.getType().cast<MemRefType>();
auto elemTy = memRefTy.getElementType();
auto inElemTy = input.getType().cast<MemRefType>().getElementType();
auto kElemTy = kernel.getType().cast<MemRefType>().getElementType();
auto outElemTy = output.getType().cast<MemRefType>().getElementType();
auto constElemTy = constantValue.getType();
if (inElemTy != kElemTy || kElemTy != outElemTy || outElemTy != constElemTy) {
return op->emitOpError() << "input, kernel, output and constant must have the same element type";
}
// NB: we can infer element type for all operation to be the same as input
// since we verified that the operand types are the same
auto elemTy = inElemTy;

IntegerType i1 = IntegerType::get(ctx, 1);

Expand All @@ -99,6 +107,7 @@ class DIPCorr2DOpLowering : public OpRewritePattern<dip::Corr2DOp> {
FloatType f32 = FloatType::getF32(ctx);
IntegerType i32 = IntegerType::get(ctx, 32);
Value zeroPaddingElem = {};
// TODO: extend for other types and add a check for supported types
if (elemTy.isF32()) {
zeroPaddingElem = rewriter.create<ConstantFloatOp>(loc, (APFloat)(float)0, f32);
} else if (elemTy.isInteger(32)) {
Expand Down
61 changes: 61 additions & 0 deletions tests/Dialect/DIP/correlation2D_invalid_type.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//
// x86
//
// RUN: not buddy-opt %s -lower-dip="DIP-strip-mining=64" -arith-expand --convert-vector-to-scf --lower-affine --convert-scf-to-cf --convert-vector-to-llvm \
// RUN: --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts 2>&1 | FileCheck %s

memref.global "private" @global_input_f32 : memref<3x3xf32> = dense<[[0. , 1. , 2. ],
[10., 11., 12.],
[20., 21., 22.]]>


memref.global "private" @global_identity_f32 : memref<3x3xf32> = dense<[[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]]>

memref.global "private" @global_output_f32 : memref<3x3xf32> = dense<[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]>

memref.global "private" @global_input_i32 : memref<3x3xi32> = dense<[[0 , 1 , 2 ],
[10, 11, 12],
[20, 21, 22]]>


memref.global "private" @global_identity_i32 : memref<3x3xi32> = dense<[[0, 0, 0],
[0, 1, 0],
[0, 0, 0]]>

memref.global "private" @global_output_i32 : memref<3x3xi32> = dense<[[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]>

func.func @main() -> i32 {
%input_f32 = memref.get_global @global_input_f32 : memref<3x3xf32>
%identity_f32 = memref.get_global @global_identity_f32 : memref<3x3xf32>
%output_f32 = memref.get_global @global_output_f32 : memref<3x3xf32>
%c_f32 = arith.constant 0. : f32

%input_i32 = memref.get_global @global_input_i32 : memref<3x3xi32>
%identity_i32 = memref.get_global @global_identity_i32 : memref<3x3xi32>
%output_i32 = memref.get_global @global_output_i32 : memref<3x3xi32>
%c_i32 = arith.constant 0 : i32

%x = arith.constant 1 : index
%y = arith.constant 1 : index

dip.corr_2d <CONSTANT_PADDING> %input_i32, %identity_f32, %output_f32, %x, %x, %c_f32 : memref<3x3xi32>, memref<3x3xf32>, memref<3x3xf32>, index, index, f32
// CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type

dip.corr_2d <CONSTANT_PADDING> %input_f32, %identity_i32, %output_f32, %x, %x, %c_f32 : memref<3x3xf32>, memref<3x3xi32>, memref<3x3xf32>, index, index, f32
// CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type

dip.corr_2d <CONSTANT_PADDING> %input_f32, %identity_f32, %output_i32, %x, %x, %c_f32 : memref<3x3xf32>, memref<3x3xf32>, memref<3x3xi32>, index, index, f32
// CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type

dip.corr_2d <CONSTANT_PADDING> %input_f32, %identity_f32, %output_f32, %x, %x, %c_i32 : memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, index, index, i32
// CHECK: 'dip.corr_2d' op input, kernel, output and constant must have the same element type

%ret = arith.constant 0 : i32
return %ret : i32
}

0 comments on commit 566afc8

Please sign in to comment.