diff --git a/frontends/comet_dsl/comet.cpp b/frontends/comet_dsl/comet.cpp index ad8543d8..2fbdad37 100644 --- a/frontends/comet_dsl/comet.cpp +++ b/frontends/comet_dsl/comet.cpp @@ -39,6 +39,7 @@ #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" @@ -383,6 +384,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, // Finally lowering index tree to SCF dialect optPM.addPass(mlir::comet::createLowerIndexTreeToSCFPass()); + pm.addPass(mlir::func::createFuncBufferizePass()); // Needed for func // Dump index tree dialect. if (emitLoops) diff --git a/frontends/comet_dsl/include/Parser.h b/frontends/comet_dsl/include/Parser.h index 9d27a5cd..c64e508a 100644 --- a/frontends/comet_dsl/include/Parser.h +++ b/frontends/comet_dsl/include/Parser.h @@ -493,6 +493,13 @@ namespace tensorAlgebra } // CallExprAST is generated for random() } + else + { + if (args.size() == 0) + { + args.push_back(nullptr); + } + } comet_debug() << "generate CallExprAST node\n "; return std::make_unique(std::move(loc), name, std::move(args[0])); } diff --git a/frontends/comet_dsl/mlir/MLIRGen.cpp b/frontends/comet_dsl/mlir/MLIRGen.cpp index 8eb1d61a..d50030f9 100644 --- a/frontends/comet_dsl/mlir/MLIRGen.cpp +++ b/frontends/comet_dsl/mlir/MLIRGen.cpp @@ -34,6 +34,7 @@ #include "comet/Dialect/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Verifier.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -1336,6 +1337,32 @@ namespace sumVal = builder.create(location, builder.getF64Type(), tensorValue); } } + else + { + auto *expr = call.getArgs(); + if(expr) + { + assert(false && "functions with argument are currently not supported!"); + } + mlir::Value tensorValue; + tensorValue = mlir::Value(); + ArrayRef args{}; + if(tensorValue) + args = ArrayRef (tensorValue); + + auto c = functionMap.lookup(callee); + if(c.getFunctionType().getResults().size() > 0) // Function that returns a value + { + auto res = builder.create(location, c.getFunctionType().getResults()[0], callee, args); + sumVal = res.getResults()[0]; + } + else // Void function + { + builder.create(location, callee, args); + sumVal = mlir::Value(); + } + } + // comet_debug() << "Called: " << callee << "\n"; // Otherwise this is a call to a user-defined function. Calls to ser-defined // functions are mapped to a custom call that takes the callee name as an @@ -2298,8 +2325,13 @@ namespace // Generic expression dispatch codegen. comet_debug() << " expr->getKind(): " << expr->getKind() << "\n"; - if (!mlirGen(*expr)) - return mlir::failure(); + + // If calling a void function this will return null, thus we cannot count on this for + // error checking + mlirGen(*expr); + // return mlir::failure(); + // if (!mlirGen(*expr)) + // return mlir::failure(); } return mlir::success(); } diff --git a/include/comet/Dialect/TensorAlgebra/IR/TAOps.td b/include/comet/Dialect/TensorAlgebra/IR/TAOps.td index b7239f25..f930b9f8 100644 --- a/include/comet/Dialect/TensorAlgebra/IR/TAOps.td +++ b/include/comet/Dialect/TensorAlgebra/IR/TAOps.td @@ -860,7 +860,7 @@ def GenericCallOp : TA_Op<"generic_call", // The generic call operation returns a single value of TensorType or // StructType. - let results = (outs TA_AnyTensor); + let results = (outs Optional); // Specialize assembly printing and parsing using a declarative format. let assemblyFormat = [{ diff --git a/lib/Conversion/TensorAlgebraToSCF/LateLowering.cpp b/lib/Conversion/TensorAlgebraToSCF/LateLowering.cpp index fddf5050..37bc3ead 100644 --- a/lib/Conversion/TensorAlgebraToSCF/LateLowering.cpp +++ b/lib/Conversion/TensorAlgebraToSCF/LateLowering.cpp @@ -260,6 +260,40 @@ namespace } }; + class ReturnOpLowering : public ConversionPattern + { + public: + explicit ReturnOpLowering(MLIRContext *ctx) + : ConversionPattern(tensorAlgebra::PrintElapsedTimeOp::getOperationName(), 1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override + { + auto ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + + auto start = operands[0]; + auto end = operands[1]; + std::string printElapsedTimeStr = "printElapsedTime"; + auto f64Type = rewriter.getF64Type(); + + if (!hasFuncDeclaration(module, printElapsedTimeStr)) + { + auto printElapsedTimeFunc = FunctionType::get(ctx, {f64Type, f64Type}, {}); + // func @printElapsedTime(f64, f64) -> () + func::FuncOp func1 = func::FuncOp::create(op->getLoc(), printElapsedTimeStr, + printElapsedTimeFunc, ArrayRef{}); + func1.setPrivate(); + module.push_back(func1); + } + + rewriter.replaceOpWithNewOp(op, printElapsedTimeStr, SmallVector{}, ValueRange{start, end}); + + return success(); + } + }; + } // end anonymous namespace. /// This is a partial lowering to linear algebra of the tensor algebra operations that are diff --git a/lib/Conversion/TensorAlgebraToSCF/LowerFunc.cpp b/lib/Conversion/TensorAlgebraToSCF/LowerFunc.cpp index 7e348fca..30cdfde5 100644 --- a/lib/Conversion/TensorAlgebraToSCF/LowerFunc.cpp +++ b/lib/Conversion/TensorAlgebraToSCF/LowerFunc.cpp @@ -6,11 +6,30 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/Sequence.h" -using namespace mlir; +// *********** For debug purpose *********// +// #ifndef DEBUG_MODE_LOWER_FUNC +// #define DEBUG_MODE_LOWER_FUNC +// #endif + +#ifdef DEBUG_MODE_LOWER_FUNC +#define comet_debug() llvm::errs() << __FILE__ << " " << __LINE__ << " " +#define comet_pdump(n) \ + llvm::errs() << __FILE__ << " " << __LINE__ << " "; \ + n->dump() +#define comet_vdump(n) \ + llvm::errs() << __FILE__ << " " << __LINE__ << " "; \ + n.dump() +#else +#define comet_debug() llvm::nulls() +#define comet_pdump(n) +#define comet_vdump(n) +#endif +// *********** For debug purpose *********// //===----------------------------------------------------------------------===// // tensorAlgebra::FuncOp to func::FuncOp RewritePatterns @@ -25,20 +44,20 @@ namespace { struct FuncOpLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(tensorAlgebra::FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { // We only lower the main function as we expect that all other functions // have been inlined. - if (op.getName() != "main") - return failure(); - - // Verify that the given main has no inputs and results. - if (op.getNumArguments() || op.getFunctionType().getNumResults()) { - return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { - diag << "expected 'main' to have 0 inputs and 0 results"; - }); + if (op.getName() == "main") + { + // return failure(); + // Verify that the given main has no inputs and results. + if (op.getNumArguments() || op.getFunctionType().getNumResults()) { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "expected 'main' to have 0 inputs and 0 results"; + }); + } } // Create a new non-tensorAlgebra function, with the same region. @@ -78,15 +97,48 @@ struct ReturnOpLowering : public OpRewritePattern { PatternRewriter &rewriter) const final { // During this lowering, we expect that all function calls have been // inlined. - if (op.hasOperand()) - return failure(); + // if (op.hasOperand()) + // return failure(); + + if(op.hasOperand()) + { + rewriter.replaceOpWithNewOp(op, op.getOperands()); + } + else + { + rewriter.replaceOpWithNewOp(op); + } + + return success(); + } +}; + +struct GenericCallOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::tensorAlgebra::GenericCallOp op, + PatternRewriter &rewriter) const final { + + // During this lowering, we expect that all function calls have been + // inlined. + // if (op.hasOperand()) + // return failure(); // We lower "toy.return" directly to "func.return". - rewriter.replaceOpWithNewOp(op); + if(op.getResults().size() > 0) + { + auto res = rewriter.replaceOpWithNewOp(op, op->getAttrOfType("callee"), op.getType(0), op.getOperands()); + } + else + { + auto res = rewriter.replaceOpWithNewOp(op, op->getAttrOfType("callee"), mlir::TypeRange(), op.getOperands()); + } + return success(); } }; + void FuncOpLoweringPass::runOnOperation() { // The first thing to define is the conversion target. This will define the // final target for this lowering. @@ -110,7 +162,7 @@ void FuncOpLoweringPass::runOnOperation() { // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the Toy operations. RewritePatternSet patterns(&getContext()); - patterns.add( + patterns.add( &getContext()); // With the target and rewrite patterns defined, we can now attempt the diff --git a/lib/Dialect/TensorAlgebra/IR/TADialect.cpp b/lib/Dialect/TensorAlgebra/IR/TADialect.cpp index 29258124..401880a0 100644 --- a/lib/Dialect/TensorAlgebra/IR/TADialect.cpp +++ b/lib/Dialect/TensorAlgebra/IR/TADialect.cpp @@ -235,7 +235,7 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, StringRef callee, ArrayRef arguments) { // Generic call always returns an unranked Tensor initially. - state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + // state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands(arguments); state.addAttribute("callee", mlir::SymbolRefAttr::get(builder.getContext(), callee));