Skip to content

Commit

Permalink
Merge branch 'fix-scalars' into triton-integration
Browse files Browse the repository at this point in the history
  • Loading branch information
pthomadakis committed Sep 5, 2024
2 parents 2d3c53c + f14c977 commit 4774f1e
Show file tree
Hide file tree
Showing 13 changed files with 319 additions and 116 deletions.
19 changes: 16 additions & 3 deletions frontends/comet_dsl/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include <map>
#include <numeric>
#include <cstdlib> /// for random num generation
Expand Down Expand Up @@ -467,9 +468,21 @@ namespace
default:
comet_debug() << "ERROR: unsupported operator type: ASCII Code(" << binop.getOp() << ")\n";
}

mlir::StringAttr opAttr = builder.getStringAttr(op);
mlir::Type elementType = builder.getF64Type();
auto returnDataType = mlir::RankedTensorType::get(1, elementType);
mlir::RankedTensorType returnDataType;
if(lhs.getType().cast<mlir::RankedTensorType>().getShape() != rhs.getType().cast<mlir::RankedTensorType>().getShape())
{
returnDataType = lhs.getType().cast<mlir::RankedTensorType>();
auto bcastRhs = builder.create<DenseConstantOp>(location, returnDataType, mlir::cast<DenseConstantOp>(rhs.getDefiningOp()).getValueAttr());
comet_vdump(bcastRhs);
rhs.replaceAllUsesWith(bcastRhs);
rhs = bcastRhs;
}
else {
mlir::Type elementType = builder.getF64Type();
returnDataType = mlir::RankedTensorType::get(1, elementType);
}
comet_vdump(rhs);
comet_vdump(lhs);

Expand All @@ -481,7 +494,7 @@ namespace
comet_debug() << "creating a new variable declaration, since the user did not declare it\n";

double data = 0.0;
auto dataAttribute = mlir::DenseElementsAttr::get(returnDataType, llvm::ArrayRef(data));
auto dataAttribute = mlir::DenseElementsAttr::get(mlir::RankedTensorType::get({1}, builder.getF64Type()), llvm::ArrayRef(data));
auto denseConst = builder.create<DenseConstantOp>(location, returnDataType, dataAttribute);

theOutput = denseConst;
Expand Down
4 changes: 2 additions & 2 deletions frontends/numpy-scipy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ Latest Version : v0.2

Requirements:
1. COMET backend installed on your computer
2. python 3.6 and above
2. python 3.8 and above
3. numpy
4. scipy >=1.9
4. scipy >=1.14
5. jinja2


Expand Down
41 changes: 39 additions & 2 deletions frontends/numpy-scipy/cometpy/MLIRGen/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,43 @@ def build_op(self):

)

class ScalarOp_Builder:
indentation_size = 4

scalar_op_wrapper_text = jinja2.Template (
("" * indentation_size)
+'%t{{dest}} = "ta.scalar"({{operators}})'
+' <{op = "{{op}}"}> '
+' : ({{inputtype}})'
+' -> ({{outputtype}})'
+ "\n",
undefined=jinja2.StrictUndefined,
)

def __init__(self, data):

self.dest = data["out_id"]
self.operators = "{}".format(",".join("%t"+str(v) for v in data["operands"]))
self.tensors_shapes =[]
for l in data["shapes"]:
if isinstance(l, int):
self.tensors_shapes.append('f64')
else:
self.tensors_shapes.append('tensor<1xf64>')

self.op = data["op"]

def build_op(self):
input_type = []

return self.scalar_op_wrapper_text.render(
dest = self.dest,
operators = self.operators,
op = self.op,
inputtype = ",".join(self.tensors_shapes[:2]),
outputtype = self.tensors_shapes[-1],
)

class ArithOp_Builder:
formats_str = ['Dense', 'CSR', 'COO', 'CSC']
indentation_size = 4
Expand Down Expand Up @@ -583,10 +620,10 @@ class PrintBuilder:

def __init__(self, data): #operand, input_labels, dtype, label_map):
self.operand = data["operands"][0]
self.outtype = "x".join(str(v) for v in data["shapes"][0])
if len(data["shapes"][0])==1 and data["shapes"][0][0] == 1:
if data["shapes"] == 1 or data["shapes"] == [1]:
self.outtype = data["value_type"]
else:
self.outtype = "x".join(str(v) for v in data["shapes"][0])
self.outtype = "tensor<{}x{}>".format(self.outtype, data["value_type"])

def build_op(self):
Expand Down
16 changes: 8 additions & 8 deletions frontends/numpy-scipy/cometpy/MLIRGen/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,16 +472,16 @@ def generate_llvm_args_from_ndarrays(num_in, *ndargs):
A2tile_crd = np.array([-1], dtype=np.int64)
# CSR
if ndarray.format == 'csr':
A1pos = np.array([ndarray.get_shape()[0]], dtype=np.int64)
A1pos = np.array([ndarray.shape[0]], dtype=np.int64)
A1crd = np.array([-1], dtype=np.int64)
A2pos = ndarray.indptr.astype('int64')
A2crd = ndarray.indices.astype('int64')

# Based on the desc_sizes array in SparseUtils.cpp:read_input_sizes_2D

# llvm_args += [*np_array_to_memref(np.array([1, 1, ndarray.get_shape()[0] + 1, ndarray.getnnz(), ndarray.getnnz(), ndarray.get_shape()[0], ndarray.get_shape()[1]], dtype='int64'))]
# llvm_args += [*np_array_to_memref(np.array([1, 1, ndarray.shape[0] + 1, ndarray.nnz, ndarray.nnz, ndarray.shape[0], ndarray.shape[1]], dtype='int64'))]
# With tiles
llvm_args += [*np_array_to_memref(np.array([1, 1, 0, 0, ndarray.get_shape()[0] + 1, ndarray.getnnz(), 0, 0, ndarray.getnnz(), ndarray.get_shape()[0], ndarray.get_shape()[1]], dtype='int64'))]
llvm_args += [*np_array_to_memref(np.array([1, 1, 0, 0, ndarray.shape[0] + 1, ndarray.nnz, 0, 0, ndarray.nnz, ndarray.shape[0], ndarray.shape[1]], dtype='int64'))]
# COO
elif ndarray.format == 'coo':
A1pos = np.array([0, ndarray.nnz], dtype=np.int64)
Expand All @@ -491,21 +491,21 @@ def generate_llvm_args_from_ndarrays(num_in, *ndargs):

# Based on the desc_sizes array in SparseUtils.cpp:read_input_sizes_2D

# llvm_args += [*np_array_to_memref(np.array([2, ndarray.nnz, 1, ndarray.getnnz(), ndarray.getnnz(), ndarray.get_shape()[0], ndarray.get_shape()[1]], dtype='int64'))]
# llvm_args += [*np_array_to_memref(np.array([2, ndarray.nnz, 1, ndarray.nnz, ndarray.nnz, ndarray.shape[0], ndarray.shape[1]], dtype='int64'))]
# With tiles
llvm_args += [*np_array_to_memref(np.array([2, ndarray.nnz, 0, 0, 1, ndarray.getnnz(), 0, 0, ndarray.getnnz(), ndarray.get_shape()[0], ndarray.get_shape()[1]], dtype='int64'))]
llvm_args += [*np_array_to_memref(np.array([2, ndarray.nnz, 0, 0, 1, ndarray.nnz, 0, 0, ndarray.nnz, ndarray.shape[0], ndarray.shape[1]], dtype='int64'))]

# CSC
elif ndarray.format == 'csc':
A1pos = ndarray.indptr.astype('int64')
A1crd = ndarray.indices.astype('int64')
A2pos = np.array([ndarray.get_shape()[1]], dtype=np.int64)
A2pos = np.array([ndarray.shape[1]], dtype=np.int64)

# Based on the desc_sizes array in SparseUtils.cpp:read_input_sizes_2D

# llvm_args += [*np_array_to_memref(np.array([ndarray.get_shape()[1] + 1, ndarray.nnz, 1, 1, ndarray.getnnz(), ndarray.get_shape()[0], ndarray.get_shape()[1]], dtype='int64'))]
# llvm_args += [*np_array_to_memref(np.array([ndarray.shape[1] + 1, ndarray.nnz, 1, 1, ndarray.nnz, ndarray.shape[0], ndarray.shape[1]], dtype='int64'))]
# With tiles
llvm_args += [*np_array_to_memref(np.array([ndarray.get_shape()[1] + 1, ndarray.nnz, 0, 0, 1, 1, 0, 0, ndarray.getnnz(), ndarray.get_shape()[0], ndarray.get_shape()[1]], dtype='int64'))]
llvm_args += [*np_array_to_memref(np.array([ndarray.shape[1] + 1, ndarray.nnz, 0, 0, 1, 1, 0, 0, ndarray.nnz, ndarray.shape[0], ndarray.shape[1]], dtype='int64'))]

Aval = ndarray.data.astype('float64')
# Based on the desc_A1pos/crd, desc_A2pos/crd, desc_Aval arrays in SparseUtils.cpp: read_input_2D
Expand Down
111 changes: 86 additions & 25 deletions frontends/numpy-scipy/cometpy/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def visit_FunctionDef(self, node):
self.tsemantics[self.tcurr] = {
'shape': list(self.inputs[i].shape),
'format': format,
'dimsSSA': [self.get_index_constant(d) for d in self.inputs[i].shape]
'dimsSSA': [self.get_index_constant(d) for d in self.inputs[i].shape],
'scalar': False,
}
self.declarations.append(
{
Expand Down Expand Up @@ -321,9 +322,11 @@ def visit_Name(self, node: ast.Name):

def visit_Constant(self, node: Constant) :
out_id = self.tcurr
self.tsemantics[self.tcurr] = {'shape': [1,], 'format': DENSE, 'scalar': True}
self.declarations.append(
{
"type": "V",
"value": f"{node.value:e}",
"todo": "l",
"id": out_id,
})
Expand All @@ -333,6 +336,52 @@ def visit_Constant(self, node: Constant) :
def create_binOp(self, node, operands, no_assign):
op0_sems = self.tsemantics[operands[0]]
op1_sems = self.tsemantics[operands[1]]
if op0_sems['scalar'] and op1_sems['scalar']:
op = ''
if isinstance(node.op, ast.Add):
op = '+'
elif isinstance(node.op, ast.Sub):
op = '-'
elif isinstance(node.op, ast.Mult):
op = '*'
elif isinstance(node.op, ast.Div):
op = '/'
else:
raise "Unexpected operator {}".format(node.op)
self.ops.append(
{
"op_type": "scalar",
"op": op,
"operands": operands[::-1],
"shapes": [op1_sems['shape'], op0_sems['shape'], [1,]],
"out_id": self.tcurr,
}
)
self.tsemantics[self.tcurr] = {'shape': [1,], 'format': DENSE, 'scalar': True}

self.tcurr += 1
self.declarations.append({
"type": "V",
"value": f"{0:e}",
"is_input": False,
"todo": "l",
"format": DENSE,
"shape": [1,],
# "dimsSSA": [self.get_index_constant(d) for d in op_semantics['shape']],
"id": self.tcurr,
})
self.ops.append(
{
"op_type": "=",
"shapes": [[1,]]*2,
"lhs": self.tcurr,
"rhs": self.tcurr-1,
"beta": no_assign,
})
self.tsemantics[self.tcurr] = self.tsemantics[self.tcurr-1]

self.tcurr +=1
return self.tcurr-1

format = self.sp_elw_add_sub_conversions[op0_sems['format']][op1_sems['format']]
if self.tsemantics[operands[0]]['format'] != DENSE:
Expand All @@ -343,19 +392,19 @@ def create_binOp(self, node, operands, no_assign):
self.need_opt_comp_workspace = op0_sems['format'] == CSR and op1_sems['format'] == CSR
if isinstance(node.op, ast.Add):
self.ops.append(
{
"op_type": "+",
"shapes": [op_semantics['shape']] * 3,
"operands": operands,
"op_ilabels": [[self.get_next_indexlabel_with_val(d) for d in op_semantics['shape']]] * 3,
"beta": 0,
"out_id": self.tcurr,
})
{
"op_type": "+",
"shapes": [op_semantics['shape']] * 3,
"operands": operands,
"op_ilabels": [[self.get_next_indexlabel_with_val(d) for d in op_semantics['shape']]] * 3,
"beta": 0,
"out_id": self.tcurr,
})

for d in op_semantics['shape']:
self.reset_indexlabel_with_val(d)

self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']]}
self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']], 'scalar': False}
# if not no_assign:
self.tcurr += 1
self.declarations.append({
Expand Down Expand Up @@ -389,7 +438,7 @@ def create_binOp(self, node, operands, no_assign):

for d in op_semantics['shape']:
self.reset_indexlabel_with_val(d)
self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']]}
self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']], 'scalar': False}
# if not no_assign:
self.tcurr += 1
self.declarations.append({
Expand Down Expand Up @@ -428,7 +477,7 @@ def create_binOp(self, node, operands, no_assign):
format = self.sp_elw_mult_conversions[op0_sems['format']][op1_sems['format']]
for d in op_semantics['shape']:
self.reset_indexlabel_with_val(d)
self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']]}
self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']], 'scalar': False}
# if not no_assign:
self.tcurr += 1

Expand Down Expand Up @@ -487,7 +536,7 @@ def visit_Method_Call(self, node: Call, obj):

if node.func.attr == "transpose":
out_format = self.sp_mattr_conversions[op_semantics['format']]
self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'][::-1], 'format': out_format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape'][::-1]]}
self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'][::-1], 'format': out_format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape'][::-1]], 'scalar': False}

in_ilabels = [self.get_next_indexlabel_with_val(d) for d in op_semantics['shape']]
for d in op_semantics['shape']:
Expand Down Expand Up @@ -529,7 +578,7 @@ def visit_Method_Call(self, node: Call, obj):
self.tsemantics[self.tcurr] = self.tsemantics[self.tcurr-1]

elif node.func.attr == "sum":
self.tsemantics[self.tcurr] = {'shape': [1,], 'format': DENSE}
self.tsemantics[self.tcurr] = {'shape': 1, 'format': DENSE, 'scalar': True}
self.ops.append(
{
"op_type": "s",
Expand All @@ -538,13 +587,13 @@ def visit_Method_Call(self, node: Call, obj):
"out_id": self.tcurr,
})
# if not no_assign:
self.declarations.append(
{
"type": "V",
"todo": "l",
"format": DENSE,
"id": self.tcurr,
})
# self.declarations.append(
# {
# "type": "V",
# "todo": "l",
# "format": DENSE,
# "id": self.tcurr,
# })
elif node.func.attr == "multiply":
op1 = NewVisitor.visit(self, node.args[0])
op1_sems = self.tsemantics[op1]
Expand All @@ -566,7 +615,7 @@ def visit_Method_Call(self, node: Call, obj):
})

format = self.sp_elw_mult_conversions[op_semantics['format']][op1_sems['format']]
self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape'][::-1]]}
self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape'][::-1]], 'scalar': False}
# if not no_assign:
self.tcurr += 1
self.declarations.append(
Expand Down Expand Up @@ -690,7 +739,7 @@ def visit_Bin_Einsum_Call(self, operands, llabels, mask,semiring, beta, no_assig
"op_ilabels": [in_ilabels, out_ilabels]
})

self.tsemantics[self.tcurr] = {'shape': shape, 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in shape]}
self.tsemantics[self.tcurr] = {'shape': shape, 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in shape], 'scalar': False}
if not no_assign:
self.tcurr += 1
self.declarations.append(
Expand Down Expand Up @@ -775,7 +824,7 @@ def visit_Einsum_Call(self, node: Call):

# ("*", operands, indices+','+indices+'->'+indices, self.tcurr, semiring))
format = self.sp_elw_mult_conversions[op0_sems['format']][op1_sems['format']]
self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']] }
self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'format': format, 'dimsSSA': [self.get_index_constant(d) for d in op_semantics['shape']], 'scalar': False }
self.tcurr += 1
self.declarations.append(
{
Expand Down Expand Up @@ -856,7 +905,10 @@ def wrapper(*pos_args, **kwargs):
v.visit(parsed_func)
in_types = []
for arg in v.in_args:
in_types.append(("%t"+str(arg), "tensor<{}xf64>".format("x".join(str(d) for d in v.tsemantics[arg]['shape']))))
if isinstance(v.tsemantics[arg]['shape'], int):
in_types.append(("%t"+str(arg), "tensor<1xf64>"))
else:
in_types.append(("%t"+str(arg), "tensor<{}xf64>".format("x".join(str(d) for d in v.tsemantics[arg]['shape']))))
irb = builders.MLIRFunctionBuilder(
func_def.name,
input_types=in_types,
Expand All @@ -868,6 +920,7 @@ def wrapper(*pos_args, **kwargs):


dense_tensors = []
scalars = []
for dec in v.declarations:

if dec["type"] == "T":
Expand All @@ -879,11 +932,17 @@ def wrapper(*pos_args, **kwargs):
irb.add_statement(t.build_tensor())
elif dec["type"] == "C":
irb.add_statement('%d{} = arith.constant {} : index '.format(dec["id"], dec["value"]))
elif dec["type"] == "V":
scalars.append('%t{} = ta.constant dense<{}> : tensor<1xf64> '.format(dec["id"], dec["value"]))
# irb.add_statement('%t{} = ta.constant dense<{}> : tensor<1xf64> '.format(dec["id"], dec["value"]))


for t in dense_tensors:
irb.add_statement(t.build_tensor())

for t in scalars:
irb.add_statement(t)

for op in v.ops:
if op["op_type"] == 'c':
op["formats"] = [v.tsemantics[t]['format'] for t in op["operands"]] + [v.tsemantics[op["out_id"]]['format']]
Expand All @@ -893,6 +952,8 @@ def wrapper(*pos_args, **kwargs):
else:
op["mask"] = (op["mask"][0], op["mask"][1], None)
irb.add_statement(builders.ArithOp_Builder(op).build_op())
elif op["op_type"] == 'scalar':
irb.add_statement(builders.ScalarOp_Builder(op).build_op())
elif op["op_type"] == 's':
irb.add_statement(builders.TensorSumBuilder(op).build_op())
elif op["op_type"] == 'p':
Expand Down
Loading

0 comments on commit 4774f1e

Please sign in to comment.