Skip to content

Commit

Permalink
Scalars will now expand to fit the other operands shape if it is a te…
Browse files Browse the repository at this point in the history
…nsor; i.e., A[M,N] + 1 is a valid operation
  • Loading branch information
pthomadakis committed Aug 30, 2024
1 parent cfd9a7a commit f14c977
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 78 deletions.
19 changes: 16 additions & 3 deletions frontends/comet_dsl/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include <map>
#include <numeric>
#include <cstdlib> /// for random num generation
Expand Down Expand Up @@ -467,9 +468,21 @@ namespace
default:
comet_debug() << "ERROR: unsupported operator type: ASCII Code(" << binop.getOp() << ")\n";
}

mlir::StringAttr opAttr = builder.getStringAttr(op);
mlir::Type elementType = builder.getF64Type();
auto returnDataType = mlir::RankedTensorType::get(1, elementType);
mlir::RankedTensorType returnDataType;
if(lhs.getType().cast<mlir::RankedTensorType>().getShape() != rhs.getType().cast<mlir::RankedTensorType>().getShape())
{
returnDataType = lhs.getType().cast<mlir::RankedTensorType>();
auto bcastRhs = builder.create<DenseConstantOp>(location, returnDataType, mlir::cast<DenseConstantOp>(rhs.getDefiningOp()).getValueAttr());
comet_vdump(bcastRhs);
rhs.replaceAllUsesWith(bcastRhs);
rhs = bcastRhs;
}
else {
mlir::Type elementType = builder.getF64Type();
returnDataType = mlir::RankedTensorType::get(1, elementType);
}
comet_vdump(rhs);
comet_vdump(lhs);

Expand All @@ -481,7 +494,7 @@ namespace
comet_debug() << "creating a new variable declaration, since the user did not declare it\n";

double data = 0.0;
auto dataAttribute = mlir::DenseElementsAttr::get(returnDataType, llvm::ArrayRef(data));
auto dataAttribute = mlir::DenseElementsAttr::get(mlir::RankedTensorType::get({1}, builder.getF64Type()), llvm::ArrayRef(data));
auto denseConst = builder.create<DenseConstantOp>(location, returnDataType, dataAttribute);

theOutput = denseConst;
Expand Down
2 changes: 1 addition & 1 deletion include/comet/Dialect/TensorAlgebra/IR/TAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def DenseConstantOp : TA_Op<"constant", [Pure]> {
let results = (outs F64Tensor);

/// Indicate that the operation has a custom parser and printer method.
let hasCustomAssemblyFormat = 1;
// let hasCustomAssemblyFormat = 1;

let builders = [
OpBuilder<(ins "DenseElementsAttr":$value),
Expand Down
135 changes: 84 additions & 51 deletions lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,55 +117,64 @@ namespace
/// Create these constants up-front to avoid large amounts of redundant
/// operations.
auto valueShape = memRefType.getShape();
SmallVector<Value, 8> constantIndices;

if (!valueShape.empty())
auto constTensor = op.getValue().getType().cast<mlir::TensorType>();
if(constTensor.getRank() == 1 && constTensor.getDimSize(0) == 1)
{
for (auto i : llvm::seq<int64_t>(
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
auto float_attr = *constantValue.getValues<FloatAttr>().begin();
auto f_val = float_attr.getValue();
auto val = rewriter.create<ConstantFloatOp>(op->getLoc(), f_val, rewriter.getF64Type());
rewriter.create<linalg::FillOp>(op->getLoc(), ValueRange(val), ValueRange(alloc));
}
else
else
{
/// This is the case of a tensor of rank 0.
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
}
SmallVector<Value, 8> constantIndices;

/// The constant operation represents a multi-dimensional constant, so we
/// will need to generate a store for each of the elements. The following
/// functor recursively walks the dimensions of the constant shape,
/// generating a store when the recursion hits the base case.
SmallVector<Value, 2> indices;
auto valueIt = constantValue.getValues<FloatAttr>().begin();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension)
{
/// The last dimension is the base case of the recursion, at this point
/// we store the element at the given index.
if (dimension == valueShape.size())
if (!valueShape.empty())
{
rewriter.create<memref::StoreOp>(
loc, rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
llvm::ArrayRef(indices));
return;
for (auto i : llvm::seq<int64_t>(
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
}

/// Otherwise, iterate over the current dimension and add the indices to
/// the list.
for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i)
else
{
indices.push_back(constantIndices[i]);
storeElements(dimension + 1);
indices.pop_back();
/// This is the case of a tensor of rank 0.
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
}
};
/// The constant operation represents a multi-dimensional constant, so we
/// will need to generate a store for each of the elements. The following
/// functor recursively walks the dimensions of the constant shape,
/// generating a store when the recursion hits the base case.
SmallVector<Value, 2> indices;
auto valueIt = constantValue.getValues<FloatAttr>().begin();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension)
{
/// The last dimension is the base case of the recursion, at this point
/// we store the element at the given index.
if (dimension == valueShape.size())
{
rewriter.create<memref::StoreOp>(
loc, rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
llvm::ArrayRef(indices));
return;
}

/// Start the element storing recursion from the first dimension.
storeElements(/*dimension=*/0);
/// Otherwise, iterate over the current dimension and add the indices to
/// the list.
for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i)
{
indices.push_back(constantIndices[i]);
storeElements(dimension + 1);
indices.pop_back();
}
};

/// Start the element storing recursion from the first dimension.
storeElements(/*dimension=*/0);
}

/// Replace this operation with the generated alloc.
op.replaceAllUsesWith(alloc);
op->replaceAllUsesWith(rewriter.create<ToTensorOp>(op->getLoc(),alloc));
rewriter.eraseOp(op);

comet_debug() << "ConstantOpLowering ends\n";
return success();
}
Expand Down Expand Up @@ -624,25 +633,33 @@ namespace
comet_vdump(rhs);
comet_vdump(lhs);

auto rhsType = op->getOperand(0).getType();
auto lhsType = op->getOperand(1).getType();
// auto rh = op->getOperand(0).getType();
// auto lh = op->getOperand(1).getType();

[[maybe_unused]] auto f64Type = rewriter.getF64Type();
Value const_index_0 = rewriter.create<ConstantIndexOp>(loc, 0);
comet_vdump(const_index_0);
std::vector<Value> alloc_zero_loc = {const_index_0};

if (rhsType.isa<MemRefType>())
if (auto toTensorOp = llvm::dyn_cast_if_present<ToTensorOp>(rhs.getDefiningOp()))
{
comet_debug() << "RHS is a tensor\n";
rhs = rewriter.create<memref::LoadOp>(loc, rhs, alloc_zero_loc);
comet_vdump(rhs);
rhs = toTensorOp.getMemref();
// comet_debug() << "RHS is a tensor\n";
// rhs = rewriter.create<memref::LoadOp>(loc, rhs, alloc_zero_loc);
// comet_vdump(rhs);
}
if (lhsType.isa<MemRefType>())
if (auto toTensorOp = llvm::dyn_cast_if_present<ToTensorOp>(lhs.getDefiningOp()))
{
comet_debug() << "LHS is a tensor\n";
lhs = rewriter.create<memref::LoadOp>(loc, lhs, alloc_zero_loc);
lhs = toTensorOp.getMemref();
// comet_debug() << "RHS is a tensor\n";
// rhs = rewriter.create<memref::LoadOp>(loc, rhs, alloc_zero_loc);
// comet_vdump(rhs);
}
// if (lhsType.isa<MemRefType>())
// {
// comet_debug() << "LHS is a tensor\n";
// lhs = rewriter.create<memref::LoadOp>(loc, lhs, alloc_zero_loc);
// }

Value res;
bool res_comes_from_setop = false;
Expand All @@ -652,7 +669,18 @@ namespace
comet_pdump(u);
if (isa<tensorAlgebra::TensorSetOp>(u))
{
// u->dump();
// u->getBlock()->dump();
res = cast<tensorAlgebra::TensorSetOp>(u).getOperation()->getOperand(1);
// (++res.getUsers().begin())->dump();
if(!res.getUsers().empty() && isa<TensorSetOp>(*(++res.getUsers().begin())))
{
res = cast<tensorAlgebra::TensorSetOp>(*(++res.getUsers().begin())).getRhs();
}
if(auto toTensor = mlir::dyn_cast_or_null<ToTensorOp>(res.getDefiningOp()))
{
res = toTensor.getMemref();
}
comet_debug() << "Result from SetOp:\n";
comet_vdump(res);
res_comes_from_setop = true;
Expand All @@ -675,19 +703,23 @@ namespace
Value res_val;
if (op_attr.compare("+") == 0)
{
res_val = rewriter.create<AddFOp>(loc, rhs, lhs);
rewriter.create<linalg::AddOp>(loc, ValueRange{lhs, rhs}, ValueRange(res));
// res_val = rewriter.create<AddFOp>(loc, lhs, rhs);
}
else if (op_attr.compare("-") == 0)
{
res_val = rewriter.create<SubFOp>(loc, lhs, rhs);
rewriter.create<linalg::SubOp>(loc, ValueRange{lhs, rhs}, ValueRange(res));
// res_val = rewriter.create<SubFOp>(loc, lhs, rhs);
}
else if (op_attr.compare("*") == 0)
{
res_val = rewriter.create<MulFOp>(loc, rhs, lhs);
rewriter.create<linalg::MulOp>(loc, ValueRange{lhs, rhs}, ValueRange(res));
// res_val = rewriter.create<MulFOp>(loc, lhs, rhs);
}
else if (op_attr.compare("/") == 0)
{
res_val = rewriter.create<DivFOp>(loc, lhs, rhs);
rewriter.create<linalg::DivOp>(loc, ValueRange{lhs, rhs}, ValueRange(res));
// res_val = rewriter.create<DivFOp>(loc, lhs, rhs);
}
else
{
Expand All @@ -696,7 +728,8 @@ namespace

comet_vdump(res_val);
/// store res_val to res
[[maybe_unused]] auto storeOp = rewriter.create<memref::StoreOp>(loc, res_val, res, alloc_zero_loc);
// rewriter.create<linalg::CopyOp>(loc, res_val, res);
// [[maybe_unused]] auto storeOp = rewriter.create<memref::StoreOp>(loc, res_val, res, alloc_zero_loc);
comet_vdump(storeOp);

op.replaceAllUsesWith(res);
Expand Down
53 changes: 30 additions & 23 deletions lib/Dialect/TensorAlgebra/IR/TADialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,26 @@ void DenseConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &stat
/// or `false` on success. This allows for easily chaining together a set of
/// parser rules. These rules are used to populate an `mlir::OperationState`
/// similarly to the `build` methods described above.
mlir::ParseResult DenseConstantOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result)
{
mlir::DenseElementsAttr value;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(value, "value", result.attributes))
return failure();

result.addTypes(value.getType());
return success();
}

/// The 'OpAsmPrinter' class is a stream that allows for formatting
/// strings, attributes, operands, types, etc.
void DenseConstantOp::print(mlir::OpAsmPrinter &printer)
{
printer << " ";
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
printer << getValue();
}
// mlir::ParseResult DenseConstantOp::parse(mlir::OpAsmParser &parser,
// mlir::OperationState &result)
// {
// mlir::DenseElementsAttr value;
// if (parser.parseOptionalAttrDict(result.attributes) ||
// parser.parseAttribute(value, "value", result.attributes))
// return failure();

// result.addTypes(value.getType());
// return success();
// }

// /// The 'OpAsmPrinter' class is a stream that allows for formatting
// /// strings, attributes, operands, types, etc.
// void DenseConstantOp::print(mlir::OpAsmPrinter &printer)
// {
// printer << " ";
// printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
// printer << getValue();
// }

/// Verifier for the constant operation. This corresponds to the
/// `let hasVerifier = 1` in the op definition.
Expand All @@ -96,9 +96,16 @@ mlir::LogicalResult DenseConstantOp::verify()
auto attrType = getValue().getType().cast<mlir::TensorType>();
if (attrType.getRank() != resultType.getRank())
{
return emitOpError("return type must match the one of the attached value "
"attribute: ")
<< attrType.getRank() << " != " << resultType.getRank();
if(!(attrType.getRank() == 1 && attrType.getDimSize(0) == 1))
{
return emitOpError("return type must match the one of the attached value "
"attribute: ")
<< attrType.getRank() << " != " << resultType.getRank();
}
else
{
return mlir::success();
}
}

/// Check that each of the dimensions match between the two types.
Expand Down

0 comments on commit f14c977

Please sign in to comment.