From f14c977d781c47ec3d9e0e8d32af356a0585790e Mon Sep 17 00:00:00 2001 From: Polykarpos Thomadakis Date: Fri, 30 Aug 2024 09:58:01 -0700 Subject: [PATCH] Scalars will now expand to fit the other operands shape if it is a tensor; i.e., A[M,N] + 1 is a valid operation --- frontends/comet_dsl/mlir/MLIRGen.cpp | 19 ++- .../comet/Dialect/TensorAlgebra/IR/TAOps.td | 2 +- .../TensorAlgebraToSCF/TensorAlgebraToSCF.cpp | 135 +++++++++++------- lib/Dialect/TensorAlgebra/IR/TADialect.cpp | 53 ++++--- 4 files changed, 131 insertions(+), 78 deletions(-) diff --git a/frontends/comet_dsl/mlir/MLIRGen.cpp b/frontends/comet_dsl/mlir/MLIRGen.cpp index 8eaee967..76db2b1c 100644 --- a/frontends/comet_dsl/mlir/MLIRGen.cpp +++ b/frontends/comet_dsl/mlir/MLIRGen.cpp @@ -46,6 +46,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/raw_ostream.h" +#include #include #include #include /// for random num generation @@ -467,9 +468,21 @@ namespace default: comet_debug() << "ERROR: unsupported operator type: ASCII Code(" << binop.getOp() << ")\n"; } + mlir::StringAttr opAttr = builder.getStringAttr(op); - mlir::Type elementType = builder.getF64Type(); - auto returnDataType = mlir::RankedTensorType::get(1, elementType); + mlir::RankedTensorType returnDataType; + if(lhs.getType().cast().getShape() != rhs.getType().cast().getShape()) + { + returnDataType = lhs.getType().cast(); + auto bcastRhs = builder.create(location, returnDataType, mlir::cast(rhs.getDefiningOp()).getValueAttr()); + comet_vdump(bcastRhs); + rhs.replaceAllUsesWith(bcastRhs); + rhs = bcastRhs; + } + else { + mlir::Type elementType = builder.getF64Type(); + returnDataType = mlir::RankedTensorType::get(1, elementType); + } comet_vdump(rhs); comet_vdump(lhs); @@ -481,7 +494,7 @@ namespace comet_debug() << "creating a new variable declaration, since the user did not declare it\n"; double data = 0.0; - auto dataAttribute = mlir::DenseElementsAttr::get(returnDataType, llvm::ArrayRef(data)); + auto dataAttribute = mlir::DenseElementsAttr::get(mlir::RankedTensorType::get({1}, builder.getF64Type()), llvm::ArrayRef(data)); auto denseConst = builder.create(location, returnDataType, dataAttribute); theOutput = denseConst; diff --git a/include/comet/Dialect/TensorAlgebra/IR/TAOps.td b/include/comet/Dialect/TensorAlgebra/IR/TAOps.td index 54ff6abd..40398645 100644 --- a/include/comet/Dialect/TensorAlgebra/IR/TAOps.td +++ b/include/comet/Dialect/TensorAlgebra/IR/TAOps.td @@ -260,7 +260,7 @@ def DenseConstantOp : TA_Op<"constant", [Pure]> { let results = (outs F64Tensor); /// Indicate that the operation has a custom parser and printer method. - let hasCustomAssemblyFormat = 1; + // let hasCustomAssemblyFormat = 1; let builders = [ OpBuilder<(ins "DenseElementsAttr":$value), diff --git a/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp b/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp index 7c6e27eb..09dd38fb 100644 --- a/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp +++ b/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp @@ -117,55 +117,64 @@ namespace /// Create these constants up-front to avoid large amounts of redundant /// operations. auto valueShape = memRefType.getShape(); - SmallVector constantIndices; - - if (!valueShape.empty()) + auto constTensor = op.getValue().getType().cast(); + if(constTensor.getRank() == 1 && constTensor.getDimSize(0) == 1) { - for (auto i : llvm::seq( - 0, *std::max_element(valueShape.begin(), valueShape.end()))) - constantIndices.push_back(rewriter.create(loc, i)); + auto float_attr = *constantValue.getValues().begin(); + auto f_val = float_attr.getValue(); + auto val = rewriter.create(op->getLoc(), f_val, rewriter.getF64Type()); + rewriter.create(op->getLoc(), ValueRange(val), ValueRange(alloc)); } - else + else { - /// This is the case of a tensor of rank 0. - constantIndices.push_back(rewriter.create(loc, 0)); - } + SmallVector constantIndices; - /// The constant operation represents a multi-dimensional constant, so we - /// will need to generate a store for each of the elements. The following - /// functor recursively walks the dimensions of the constant shape, - /// generating a store when the recursion hits the base case. - SmallVector indices; - auto valueIt = constantValue.getValues().begin(); - std::function storeElements = [&](uint64_t dimension) - { - /// The last dimension is the base case of the recursion, at this point - /// we store the element at the given index. - if (dimension == valueShape.size()) + if (!valueShape.empty()) { - rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, - llvm::ArrayRef(indices)); - return; + for (auto i : llvm::seq( + 0, *std::max_element(valueShape.begin(), valueShape.end()))) + constantIndices.push_back(rewriter.create(loc, i)); } - - /// Otherwise, iterate over the current dimension and add the indices to - /// the list. - for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) + else { - indices.push_back(constantIndices[i]); - storeElements(dimension + 1); - indices.pop_back(); + /// This is the case of a tensor of rank 0. + constantIndices.push_back(rewriter.create(loc, 0)); } - }; + /// The constant operation represents a multi-dimensional constant, so we + /// will need to generate a store for each of the elements. The following + /// functor recursively walks the dimensions of the constant shape, + /// generating a store when the recursion hits the base case. + SmallVector indices; + auto valueIt = constantValue.getValues().begin(); + std::function storeElements = [&](uint64_t dimension) + { + /// The last dimension is the base case of the recursion, at this point + /// we store the element at the given index. + if (dimension == valueShape.size()) + { + rewriter.create( + loc, rewriter.create(loc, *valueIt++), alloc, + llvm::ArrayRef(indices)); + return; + } - /// Start the element storing recursion from the first dimension. - storeElements(/*dimension=*/0); + /// Otherwise, iterate over the current dimension and add the indices to + /// the list. + for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) + { + indices.push_back(constantIndices[i]); + storeElements(dimension + 1); + indices.pop_back(); + } + }; + + /// Start the element storing recursion from the first dimension. + storeElements(/*dimension=*/0); + } /// Replace this operation with the generated alloc. - op.replaceAllUsesWith(alloc); + op->replaceAllUsesWith(rewriter.create(op->getLoc(),alloc)); rewriter.eraseOp(op); - comet_debug() << "ConstantOpLowering ends\n"; return success(); } @@ -624,25 +633,33 @@ namespace comet_vdump(rhs); comet_vdump(lhs); - auto rhsType = op->getOperand(0).getType(); - auto lhsType = op->getOperand(1).getType(); + // auto rh = op->getOperand(0).getType(); + // auto lh = op->getOperand(1).getType(); [[maybe_unused]] auto f64Type = rewriter.getF64Type(); Value const_index_0 = rewriter.create(loc, 0); comet_vdump(const_index_0); std::vector alloc_zero_loc = {const_index_0}; - if (rhsType.isa()) + if (auto toTensorOp = llvm::dyn_cast_if_present(rhs.getDefiningOp())) { - comet_debug() << "RHS is a tensor\n"; - rhs = rewriter.create(loc, rhs, alloc_zero_loc); - comet_vdump(rhs); + rhs = toTensorOp.getMemref(); + // comet_debug() << "RHS is a tensor\n"; + // rhs = rewriter.create(loc, rhs, alloc_zero_loc); + // comet_vdump(rhs); } - if (lhsType.isa()) + if (auto toTensorOp = llvm::dyn_cast_if_present(lhs.getDefiningOp())) { - comet_debug() << "LHS is a tensor\n"; - lhs = rewriter.create(loc, lhs, alloc_zero_loc); + lhs = toTensorOp.getMemref(); + // comet_debug() << "RHS is a tensor\n"; + // rhs = rewriter.create(loc, rhs, alloc_zero_loc); + // comet_vdump(rhs); } + // if (lhsType.isa()) + // { + // comet_debug() << "LHS is a tensor\n"; + // lhs = rewriter.create(loc, lhs, alloc_zero_loc); + // } Value res; bool res_comes_from_setop = false; @@ -652,7 +669,18 @@ namespace comet_pdump(u); if (isa(u)) { + // u->dump(); + // u->getBlock()->dump(); res = cast(u).getOperation()->getOperand(1); + // (++res.getUsers().begin())->dump(); + if(!res.getUsers().empty() && isa(*(++res.getUsers().begin()))) + { + res = cast(*(++res.getUsers().begin())).getRhs(); + } + if(auto toTensor = mlir::dyn_cast_or_null(res.getDefiningOp())) + { + res = toTensor.getMemref(); + } comet_debug() << "Result from SetOp:\n"; comet_vdump(res); res_comes_from_setop = true; @@ -675,19 +703,23 @@ namespace Value res_val; if (op_attr.compare("+") == 0) { - res_val = rewriter.create(loc, rhs, lhs); + rewriter.create(loc, ValueRange{lhs, rhs}, ValueRange(res)); + // res_val = rewriter.create(loc, lhs, rhs); } else if (op_attr.compare("-") == 0) { - res_val = rewriter.create(loc, lhs, rhs); + rewriter.create(loc, ValueRange{lhs, rhs}, ValueRange(res)); + // res_val = rewriter.create(loc, lhs, rhs); } else if (op_attr.compare("*") == 0) { - res_val = rewriter.create(loc, rhs, lhs); + rewriter.create(loc, ValueRange{lhs, rhs}, ValueRange(res)); + // res_val = rewriter.create(loc, lhs, rhs); } else if (op_attr.compare("/") == 0) { - res_val = rewriter.create(loc, lhs, rhs); + rewriter.create(loc, ValueRange{lhs, rhs}, ValueRange(res)); + // res_val = rewriter.create(loc, lhs, rhs); } else { @@ -696,7 +728,8 @@ namespace comet_vdump(res_val); /// store res_val to res - [[maybe_unused]] auto storeOp = rewriter.create(loc, res_val, res, alloc_zero_loc); + // rewriter.create(loc, res_val, res); + // [[maybe_unused]] auto storeOp = rewriter.create(loc, res_val, res, alloc_zero_loc); comet_vdump(storeOp); op.replaceAllUsesWith(res); diff --git a/lib/Dialect/TensorAlgebra/IR/TADialect.cpp b/lib/Dialect/TensorAlgebra/IR/TADialect.cpp index 48402fa4..7eba8c86 100644 --- a/lib/Dialect/TensorAlgebra/IR/TADialect.cpp +++ b/lib/Dialect/TensorAlgebra/IR/TADialect.cpp @@ -60,26 +60,26 @@ void DenseConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &stat /// or `false` on success. This allows for easily chaining together a set of /// parser rules. These rules are used to populate an `mlir::OperationState` /// similarly to the `build` methods described above. -mlir::ParseResult DenseConstantOp::parse(mlir::OpAsmParser &parser, - mlir::OperationState &result) -{ - mlir::DenseElementsAttr value; - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseAttribute(value, "value", result.attributes)) - return failure(); - - result.addTypes(value.getType()); - return success(); -} - -/// The 'OpAsmPrinter' class is a stream that allows for formatting -/// strings, attributes, operands, types, etc. -void DenseConstantOp::print(mlir::OpAsmPrinter &printer) -{ - printer << " "; - printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); - printer << getValue(); -} +// mlir::ParseResult DenseConstantOp::parse(mlir::OpAsmParser &parser, +// mlir::OperationState &result) +// { +// mlir::DenseElementsAttr value; +// if (parser.parseOptionalAttrDict(result.attributes) || +// parser.parseAttribute(value, "value", result.attributes)) +// return failure(); + +// result.addTypes(value.getType()); +// return success(); +// } + +// /// The 'OpAsmPrinter' class is a stream that allows for formatting +// /// strings, attributes, operands, types, etc. +// void DenseConstantOp::print(mlir::OpAsmPrinter &printer) +// { +// printer << " "; +// printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); +// printer << getValue(); +// } /// Verifier for the constant operation. This corresponds to the /// `let hasVerifier = 1` in the op definition. @@ -96,9 +96,16 @@ mlir::LogicalResult DenseConstantOp::verify() auto attrType = getValue().getType().cast(); if (attrType.getRank() != resultType.getRank()) { - return emitOpError("return type must match the one of the attached value " - "attribute: ") - << attrType.getRank() << " != " << resultType.getRank(); + if(!(attrType.getRank() == 1 && attrType.getDimSize(0) == 1)) + { + return emitOpError("return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + else + { + return mlir::success(); + } } /// Check that each of the dimensions match between the two types.