diff --git a/frontends/comet_dsl/comet.cpp b/frontends/comet_dsl/comet.cpp index 2fbdad37..f983fafb 100644 --- a/frontends/comet_dsl/comet.cpp +++ b/frontends/comet_dsl/comet.cpp @@ -40,6 +40,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Func/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/Transforms/Passes.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" @@ -384,6 +385,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, // Finally lowering index tree to SCF dialect optPM.addPass(mlir::comet::createLowerIndexTreeToSCFPass()); + optPM.addPass(mlir::createTensorBufferizePass()); pm.addPass(mlir::func::createFuncBufferizePass()); // Needed for func // Dump index tree dialect. diff --git a/frontends/comet_dsl/include/AST.h b/frontends/comet_dsl/include/AST.h index fc7faf85..aceb0d02 100644 --- a/frontends/comet_dsl/include/AST.h +++ b/frontends/comet_dsl/include/AST.h @@ -89,7 +89,8 @@ namespace tensorAlgebra Expr_GetTime, Expr_ForLoop, Expr_ForEnd, - Expr_Mask + Expr_Mask, + Expr_FuncArg, }; ExprAST(ExprASTKind kind, Location location) @@ -156,6 +157,22 @@ namespace tensorAlgebra static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; } }; + + class FuncArgAST : public ExprAST + { + Location loc; + std::string name; + VarType type; + + public: + FuncArgAST(Location loc, llvm::StringRef name, VarType type) + : ExprAST(Expr_FuncArg, loc), loc(std::move(loc)), name(name), type(std::move(type)){} + + llvm::StringRef getName() {return name;} + const VarType &getType() {return type;} + + static bool classof(const ExprAST *C) { return C->getKind() == Expr_FuncArg; } + }; // /// Expression class for defining a variable. // class VarDeclExprAST : public ExprAST // { @@ -474,15 +491,16 @@ namespace tensorAlgebra class CallExprAST : public ExprAST { std::string Callee; - std::unique_ptr Arg; + std::vector> Args; public: CallExprAST(Location loc, const std::string &Callee, - std::unique_ptr Arg) - : ExprAST(Expr_Call, loc), Callee(Callee), Arg(std::move(Arg)) {} + std::vector> Args) + : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {} llvm::StringRef getCallee() { return Callee; } - ExprAST *getArgs() { return Arg.get(); } + ExprAST* getArg(int index) { return Args[index].get(); } + size_t getNumArgs() {return Args.size();} /// LLVM style RTTI static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; } }; @@ -594,17 +612,17 @@ namespace tensorAlgebra { Location location; std::string name; - std::vector> args; + std::vector> args; public: PrototypeAST(Location location, const std::string &name, - std::vector> args) + std::vector> args) : location(location), name(name), args(std::move(args)) {} const Location &loc() { return location; } const std::string &getName() const { return name; } - //TODO(gkestor): check VariableExprAST - const std::vector> &getArgs() + //TODO(gkestor): check FuncArgAST + const std::vector> &getArgs() { return args; } diff --git a/frontends/comet_dsl/include/Parser.h b/frontends/comet_dsl/include/Parser.h index c64e508a..429b04be 100644 --- a/frontends/comet_dsl/include/Parser.h +++ b/frontends/comet_dsl/include/Parser.h @@ -495,13 +495,13 @@ namespace tensorAlgebra } else { - if (args.size() == 0) - { - args.push_back(nullptr); - } + // if (args.size() == 0) + // { + // args.push_back(nullptr); + // } } comet_debug() << "generate CallExprAST node\n "; - return std::make_unique(std::move(loc), name, std::move(args[0])); + return std::make_unique(std::move(loc), name, std::move(args)); } /// primary @@ -603,6 +603,7 @@ namespace tensorAlgebra while (true) { int tokPrec = getTokPrecedence(); + comet_debug() << lexer.getCurToken() << "\n"; comet_debug() << " tokPrec: " << tokPrec << ", exprPrec: " << exprPrec << "\n"; // If this is a binop that binds at least as tightly as the current binop, // consume it, otherwise we are done. @@ -666,6 +667,11 @@ namespace tensorAlgebra auto lhs = parsePrimary(); if (!lhs) return nullptr; + // if(lexer.getCurToken() == ';') + // { + // comet_debug() << "return single operand\n"; + // return lhs; + // } comet_debug() << "finished lhs parse\n"; comet_debug() << " call parseBinOpRHS\n"; @@ -724,6 +730,7 @@ namespace tensorAlgebra if (!type) type = std::make_unique(); lexer.consume(Token('=')); + comet_debug() << "Parse declaration for " << id << "\n"; auto expr = parseExpression(); return std::make_unique(std::move(loc), std::move(id), std::move(*type), std::move(expr)); @@ -1341,9 +1348,9 @@ namespace tensorAlgebra { comet_debug() << __FILE__ << __LINE__ << " TensorOpExprAST rhs is Expr_Call\n"; - //CallExprAST *call = llvm::cast(RHS.get()); - //llvm::StringRef callee = call->getCallee(); - //comet_debug() << __FILE__ << __LINE__ << " callee: " << callee << "\n"; + CallExprAST *call = llvm::cast(RHS.get()); + llvm::StringRef callee = call->getCallee(); + comet_debug() << __FILE__ << __LINE__ << " callee: " << callee << "\n"; } else if (RHS.get()->getKind() == tensorAlgebra::ExprAST::Expr_Transpose) { @@ -1468,6 +1475,7 @@ namespace tensorAlgebra else if (lexer.getCurToken() == tok_identifier && lexer.lookAhead() == ' ') { + comet_debug() << "ParseVarExpression lhs parse\n"; auto varOp = ParseVarExpression(); if (!varOp) return nullptr; @@ -1517,6 +1525,76 @@ namespace tensorAlgebra return exprList; } + std::unique_ptr parseFuncArg() + { + auto tok = lexer.getCurToken(); + if(tok == tok_tensor) + { + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat Tensor + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') + { + lexer.consume(Token('<')); // eat < + type = std::make_unique(); + if (lexer.getCurToken() == tok_double) + { + type->elt_ty = VarType::TY_DOUBLE; + } + else if (lexer.getCurToken() == tok_float) + { + type->elt_ty = VarType::TY_FLOAT; + } + else if (lexer.getCurToken() == tok_int) + { + type->elt_ty = VarType::TY_INT; + } + lexer.getNextToken(); // eat el_type + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + } + + if (lexer.getCurToken() != tok_identifier) + return parseError("identifier", + "after 'Tensor' declaration"); + std::string id(lexer.getId()); + lexer.getNextToken(); // eat id + + return std::make_unique(std::move(loc), id, std::move(*type)); + } + else if(tok == tok_int || tok == tok_float || tok == tok_double ) + { + auto loc = lexer.getLastLocation(); + std::unique_ptr type; + type = std::make_unique(); + if (lexer.getCurToken() == tok_double) + { + type->elt_ty = VarType::TY_DOUBLE; + } + else if (lexer.getCurToken() == tok_float) + { + type->elt_ty = VarType::TY_FLOAT; + } + else if (lexer.getCurToken() == tok_int) + { + type->elt_ty = VarType::TY_INT; + } + lexer.getNextToken(); // eat type + + std::string id(lexer.getId()); + lexer.getNextToken(); // eat id + + return std::make_unique(std::move(loc), id, std::move(*type)); + } + else + { + return parseError("Unexpected token in function arguments"); + } + } + + /// prototype ::= def id '(' decl_list ')' /// decl_list ::= identifier | identifier, decl_list std::unique_ptr parsePrototype() @@ -1533,22 +1611,23 @@ namespace tensorAlgebra return parseError("(", "in prototype"); lexer.consume(Token('(')); - std::vector> args; + std::vector> args; if (lexer.getCurToken() != ')') { do { - std::string name(lexer.getId()); - auto loc = lexer.getLastLocation(); - lexer.consume(tok_identifier); - auto decl = std::make_unique(std::move(loc), name); - args.push_back(std::move(decl)); + auto arg = parseFuncArg(); + // std::string name(lexer.getId()); + // auto loc = lexer.getLastLocation(); + // lexer.consume(tok_identifier); + // auto decl = std::make_unique(std::move(loc), name); + args.push_back(std::move(arg)); if (lexer.getCurToken() != ',') break; lexer.consume(Token(',')); - if (lexer.getCurToken() != tok_identifier) - return parseError( - "identifier", "after ',' in function parameter list"); + // if (lexer.getCurToken() != tok_identifier) + // return parseError( + // "identifier", "after ',' in function parameter list"); } while (true); } if (lexer.getCurToken() != ')') diff --git a/frontends/comet_dsl/mlir/MLIRGen.cpp b/frontends/comet_dsl/mlir/MLIRGen.cpp index d50030f9..5374fb79 100644 --- a/frontends/comet_dsl/mlir/MLIRGen.cpp +++ b/frontends/comet_dsl/mlir/MLIRGen.cpp @@ -34,6 +34,7 @@ #include "comet/Dialect/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Verifier.h" #include "mlir/IR/Attributes.h" @@ -370,10 +371,12 @@ namespace // return nullptr; // } + // function. // // Declare all the function arguments in the symbol table. for (const auto nameValue : llvm::zip(protoArgs, entryBlock.getArguments())) { + comet_debug() << "Proto Args "<< std::get<1>(nameValue) << "\n"; if (failed(declare(std::get<0>(nameValue)->getName(), std::get<1>(nameValue)))) return nullptr; @@ -1245,6 +1248,8 @@ namespace { if (!(expr = mlirGen(**ret.getExpr()))) return mlir::failure(); + + expr = builder.create(location, mlir::UnrankedTensorType::get(builder.getF64Type()), expr); } // Otherwise, this return operation has zero operands. @@ -1306,13 +1311,15 @@ namespace /// builtin. Other identifiers are assumed to be user-defined functions. mlir::Value mlirGen(CallExprAST &call) { + comet_debug()<< "CallExprAST\n"; + llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); mlir::Value sumVal; if (callee == "SUM") { - auto *expr = call.getArgs(); + auto *expr = call.getArg(0); // Check if it SUM(A[i,j]) or SUM(A[i,j] * B[j,k]) // Case 1: SUM(A[i,j]) if (llvm::isa(expr)) @@ -1339,16 +1346,25 @@ namespace } else { - auto *expr = call.getArgs(); - if(expr) + std::vector expr_args; + comet_debug()<< "Generic Call\n"; + comet_debug() <<"Num args: " << call.getNumArgs() << "\n"; + // auto exprs = call.getArgs(); + if(call.getNumArgs() > 0 ) { - assert(false && "functions with argument are currently not supported!"); + for(size_t i = 0; i < call.getNumArgs(); i++) + { + auto res = builder.create(location, mlir::UnrankedTensorType::get(builder.getF64Type()), mlirGen(*call.getArg(i))); + expr_args.push_back(res); + } + comet_debug() <<"Num args: " << call.getNumArgs() << "\n"; + // assert(false && "functions with argument are currently not supported!"); } mlir::Value tensorValue; tensorValue = mlir::Value(); - ArrayRef args{}; - if(tensorValue) - args = ArrayRef (tensorValue); + ArrayRef args(expr_args); + // if(tensorValue) + // args = ArrayRef (tensorValue); auto c = functionMap.lookup(callee); if(c.getFunctionType().getResults().size() > 0) // Function that returns a value @@ -2291,6 +2307,13 @@ namespace if (mlir::failed(mlirGenTensorFillRandom(loc(tensor_op->loc()), tensor_name))) return mlir::success(); } + else + { + LabeledTensorExprAST *lhsLabeledTensorExprAST = llvm::cast(tensor_op->getLHS()); + auto call_res = mlirGen(*call); + auto lhs_tensor = symbolTable.lookup(lhsLabeledTensorExprAST->getTensorName()); + builder.create(loc(tensor_op->loc()), call_res, lhs_tensor); + } // TODO: put check here, if the user mis-spells something... continue; @@ -2307,6 +2330,17 @@ namespace mlirGen(*transpose, *lhsLabeledTensorExprAST); continue; } + else if(tensor_op->getRHS()->getKind() == ExprAST::ExprASTKind::Expr_Call) + { + comet_debug() << __LINE__ << " in TensorOpExprAST, rhs is Expr_Call\n"; + + LabeledTensorExprAST *lhsLabeledTensorExprAST = llvm::cast(tensor_op->getLHS()); + CallExprAST * call = llvm::cast(tensor_op->getRHS()); + auto call_res = mlirGen(*call); + auto lhs_tensor = symbolTable.lookup(lhsLabeledTensorExprAST->getTensorName()); + builder.create(loc(tensor_op->loc()), call_res, lhs_tensor); + continue; + } else if ((tensor_op->getRHS()->getKind() == ExprAST::ExprASTKind::Expr_LabeledTensor && tensor_op->getLHS()->getKind() == ExprAST::ExprASTKind::Expr_LabeledTensor)) // TODO: we should not reach this case { diff --git a/frontends/comet_dsl/parser/AST.cpp b/frontends/comet_dsl/parser/AST.cpp index 1d12d250..ee131006 100644 --- a/frontends/comet_dsl/parser/AST.cpp +++ b/frontends/comet_dsl/parser/AST.cpp @@ -249,7 +249,7 @@ void ASTDumper::dump(CallExprAST *Node) { INDENT(); llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n"; - dump(Node->getArgs()); + // dump(Node->getArgs()); indent(); llvm::errs() << "]\n"; } diff --git a/lib/Conversion/TensorAlgebraToSCF/LowerFunc.cpp b/lib/Conversion/TensorAlgebraToSCF/LowerFunc.cpp index 30cdfde5..0acaca17 100644 --- a/lib/Conversion/TensorAlgebraToSCF/LowerFunc.cpp +++ b/lib/Conversion/TensorAlgebraToSCF/LowerFunc.cpp @@ -10,6 +10,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/Sequence.h" +using namespace mlir; // *********** For debug purpose *********// // #ifndef DEBUG_MODE_LOWER_FUNC @@ -127,7 +128,7 @@ struct GenericCallOpLowering : public OpRewritePattern 0) { - auto res = rewriter.replaceOpWithNewOp(op, op->getAttrOfType("callee"), op.getType(0), op.getOperands()); + auto res = rewriter.replaceOpWithNewOp(op, op->getAttrOfType("callee"), mlir::UnrankedTensorType::get(rewriter.getF64Type()), op.getOperands()); } else { diff --git a/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp b/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp index e392e158..4fba112d 100644 --- a/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp +++ b/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp @@ -686,7 +686,7 @@ namespace lhs = rewriter.create(loc, lhs, alloc_zero_loc); } - assert((rhsType.isF64() || rhsType.isa()) && (lhsType.isF64() || lhsType.isa()) && "Scalar Operands data type must be either F64 or memref"); + // assert((rhsType.isF64() || rhsType.isa()) && (lhsType.isF64() || lhsType.isa()) && "Scalar Operands data type must be either F64 or memref"); Value res; bool res_comes_from_setop = false;