diff --git a/lib/Conversion/LowerDIP/LowerDIPPass.cpp b/lib/Conversion/LowerDIP/LowerDIPPass.cpp index f16629dfa..613595ff5 100644 --- a/lib/Conversion/LowerDIP/LowerDIPPass.cpp +++ b/lib/Conversion/LowerDIP/LowerDIPPass.cpp @@ -73,8 +73,16 @@ class DIPCorr2DOpLowering : public OpRewritePattern { auto boundaryOptionAttr = op.boundary_option(); Value strideVal = rewriter.create(loc, stride); - auto memRefTy = input.getType().cast(); - auto elemTy = memRefTy.getElementType(); + auto inElemTy = input.getType().cast().getElementType(); + auto kElemTy = kernel.getType().cast().getElementType(); + auto outElemTy = output.getType().cast().getElementType(); + auto constElemTy = constantValue.getType(); + if (inElemTy != kElemTy || kElemTy != outElemTy || outElemTy != constElemTy) { + return op->emitOpError() << "input, kernel, output and constant must have the same element type"; + } + // NB: we can infer element type for all operation to be the same as input + // since we verified that the operand types are the same + auto elemTy = inElemTy; IntegerType i1 = IntegerType::get(ctx, 1); @@ -99,6 +107,7 @@ class DIPCorr2DOpLowering : public OpRewritePattern { FloatType f32 = FloatType::getF32(ctx); IntegerType i32 = IntegerType::get(ctx, 32); Value zeroPaddingElem = {}; + // TODO: extend for other types and add a check for supported types if (elemTy.isF32()) { zeroPaddingElem = rewriter.create(loc, (APFloat)(float)0, f32); } else if (elemTy.isInteger(32)) {