Skip to content

Commit

Permalink
[WIP] #34 Some work to support arguments in functional calls. Still i…
Browse files Browse the repository at this point in the history
…n progress.
  • Loading branch information
pthomadakis committed Oct 22, 2023
1 parent e892ac3 commit 4977dd2
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 36 deletions.
2 changes: 2 additions & 0 deletions frontends/comet_dsl/comet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"

#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
Expand Down Expand Up @@ -384,6 +385,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,

// Finally lowering index tree to SCF dialect
optPM.addPass(mlir::comet::createLowerIndexTreeToSCFPass());
optPM.addPass(mlir::createTensorBufferizePass());
pm.addPass(mlir::func::createFuncBufferizePass()); // Needed for func

// Dump index tree dialect.
Expand Down
36 changes: 27 additions & 9 deletions frontends/comet_dsl/include/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ namespace tensorAlgebra
Expr_GetTime,
Expr_ForLoop,
Expr_ForEnd,
Expr_Mask
Expr_Mask,
Expr_FuncArg,
};

ExprAST(ExprASTKind kind, Location location)
Expand Down Expand Up @@ -156,6 +157,22 @@ namespace tensorAlgebra
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
};


class FuncArgAST : public ExprAST
{
Location loc;
std::string name;
VarType type;

public:
FuncArgAST(Location loc, llvm::StringRef name, VarType type)
: ExprAST(Expr_FuncArg, loc), loc(std::move(loc)), name(name), type(std::move(type)){}

llvm::StringRef getName() {return name;}
const VarType &getType() {return type;}

static bool classof(const ExprAST *C) { return C->getKind() == Expr_FuncArg; }
};
// /// Expression class for defining a variable.
// class VarDeclExprAST : public ExprAST
// {
Expand Down Expand Up @@ -474,15 +491,16 @@ namespace tensorAlgebra
class CallExprAST : public ExprAST
{
std::string Callee;
std::unique_ptr<ExprAST> Arg;
std::vector<std::unique_ptr<ExprAST>> Args;

public:
CallExprAST(Location loc, const std::string &Callee,
std::unique_ptr<ExprAST> Arg)
: ExprAST(Expr_Call, loc), Callee(Callee), Arg(std::move(Arg)) {}
std::vector<std::unique_ptr<ExprAST>> Args)
: ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}

llvm::StringRef getCallee() { return Callee; }
ExprAST *getArgs() { return Arg.get(); }
ExprAST* getArg(int index) { return Args[index].get(); }
size_t getNumArgs() {return Args.size();}
/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
};
Expand Down Expand Up @@ -594,17 +612,17 @@ namespace tensorAlgebra
{
Location location;
std::string name;
std::vector<std::unique_ptr<VariableExprAST>> args;
std::vector<std::unique_ptr<FuncArgAST>> args;

public:
PrototypeAST(Location location, const std::string &name,
std::vector<std::unique_ptr<VariableExprAST>> args)
std::vector<std::unique_ptr<FuncArgAST>> args)
: location(location), name(name), args(std::move(args)) {}

const Location &loc() { return location; }
const std::string &getName() const { return name; }
//TODO(gkestor): check VariableExprAST
const std::vector<std::unique_ptr<VariableExprAST>> &getArgs()
//TODO(gkestor): check FuncArgAST
const std::vector<std::unique_ptr<FuncArgAST>> &getArgs()
{
return args;
}
Expand Down
113 changes: 96 additions & 17 deletions frontends/comet_dsl/include/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,13 +495,13 @@ namespace tensorAlgebra
}
else
{
if (args.size() == 0)
{
args.push_back(nullptr);
}
// if (args.size() == 0)
// {
// args.push_back(nullptr);
// }
}
comet_debug() << "generate CallExprAST node\n ";
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args[0]));
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
}

/// primary
Expand Down Expand Up @@ -603,6 +603,7 @@ namespace tensorAlgebra
while (true)
{
int tokPrec = getTokPrecedence();
comet_debug() << lexer.getCurToken() << "\n";
comet_debug() << " tokPrec: " << tokPrec << ", exprPrec: " << exprPrec << "\n";
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
Expand Down Expand Up @@ -666,6 +667,11 @@ namespace tensorAlgebra
auto lhs = parsePrimary();
if (!lhs)
return nullptr;
// if(lexer.getCurToken() == ';')
// {
// comet_debug() << "return single operand\n";
// return lhs;
// }

comet_debug() << "finished lhs parse\n";
comet_debug() << " call parseBinOpRHS\n";
Expand Down Expand Up @@ -724,6 +730,7 @@ namespace tensorAlgebra
if (!type)
type = std::make_unique<VarType>();
lexer.consume(Token('='));
comet_debug() << "Parse declaration for " << id << "\n";
auto expr = parseExpression();
return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
std::move(*type), std::move(expr));
Expand Down Expand Up @@ -1341,9 +1348,9 @@ namespace tensorAlgebra
{
comet_debug() << __FILE__ << __LINE__ << " TensorOpExprAST rhs is Expr_Call\n";

//CallExprAST *call = llvm::cast<CallExprAST>(RHS.get());
//llvm::StringRef callee = call->getCallee();
//comet_debug() << __FILE__ << __LINE__ << " callee: " << callee << "\n";
CallExprAST *call = llvm::cast<CallExprAST>(RHS.get());
llvm::StringRef callee = call->getCallee();
comet_debug() << __FILE__ << __LINE__ << " callee: " << callee << "\n";
}
else if (RHS.get()->getKind() == tensorAlgebra::ExprAST::Expr_Transpose)
{
Expand Down Expand Up @@ -1468,6 +1475,7 @@ namespace tensorAlgebra
else if (lexer.getCurToken() == tok_identifier &&
lexer.lookAhead() == ' ')
{
comet_debug() << "ParseVarExpression lhs parse\n";
auto varOp = ParseVarExpression();
if (!varOp)
return nullptr;
Expand Down Expand Up @@ -1517,6 +1525,76 @@ namespace tensorAlgebra
return exprList;
}

std::unique_ptr<FuncArgAST> parseFuncArg()
{
auto tok = lexer.getCurToken();
if(tok == tok_tensor)
{
auto loc = lexer.getLastLocation();
lexer.getNextToken(); // eat Tensor

std::unique_ptr<VarType> type; // Type is optional, it can be inferred
if (lexer.getCurToken() == '<')
{
lexer.consume(Token('<')); // eat <
type = std::make_unique<VarType>();
if (lexer.getCurToken() == tok_double)
{
type->elt_ty = VarType::TY_DOUBLE;
}
else if (lexer.getCurToken() == tok_float)
{
type->elt_ty = VarType::TY_FLOAT;
}
else if (lexer.getCurToken() == tok_int)
{
type->elt_ty = VarType::TY_INT;
}
lexer.getNextToken(); // eat el_type
if (lexer.getCurToken() != '>')
return parseError<FuncArgAST>(">", "to end type");
lexer.getNextToken(); // eat >
}

if (lexer.getCurToken() != tok_identifier)
return parseError<FuncArgAST>("identifier",
"after 'Tensor' declaration");
std::string id(lexer.getId());
lexer.getNextToken(); // eat id

return std::make_unique<FuncArgAST>(std::move(loc), id, std::move(*type));
}
else if(tok == tok_int || tok == tok_float || tok == tok_double )
{
auto loc = lexer.getLastLocation();
std::unique_ptr<VarType> type;
type = std::make_unique<VarType>();
if (lexer.getCurToken() == tok_double)
{
type->elt_ty = VarType::TY_DOUBLE;
}
else if (lexer.getCurToken() == tok_float)
{
type->elt_ty = VarType::TY_FLOAT;
}
else if (lexer.getCurToken() == tok_int)
{
type->elt_ty = VarType::TY_INT;
}
lexer.getNextToken(); // eat type

std::string id(lexer.getId());
lexer.getNextToken(); // eat id

return std::make_unique<FuncArgAST>(std::move(loc), id, std::move(*type));
}
else
{
return parseError<FuncArgAST>("Unexpected token in function arguments");
}
}


/// prototype ::= def id '(' decl_list ')'
/// decl_list ::= identifier | identifier, decl_list
std::unique_ptr<PrototypeAST> parsePrototype()
Expand All @@ -1533,22 +1611,23 @@ namespace tensorAlgebra
return parseError<PrototypeAST>("(", "in prototype");
lexer.consume(Token('('));

std::vector<std::unique_ptr<VariableExprAST>> args;
std::vector<std::unique_ptr<FuncArgAST>> args;
if (lexer.getCurToken() != ')')
{
do
{
std::string name(lexer.getId());
auto loc = lexer.getLastLocation();
lexer.consume(tok_identifier);
auto decl = std::make_unique<VariableExprAST>(std::move(loc), name);
args.push_back(std::move(decl));
auto arg = parseFuncArg();
// std::string name(lexer.getId());
// auto loc = lexer.getLastLocation();
// lexer.consume(tok_identifier);
// auto decl = std::make_unique<VariableExprAST>(std::move(loc), name);
args.push_back(std::move(arg));
if (lexer.getCurToken() != ',')
break;
lexer.consume(Token(','));
if (lexer.getCurToken() != tok_identifier)
return parseError<PrototypeAST>(
"identifier", "after ',' in function parameter list");
// if (lexer.getCurToken() != tok_identifier)
// return parseError<PrototypeAST>(
// "identifier", "after ',' in function parameter list");
} while (true);
}
if (lexer.getCurToken() != ')')
Expand Down
48 changes: 41 additions & 7 deletions frontends/comet_dsl/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include "comet/Dialect/Utils/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Attributes.h"
Expand Down Expand Up @@ -370,10 +371,12 @@ namespace
// return nullptr;
// }

// function.
// // Declare all the function arguments in the symbol table.
for (const auto nameValue :
llvm::zip(protoArgs, entryBlock.getArguments()))
{
comet_debug() << "Proto Args "<< std::get<1>(nameValue) << "\n";
if (failed(declare(std::get<0>(nameValue)->getName(),
std::get<1>(nameValue))))
return nullptr;
Expand Down Expand Up @@ -1245,6 +1248,8 @@ namespace
{
if (!(expr = mlirGen(**ret.getExpr())))
return mlir::failure();

expr = builder.create<mlir::tensor::CastOp>(location, mlir::UnrankedTensorType::get(builder.getF64Type()), expr);
}

// Otherwise, this return operation has zero operands.
Expand Down Expand Up @@ -1306,13 +1311,15 @@ namespace
/// builtin. Other identifiers are assumed to be user-defined functions.
mlir::Value mlirGen(CallExprAST &call)
{
comet_debug()<< "CallExprAST\n";

llvm::StringRef callee = call.getCallee();
auto location = loc(call.loc());

mlir::Value sumVal;
if (callee == "SUM")
{
auto *expr = call.getArgs();
auto *expr = call.getArg(0);
// Check if it SUM(A[i,j]) or SUM(A[i,j] * B[j,k])
// Case 1: SUM(A[i,j])
if (llvm::isa<LabeledTensorExprAST>(expr))
Expand All @@ -1339,16 +1346,25 @@ namespace
}
else
{
auto *expr = call.getArgs();
if(expr)
std::vector<mlir::Value> expr_args;
comet_debug()<< "Generic Call\n";
comet_debug() <<"Num args: " << call.getNumArgs() << "\n";
// auto exprs = call.getArgs();
if(call.getNumArgs() > 0 )
{
assert(false && "functions with argument are currently not supported!");
for(size_t i = 0; i < call.getNumArgs(); i++)
{
auto res = builder.create<mlir::tensor::CastOp>(location, mlir::UnrankedTensorType::get(builder.getF64Type()), mlirGen(*call.getArg(i)));
expr_args.push_back(res);
}
comet_debug() <<"Num args: " << call.getNumArgs() << "\n";
// assert(false && "functions with argument are currently not supported!");
}
mlir::Value tensorValue;
tensorValue = mlir::Value();
ArrayRef<mlir::Value> args{};
if(tensorValue)
args = ArrayRef<mlir::Value> (tensorValue);
ArrayRef<mlir::Value> args(expr_args);
// if(tensorValue)
// args = ArrayRef<mlir::Value> (tensorValue);

auto c = functionMap.lookup(callee);
if(c.getFunctionType().getResults().size() > 0) // Function that returns a value
Expand Down Expand Up @@ -2291,6 +2307,13 @@ namespace
if (mlir::failed(mlirGenTensorFillRandom(loc(tensor_op->loc()), tensor_name)))
return mlir::success();
}
else
{
LabeledTensorExprAST *lhsLabeledTensorExprAST = llvm::cast<LabeledTensorExprAST>(tensor_op->getLHS());
auto call_res = mlirGen(*call);
auto lhs_tensor = symbolTable.lookup(lhsLabeledTensorExprAST->getTensorName());
builder.create<TensorSetOp>(loc(tensor_op->loc()), call_res, lhs_tensor);
}
// TODO: put check here, if the user mis-spells something...

continue;
Expand All @@ -2307,6 +2330,17 @@ namespace
mlirGen(*transpose, *lhsLabeledTensorExprAST);
continue;
}
else if(tensor_op->getRHS()->getKind() == ExprAST::ExprASTKind::Expr_Call)
{
comet_debug() << __LINE__ << " in TensorOpExprAST, rhs is Expr_Call\n";

LabeledTensorExprAST *lhsLabeledTensorExprAST = llvm::cast<LabeledTensorExprAST>(tensor_op->getLHS());
CallExprAST * call = llvm::cast<CallExprAST>(tensor_op->getRHS());
auto call_res = mlirGen(*call);
auto lhs_tensor = symbolTable.lookup(lhsLabeledTensorExprAST->getTensorName());
builder.create<TensorSetOp>(loc(tensor_op->loc()), call_res, lhs_tensor);
continue;
}
else if ((tensor_op->getRHS()->getKind() == ExprAST::ExprASTKind::Expr_LabeledTensor &&
tensor_op->getLHS()->getKind() == ExprAST::ExprASTKind::Expr_LabeledTensor)) // TODO: we should not reach this case
{
Expand Down
2 changes: 1 addition & 1 deletion frontends/comet_dsl/parser/AST.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ void ASTDumper::dump(CallExprAST *Node)
{
INDENT();
llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
dump(Node->getArgs());
// dump(Node->getArgs());
indent();
llvm::errs() << "]\n";
}
Expand Down
Loading

0 comments on commit 4977dd2

Please sign in to comment.