diff --git a/frontends/comet_dsl/mlir/MLIRGen.cpp b/frontends/comet_dsl/mlir/MLIRGen.cpp index 3396cae5..45778086 100644 --- a/frontends/comet_dsl/mlir/MLIRGen.cpp +++ b/frontends/comet_dsl/mlir/MLIRGen.cpp @@ -46,6 +46,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/raw_ostream.h" +#include #include #include #include /// for random num generation @@ -467,9 +468,21 @@ namespace default: comet_debug() << "ERROR: unsupported operator type: ASCII Code(" << binop.getOp() << ")\n"; } + mlir::StringAttr opAttr = builder.getStringAttr(op); - mlir::Type elementType = builder.getF64Type(); - auto returnDataType = mlir::RankedTensorType::get(1, elementType); + mlir::RankedTensorType returnDataType; + if(lhs.getType().cast().getShape() != rhs.getType().cast().getShape()) + { + returnDataType = lhs.getType().cast(); + auto bcastRhs = builder.create(location, returnDataType, mlir::cast(rhs.getDefiningOp()).getValueAttr()); + comet_vdump(bcastRhs); + rhs.replaceAllUsesWith(bcastRhs); + rhs = bcastRhs; + } + else { + mlir::Type elementType = builder.getF64Type(); + returnDataType = mlir::RankedTensorType::get(1, elementType); + } comet_vdump(rhs); comet_vdump(lhs); @@ -481,7 +494,7 @@ namespace comet_debug() << "creating a new variable declaration, since the user did not declare it\n"; double data = 0.0; - auto dataAttribute = mlir::DenseElementsAttr::get(returnDataType, llvm::ArrayRef(data)); + auto dataAttribute = mlir::DenseElementsAttr::get(mlir::RankedTensorType::get({1}, builder.getF64Type()), llvm::ArrayRef(data)); auto denseConst = builder.create(location, returnDataType, dataAttribute); theOutput = denseConst; diff --git a/frontends/numpy-scipy/README.md b/frontends/numpy-scipy/README.md index a527bb83..03cecca3 100644 --- a/frontends/numpy-scipy/README.md +++ b/frontends/numpy-scipy/README.md @@ -7,9 +7,9 @@ Latest Version : v0.2 Requirements: 1. COMET backend installed on your computer -2. python 3.6 and above +2. python 3.8 and above 3. numpy -4. scipy >=1.9 +4. scipy >=1.14 5. jinja2 diff --git a/frontends/numpy-scipy/cometpy/MLIRGen/builders.py b/frontends/numpy-scipy/cometpy/MLIRGen/builders.py index 12b3b9a6..29c108c4 100644 --- a/frontends/numpy-scipy/cometpy/MLIRGen/builders.py +++ b/frontends/numpy-scipy/cometpy/MLIRGen/builders.py @@ -238,6 +238,43 @@ def build_op(self): ) +class ScalarOp_Builder: + indentation_size = 4 + + scalar_op_wrapper_text = jinja2.Template ( + ("" * indentation_size) + +'%t{{dest}} = "ta.scalar"({{operators}})' + +' <{op = "{{op}}"}> ' + +' : ({{inputtype}})' + +' -> ({{outputtype}})' + + "\n", + undefined=jinja2.StrictUndefined, + ) + + def __init__(self, data): + + self.dest = data["out_id"] + self.operators = "{}".format(",".join("%t"+str(v) for v in data["operands"])) + self.tensors_shapes =[] + for l in data["shapes"]: + if isinstance(l, int): + self.tensors_shapes.append('f64') + else: + self.tensors_shapes.append('tensor<1xf64>') + + self.op = data["op"] + + def build_op(self): + input_type = [] + + return self.scalar_op_wrapper_text.render( + dest = self.dest, + operators = self.operators, + op = self.op, + inputtype = ",".join(self.tensors_shapes[:2]), + outputtype = self.tensors_shapes[-1], + ) + class ArithOp_Builder: formats_str = ['Dense', 'CSR', 'COO', 'CSC'] indentation_size = 4 @@ -583,10 +620,10 @@ class PrintBuilder: def __init__(self, data): #operand, input_labels, dtype, label_map): self.operand = data["operands"][0] - self.outtype = "x".join(str(v) for v in data["shapes"][0]) - if len(data["shapes"][0])==1 and data["shapes"][0][0] == 1: + if data["shapes"] == 1 or data["shapes"] == [1]: self.outtype = data["value_type"] else: + self.outtype = "x".join(str(v) for v in data["shapes"][0]) self.outtype = "tensor<{}x{}>".format(self.outtype, data["value_type"]) def build_op(self): diff --git a/frontends/numpy-scipy/cometpy/MLIRGen/lowering.py b/frontends/numpy-scipy/cometpy/MLIRGen/lowering.py index 7591df3b..8c6962b4 100644 --- a/frontends/numpy-scipy/cometpy/MLIRGen/lowering.py +++ b/frontends/numpy-scipy/cometpy/MLIRGen/lowering.py @@ -472,16 +472,16 @@ def generate_llvm_args_from_ndarrays(num_in, *ndargs): A2tile_crd = np.array([-1], dtype=np.int64) # CSR if ndarray.format == 'csr': - A1pos = np.array([ndarray.get_shape()[0]], dtype=np.int64) + A1pos = np.array([ndarray.shape[0]], dtype=np.int64) A1crd = np.array([-1], dtype=np.int64) A2pos = ndarray.indptr.astype('int64') A2crd = ndarray.indices.astype('int64') # Based on the desc_sizes array in SparseUtils.cpp:read_input_sizes_2D - # llvm_args += [*np_array_to_memref(np.array([1, 1, ndarray.get_shape()[0] + 1, ndarray.getnnz(), ndarray.getnnz(), ndarray.get_shape()[0], ndarray.get_shape()[1]], dtype='int64'))] + # llvm_args += [*np_array_to_memref(np.array([1, 1, ndarray.shape[0] + 1, ndarray.nnz, ndarray.nnz, ndarray.shape[0], ndarray.shape[1]], dtype='int64'))] # With tiles - llvm_args += [*np_array_to_memref(np.array([1, 1, 0, 0, ndarray.get_shape()[0] + 1, ndarray.getnnz(), 0, 0, ndarray.getnnz(), ndarray.get_shape()[0], ndarray.get_shape()[1]], dtype='int64'))] + llvm_args += [*np_array_to_memref(np.array([1, 1, 0, 0, ndarray.shape[0] + 1, ndarray.nnz, 0, 0, ndarray.nnz, ndarray.shape[0], ndarray.shape[1]], dtype='int64'))] # COO elif ndarray.format == 'coo': A1pos = np.array([0, ndarray.nnz], dtype=np.int64) @@ -491,21 +491,21 @@ def generate_llvm_args_from_ndarrays(num_in, *ndargs): # Based on the desc_sizes array in SparseUtils.cpp:read_input_sizes_2D - # llvm_args += [*np_array_to_memref(np.array([2, ndarray.nnz, 1, ndarray.getnnz(), ndarray.getnnz(), ndarray.get_shape()[0], ndarray.get_shape()[1]], dtype='int64'))] + # llvm_args += [*np_array_to_memref(np.array([2, ndarray.nnz, 1, ndarray.nnz, ndarray.nnz, ndarray.shape[0], ndarray.shape[1]], dtype='int64'))] # With tiles - llvm_args += [*np_array_to_memref(np.array([2, ndarray.nnz, 0, 0, 1, ndarray.getnnz(), 0, 0, ndarray.getnnz(), ndarray.get_shape()[0], ndarray.get_shape()[1]], dtype='int64'))] + llvm_args += [*np_array_to_memref(np.array([2, ndarray.nnz, 0, 0, 1, ndarray.nnz, 0, 0, ndarray.nnz, ndarray.shape[0], ndarray.shape[1]], dtype='int64'))] # CSC elif ndarray.format == 'csc': A1pos = ndarray.indptr.astype('int64') A1crd = ndarray.indices.astype('int64') - A2pos = np.array([ndarray.get_shape()[1]], dtype=np.int64) + A2pos = np.array([ndarray.shape[1]], dtype=np.int64) # Based on the desc_sizes array in SparseUtils.cpp:read_input_sizes_2D - # llvm_args += [*np_array_to_memref(np.array([ndarray.get_shape()[1] + 1, ndarray.nnz, 1, 1, ndarray.getnnz(), ndarray.get_shape()[0], ndarray.get_shape()[1]], dtype='int64'))] + # llvm_args += [*np_array_to_memref(np.array([ndarray.shape[1] + 1, ndarray.nnz, 1, 1, ndarray.nnz, ndarray.shape[0], ndarray.shape[1]], dtype='int64'))] # With tiles - llvm_args += [*np_array_to_memref(np.array([ndarray.get_shape()[1] + 1, ndarray.nnz, 0, 0, 1, 1, 0, 0, ndarray.getnnz(), ndarray.get_shape()[0], ndarray.get_shape()[1]], dtype='int64'))] + llvm_args += [*np_array_to_memref(np.array([ndarray.shape[1] + 1, ndarray.nnz, 0, 0, 1, 1, 0, 0, ndarray.nnz, ndarray.shape[0], ndarray.shape[1]], dtype='int64'))] Aval = ndarray.data.astype('float64') # Based on the desc_A1pos/crd, desc_A2pos/crd, desc_Aval arrays in SparseUtils.cpp: read_input_2D diff --git a/frontends/numpy-scipy/cometpy/comet.py b/frontends/numpy-scipy/cometpy/comet.py index 5c5ca68f..3e917976 100644 --- a/frontends/numpy-scipy/cometpy/comet.py +++ b/frontends/numpy-scipy/cometpy/comet.py @@ -198,7 +198,8 @@ def visit_FunctionDef(self, node): self.tsemantics[self.tcurr] = { 'shape': list(self.inputs[i].shape), 'format': format, - 'dimsSSA': [self.get_index_constant(d) for d in self.inputs[i].shape] + 'dimsSSA': [self.get_index_constant(d) for d in self.inputs[i].shape], + 'scalar': False, } self.declarations.append( { @@ -321,9 +322,11 @@ def visit_Name(self, node: ast.Name): def visit_Constant(self, node: Constant) : out_id = self.tcurr + self.tsemantics[self.tcurr] = {'shape': [1,], 'format': DENSE, 'scalar': True} self.declarations.append( { "type": "V", + "value": f"{node.value:e}", "todo": "l", "id": out_id, }) @@ -333,6 +336,52 @@ def visit_Constant(self, node: Constant) : def create_binOp(self, node, operands, no_assign): op0_sems = self.tsemantics[operands[0]] op1_sems = self.tsemantics[operands[1]] + if op0_sems['scalar'] and op1_sems['scalar']: + op = '' + if isinstance(node.op, ast.Add): + op = '+' + elif isinstance(node.op, ast.Sub): + op = '-' + elif isinstance(node.op, ast.Mult): + op = '*' + elif isinstance(node.op, ast.Div): + op = '/' + else: + raise "Unexpected operator {}".format(node.op) + self.ops.append( + { + "op_type": "scalar", + "op": op, + "operands": operands[::-1], + "shapes": [op1_sems['shape'], op0_sems['shape'], [1,]], + "out_id": self.tcurr, + } + ) + self.tsemantics[self.tcurr] = {'shape': [1,], 'format': DENSE, 'scalar': True} + + self.tcurr += 1 + self.declarations.append({ + "type": "V", + "value": f"{0:e}", + "is_input": False, + "todo": "l", + "format": DENSE, + "shape": [1,], + # "dimsSSA": [self.get_index_constant(d) for d in op_semantics['shape']], + "id": self.tcurr, + }) + self.ops.append( + { + "op_type": "=", + "shapes": [[1,]]*2, + "lhs": self.tcurr, + "rhs": self.tcurr-1, + "beta": no_assign, + }) + self.tsemantics[self.tcurr] = self.tsemantics[self.tcurr-1] + + self.tcurr +=1 + return self.tcurr-1 format = self.sp_elw_add_sub_conversions[op0_sems['format']][op1_sems['format']] if self.tsemantics[operands[0]]['format'] != DENSE: @@ -343,19 +392,19 @@ def create_binOp(self, node, operands, no_assign): self.need_opt_comp_workspace = op0_sems['format'] == CSR and op1_sems['format'] == CSR if isinstance(node.op, ast.Add): self.ops.append( - { - "op_type": "+", - "shapes": [op_semantics['shape']] * 3, - "operands": operands, - "op_ilabels": [[self.get_next_indexlabel_with_val(d) for d in op_semantics['shape']]] * 3, - "beta": 0, - "out_id": self.tcurr, - }) + { + "op_type": "+", + "shapes": [op_semantics['shape']] * 3, + "operands": operands, + "op_ilabels": [[self.get_next_indexlabel_with_val(d) for d in op_semantics['shape']]] * 3, + "beta": 0, + "out_id": self.tcurr, + }) for d in op_semantics['shape']: self.reset_indexlabel_with_val(d) - self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']]} + self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']], 'scalar': False} # if not no_assign: self.tcurr += 1 self.declarations.append({ @@ -389,7 +438,7 @@ def create_binOp(self, node, operands, no_assign): for d in op_semantics['shape']: self.reset_indexlabel_with_val(d) - self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']]} + self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']], 'scalar': False} # if not no_assign: self.tcurr += 1 self.declarations.append({ @@ -428,7 +477,7 @@ def create_binOp(self, node, operands, no_assign): format = self.sp_elw_mult_conversions[op0_sems['format']][op1_sems['format']] for d in op_semantics['shape']: self.reset_indexlabel_with_val(d) - self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']]} + self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']], 'scalar': False} # if not no_assign: self.tcurr += 1 @@ -487,7 +536,7 @@ def visit_Method_Call(self, node: Call, obj): if node.func.attr == "transpose": out_format = self.sp_mattr_conversions[op_semantics['format']] - self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'][::-1], 'format': out_format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape'][::-1]]} + self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'][::-1], 'format': out_format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape'][::-1]], 'scalar': False} in_ilabels = [self.get_next_indexlabel_with_val(d) for d in op_semantics['shape']] for d in op_semantics['shape']: @@ -529,7 +578,7 @@ def visit_Method_Call(self, node: Call, obj): self.tsemantics[self.tcurr] = self.tsemantics[self.tcurr-1] elif node.func.attr == "sum": - self.tsemantics[self.tcurr] = {'shape': [1,], 'format': DENSE} + self.tsemantics[self.tcurr] = {'shape': 1, 'format': DENSE, 'scalar': True} self.ops.append( { "op_type": "s", @@ -538,13 +587,13 @@ def visit_Method_Call(self, node: Call, obj): "out_id": self.tcurr, }) # if not no_assign: - self.declarations.append( - { - "type": "V", - "todo": "l", - "format": DENSE, - "id": self.tcurr, - }) + # self.declarations.append( + # { + # "type": "V", + # "todo": "l", + # "format": DENSE, + # "id": self.tcurr, + # }) elif node.func.attr == "multiply": op1 = NewVisitor.visit(self, node.args[0]) op1_sems = self.tsemantics[op1] @@ -566,7 +615,7 @@ def visit_Method_Call(self, node: Call, obj): }) format = self.sp_elw_mult_conversions[op_semantics['format']][op1_sems['format']] - self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape'][::-1]]} + self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape'][::-1]], 'scalar': False} # if not no_assign: self.tcurr += 1 self.declarations.append( @@ -690,7 +739,7 @@ def visit_Bin_Einsum_Call(self, operands, llabels, mask,semiring, beta, no_assig "op_ilabels": [in_ilabels, out_ilabels] }) - self.tsemantics[self.tcurr] = {'shape': shape, 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in shape]} + self.tsemantics[self.tcurr] = {'shape': shape, 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in shape], 'scalar': False} if not no_assign: self.tcurr += 1 self.declarations.append( @@ -775,7 +824,7 @@ def visit_Einsum_Call(self, node: Call): # ("*", operands, indices+','+indices+'->'+indices, self.tcurr, semiring)) format = self.sp_elw_mult_conversions[op0_sems['format']][op1_sems['format']] - self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']] } + self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']], 'scalar': False } self.tcurr += 1 self.declarations.append( { @@ -856,7 +905,10 @@ def wrapper(*pos_args, **kwargs): v.visit(parsed_func) in_types = [] for arg in v.in_args: - in_types.append(("%t"+str(arg), "tensor<{}xf64>".format("x".join(str(d) for d in v.tsemantics[arg]['shape'])))) + if isinstance(v.tsemantics[arg]['shape'], int): + in_types.append(("%t"+str(arg), "tensor<1xf64>")) + else: + in_types.append(("%t"+str(arg), "tensor<{}xf64>".format("x".join(str(d) for d in v.tsemantics[arg]['shape'])))) irb = builders.MLIRFunctionBuilder( func_def.name, input_types=in_types, @@ -868,6 +920,7 @@ def wrapper(*pos_args, **kwargs): dense_tensors = [] + scalars = [] for dec in v.declarations: if dec["type"] == "T": @@ -879,11 +932,17 @@ def wrapper(*pos_args, **kwargs): irb.add_statement(t.build_tensor()) elif dec["type"] == "C": irb.add_statement('%d{} = arith.constant {} : index '.format(dec["id"], dec["value"])) + elif dec["type"] == "V": + scalars.append('%t{} = ta.constant dense<{}> : tensor<1xf64> '.format(dec["id"], dec["value"])) + # irb.add_statement('%t{} = ta.constant dense<{}> : tensor<1xf64> '.format(dec["id"], dec["value"])) for t in dense_tensors: irb.add_statement(t.build_tensor()) + for t in scalars: + irb.add_statement(t) + for op in v.ops: if op["op_type"] == 'c': op["formats"] = [v.tsemantics[t]['format'] for t in op["operands"]] + [v.tsemantics[op["out_id"]]['format']] @@ -893,6 +952,8 @@ def wrapper(*pos_args, **kwargs): else: op["mask"] = (op["mask"][0], op["mask"][1], None) irb.add_statement(builders.ArithOp_Builder(op).build_op()) + elif op["op_type"] == 'scalar': + irb.add_statement(builders.ScalarOp_Builder(op).build_op()) elif op["op_type"] == 's': irb.add_statement(builders.TensorSumBuilder(op).build_op()) elif op["op_type"] == 'p': diff --git a/frontends/numpy-scipy/integration_tests/ops/scalar.py b/frontends/numpy-scipy/integration_tests/ops/scalar.py new file mode 100644 index 00000000..b5643c7b --- /dev/null +++ b/frontends/numpy-scipy/integration_tests/ops/scalar.py @@ -0,0 +1,30 @@ +import time +import numpy as np +import scipy as sp +from cometpy import comet + +def run_numpy(): + a = 5 + 1 + b = a + 5 + 1 + c = b / 2 + d = c * 3 + e = d - 1 + + return e + +@comet.compile(flags=None) +def run_comet_with_jit(): + a = 5 + 1 + b = a + 5 + 1 + c = b / 2 + d = c * 3 + e = d - 1 + + return e + +expected_result = run_numpy() +result_with_jit = run_comet_with_jit() +if sp.sparse.issparse(expected_result): + expected_result = expected_result.todense() + result_with_jit = result_with_jit.todense() +np.testing.assert_almost_equal(result_with_jit, expected_result) \ No newline at end of file diff --git a/frontends/numpy-scipy/setup.py b/frontends/numpy-scipy/setup.py index baf484df..b637b053 100644 --- a/frontends/numpy-scipy/setup.py +++ b/frontends/numpy-scipy/setup.py @@ -54,7 +54,7 @@ install_requires=[ 'jinja2', 'numpy', - 'scipy>=1.10' + 'scipy>=1.14' ], python_requires=">=3.8", ) \ No newline at end of file diff --git a/include/comet/Dialect/TensorAlgebra/IR/TAOps.td b/include/comet/Dialect/TensorAlgebra/IR/TAOps.td index 54ff6abd..40398645 100644 --- a/include/comet/Dialect/TensorAlgebra/IR/TAOps.td +++ b/include/comet/Dialect/TensorAlgebra/IR/TAOps.td @@ -260,7 +260,7 @@ def DenseConstantOp : TA_Op<"constant", [Pure]> { let results = (outs F64Tensor); /// Indicate that the operation has a custom parser and printer method. - let hasCustomAssemblyFormat = 1; + // let hasCustomAssemblyFormat = 1; let builders = [ OpBuilder<(ins "DenseElementsAttr":$value), diff --git a/integration_test/ops/scalar.ta b/integration_test/ops/scalar.ta new file mode 100644 index 00000000..7452037b --- /dev/null +++ b/integration_test/ops/scalar.ta @@ -0,0 +1,15 @@ +# RUN: comet-opt --convert-to-loops --convert-to-llvm %s &> scalars.llvm +# RUN: mlir-cpu-runner scalars.llvm -O3 -e main -entry-point-result=void -shared-libs=%comet_utility_library_dir/libcomet_runner_utils%shlibext | FileCheck %s + +def main() { + var a = 5 + 1; + var b = a + 5 + 1; + var c = b / 2; + var d = c * 3; + var e = d - 1; + print(e); +} + +# Print the result for verification. +# CHECK: data = +# CHECK-NEXT: 17, \ No newline at end of file diff --git a/lib/Conversion/TensorAlgebraToSCF/EarlyLowering.cpp b/lib/Conversion/TensorAlgebraToSCF/EarlyLowering.cpp index f3400add..150cdf01 100644 --- a/lib/Conversion/TensorAlgebraToSCF/EarlyLowering.cpp +++ b/lib/Conversion/TensorAlgebraToSCF/EarlyLowering.cpp @@ -254,6 +254,7 @@ namespace tensorAlgebra::TensorSetOp, tensorAlgebra::IndexLabelOp, tensorAlgebra::DenseConstantOp, + tensorAlgebra::ScalarOp, tensorAlgebra::SparseTensorConstructOp>(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) diff --git a/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp b/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp index f09ae845..09dd38fb 100644 --- a/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp +++ b/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp @@ -104,6 +104,9 @@ namespace comet_vdump(alloc_op); alloc = alloc_op; } + else { + alloc = rewriter.create(loc, memRefType); + } } else { @@ -114,55 +117,64 @@ namespace /// Create these constants up-front to avoid large amounts of redundant /// operations. auto valueShape = memRefType.getShape(); - SmallVector constantIndices; - - if (!valueShape.empty()) + auto constTensor = op.getValue().getType().cast(); + if(constTensor.getRank() == 1 && constTensor.getDimSize(0) == 1) { - for (auto i : llvm::seq( - 0, *std::max_element(valueShape.begin(), valueShape.end()))) - constantIndices.push_back(rewriter.create(loc, i)); + auto float_attr = *constantValue.getValues().begin(); + auto f_val = float_attr.getValue(); + auto val = rewriter.create(op->getLoc(), f_val, rewriter.getF64Type()); + rewriter.create(op->getLoc(), ValueRange(val), ValueRange(alloc)); } - else + else { - /// This is the case of a tensor of rank 0. - constantIndices.push_back(rewriter.create(loc, 0)); - } + SmallVector constantIndices; - /// The constant operation represents a multi-dimensional constant, so we - /// will need to generate a store for each of the elements. The following - /// functor recursively walks the dimensions of the constant shape, - /// generating a store when the recursion hits the base case. - SmallVector indices; - auto valueIt = constantValue.getValues().begin(); - std::function storeElements = [&](uint64_t dimension) - { - /// The last dimension is the base case of the recursion, at this point - /// we store the element at the given index. - if (dimension == valueShape.size()) + if (!valueShape.empty()) { - rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, - llvm::ArrayRef(indices)); - return; + for (auto i : llvm::seq( + 0, *std::max_element(valueShape.begin(), valueShape.end()))) + constantIndices.push_back(rewriter.create(loc, i)); } - - /// Otherwise, iterate over the current dimension and add the indices to - /// the list. - for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) + else { - indices.push_back(constantIndices[i]); - storeElements(dimension + 1); - indices.pop_back(); + /// This is the case of a tensor of rank 0. + constantIndices.push_back(rewriter.create(loc, 0)); } - }; + /// The constant operation represents a multi-dimensional constant, so we + /// will need to generate a store for each of the elements. The following + /// functor recursively walks the dimensions of the constant shape, + /// generating a store when the recursion hits the base case. + SmallVector indices; + auto valueIt = constantValue.getValues().begin(); + std::function storeElements = [&](uint64_t dimension) + { + /// The last dimension is the base case of the recursion, at this point + /// we store the element at the given index. + if (dimension == valueShape.size()) + { + rewriter.create( + loc, rewriter.create(loc, *valueIt++), alloc, + llvm::ArrayRef(indices)); + return; + } - /// Start the element storing recursion from the first dimension. - storeElements(/*dimension=*/0); + /// Otherwise, iterate over the current dimension and add the indices to + /// the list. + for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) + { + indices.push_back(constantIndices[i]); + storeElements(dimension + 1); + indices.pop_back(); + } + }; + + /// Start the element storing recursion from the first dimension. + storeElements(/*dimension=*/0); + } /// Replace this operation with the generated alloc. - op.replaceAllUsesWith(alloc); + op->replaceAllUsesWith(rewriter.create(op->getLoc(),alloc)); rewriter.eraseOp(op); - comet_debug() << "ConstantOpLowering ends\n"; return success(); } @@ -621,25 +633,33 @@ namespace comet_vdump(rhs); comet_vdump(lhs); - auto rhsType = op->getOperand(0).getType(); - auto lhsType = op->getOperand(1).getType(); + // auto rh = op->getOperand(0).getType(); + // auto lh = op->getOperand(1).getType(); [[maybe_unused]] auto f64Type = rewriter.getF64Type(); Value const_index_0 = rewriter.create(loc, 0); comet_vdump(const_index_0); std::vector alloc_zero_loc = {const_index_0}; - if (rhsType.isa()) + if (auto toTensorOp = llvm::dyn_cast_if_present(rhs.getDefiningOp())) { - comet_debug() << "RHS is a tensor\n"; - rhs = rewriter.create(loc, rhs, alloc_zero_loc); - comet_vdump(rhs); + rhs = toTensorOp.getMemref(); + // comet_debug() << "RHS is a tensor\n"; + // rhs = rewriter.create(loc, rhs, alloc_zero_loc); + // comet_vdump(rhs); } - if (lhsType.isa()) + if (auto toTensorOp = llvm::dyn_cast_if_present(lhs.getDefiningOp())) { - comet_debug() << "LHS is a tensor\n"; - lhs = rewriter.create(loc, lhs, alloc_zero_loc); + lhs = toTensorOp.getMemref(); + // comet_debug() << "RHS is a tensor\n"; + // rhs = rewriter.create(loc, rhs, alloc_zero_loc); + // comet_vdump(rhs); } + // if (lhsType.isa()) + // { + // comet_debug() << "LHS is a tensor\n"; + // lhs = rewriter.create(loc, lhs, alloc_zero_loc); + // } Value res; bool res_comes_from_setop = false; @@ -649,7 +669,18 @@ namespace comet_pdump(u); if (isa(u)) { + // u->dump(); + // u->getBlock()->dump(); res = cast(u).getOperation()->getOperand(1); + // (++res.getUsers().begin())->dump(); + if(!res.getUsers().empty() && isa(*(++res.getUsers().begin()))) + { + res = cast(*(++res.getUsers().begin())).getRhs(); + } + if(auto toTensor = mlir::dyn_cast_or_null(res.getDefiningOp())) + { + res = toTensor.getMemref(); + } comet_debug() << "Result from SetOp:\n"; comet_vdump(res); res_comes_from_setop = true; @@ -672,19 +703,23 @@ namespace Value res_val; if (op_attr.compare("+") == 0) { - res_val = rewriter.create(loc, rhs, lhs); + rewriter.create(loc, ValueRange{lhs, rhs}, ValueRange(res)); + // res_val = rewriter.create(loc, lhs, rhs); } else if (op_attr.compare("-") == 0) { - res_val = rewriter.create(loc, lhs, rhs); + rewriter.create(loc, ValueRange{lhs, rhs}, ValueRange(res)); + // res_val = rewriter.create(loc, lhs, rhs); } else if (op_attr.compare("*") == 0) { - res_val = rewriter.create(loc, rhs, lhs); + rewriter.create(loc, ValueRange{lhs, rhs}, ValueRange(res)); + // res_val = rewriter.create(loc, lhs, rhs); } else if (op_attr.compare("/") == 0) { - res_val = rewriter.create(loc, lhs, rhs); + rewriter.create(loc, ValueRange{lhs, rhs}, ValueRange(res)); + // res_val = rewriter.create(loc, lhs, rhs); } else { @@ -693,7 +728,8 @@ namespace comet_vdump(res_val); /// store res_val to res - [[maybe_unused]] auto storeOp = rewriter.create(loc, res_val, res, alloc_zero_loc); + // rewriter.create(loc, res_val, res); + // [[maybe_unused]] auto storeOp = rewriter.create(loc, res_val, res, alloc_zero_loc); comet_vdump(storeOp); op.replaceAllUsesWith(res); diff --git a/lib/Dialect/TensorAlgebra/IR/TADialect.cpp b/lib/Dialect/TensorAlgebra/IR/TADialect.cpp index 48402fa4..7eba8c86 100644 --- a/lib/Dialect/TensorAlgebra/IR/TADialect.cpp +++ b/lib/Dialect/TensorAlgebra/IR/TADialect.cpp @@ -60,26 +60,26 @@ void DenseConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &stat /// or `false` on success. This allows for easily chaining together a set of /// parser rules. These rules are used to populate an `mlir::OperationState` /// similarly to the `build` methods described above. -mlir::ParseResult DenseConstantOp::parse(mlir::OpAsmParser &parser, - mlir::OperationState &result) -{ - mlir::DenseElementsAttr value; - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseAttribute(value, "value", result.attributes)) - return failure(); - - result.addTypes(value.getType()); - return success(); -} - -/// The 'OpAsmPrinter' class is a stream that allows for formatting -/// strings, attributes, operands, types, etc. -void DenseConstantOp::print(mlir::OpAsmPrinter &printer) -{ - printer << " "; - printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); - printer << getValue(); -} +// mlir::ParseResult DenseConstantOp::parse(mlir::OpAsmParser &parser, +// mlir::OperationState &result) +// { +// mlir::DenseElementsAttr value; +// if (parser.parseOptionalAttrDict(result.attributes) || +// parser.parseAttribute(value, "value", result.attributes)) +// return failure(); + +// result.addTypes(value.getType()); +// return success(); +// } + +// /// The 'OpAsmPrinter' class is a stream that allows for formatting +// /// strings, attributes, operands, types, etc. +// void DenseConstantOp::print(mlir::OpAsmPrinter &printer) +// { +// printer << " "; +// printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); +// printer << getValue(); +// } /// Verifier for the constant operation. This corresponds to the /// `let hasVerifier = 1` in the op definition. @@ -96,9 +96,16 @@ mlir::LogicalResult DenseConstantOp::verify() auto attrType = getValue().getType().cast(); if (attrType.getRank() != resultType.getRank()) { - return emitOpError("return type must match the one of the attached value " - "attribute: ") - << attrType.getRank() << " != " << resultType.getRank(); + if(!(attrType.getRank() == 1 && attrType.getDimSize(0) == 1)) + { + return emitOpError("return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + else + { + return mlir::success(); + } } /// Check that each of the dimensions match between the two types. diff --git a/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp b/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp index 90d198d9..df5cc2c5 100644 --- a/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp +++ b/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp @@ -1700,6 +1700,7 @@ namespace tensorAlgebra::SparseOutputTensorDeclOp, tensorAlgebra::TempSparseOutputTensorDeclOp, tensorAlgebra::IndexLabelOp, + tensorAlgebra::ScalarOp, tensorAlgebra::SparseTensorConstructOp>(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) @@ -1793,6 +1794,7 @@ namespace tensorAlgebra::IndexLabelOp, tensorAlgebra::DenseConstantOp, tensorAlgebra::TensorDimOp, + tensorAlgebra::ScalarOp, func::CallOp>(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) @@ -1842,6 +1844,7 @@ namespace tensorAlgebra::IndexLabelOp, tensorAlgebra::DenseConstantOp, tensorAlgebra::TensorDimOp, + tensorAlgebra::ScalarOp, func::CallOp>(); if (failed(applyPartialConversion(function, target, std::move(patterns))))