Skip to content

Commit

Permalink
Added new notation for masking in numpy-scipy interface
Browse files Browse the repository at this point in the history
  • Loading branch information
pthomadakis committed Oct 22, 2023
1 parent c6d4940 commit 5a5a521
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 14 deletions.
54 changes: 44 additions & 10 deletions frontends/numpy-scipy/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class NewVisitor(ast.NodeVisitor):
uniqueLabels = []
in_args = []
returns = []
mask = None
need_opt_comp_workspace = False
def __init__(self,inputs):
self.inputs = inputs
Expand Down Expand Up @@ -172,9 +173,16 @@ def visit_FunctionDef(self, node):
NewVisitor.visit(self, stmt)

def visit_Assign(self, node):
vals = [NewVisitor.visit(self, node.value)]
for l,v in zip(node.targets, vals):
self.tsymbols[l.id] = v
if isinstance(node.targets[0], ast.Subscript): # We do not support multiple targets currently
id = node.targets[0].value.id
mask = node.targets[0].slice
self.mask = mask
v = NewVisitor.visit(self, node.value)
self.tsymbols[id] = v
else:
vals = [NewVisitor.visit(self, node.value)]
for l,v in zip(node.targets, vals):
self.tsymbols[l.id] = v

def visit_Call(self, node: Call) -> Any:
obj = None
Expand Down Expand Up @@ -245,6 +253,20 @@ def visit_BinOp(self, node: BinOp) -> Any:
self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'labels': op_semantics['labels'], 'format': format}
self.declarations.append(('d', 'T', 'l', self.tcurr))
elif isinstance(node.op, ast.MatMult):
mask = (None,None)
if self.mask != None:
mask = (NewVisitor.visit(self, self.mask.elts[0]), self.mask.elts[1].value)
mask_sems = self.tsemantics[mask[0]]
if 'labels' not in mask_sems:
labels = []
for d in mask_sems['shape']:
self.iLabelsToVals[self.icurr] = (d, mask_sems['format'])
labels.append(self.icurr)
self.icurr += 1
mask_sems['labels'] = labels
self.mask = None


op1 = self.tsemantics[operands[0]]
op2 = self.tsemantics[operands[1]]
if len(op2['shape']) > 2 or len(op1['shape'])> 2:
Expand Down Expand Up @@ -307,12 +329,12 @@ def visit_BinOp(self, node: BinOp) -> Any:
# elif len(op1['shape']) == 1 and len(op2['shape']) == 1:
# shape = [1,0]
self.need_opt_comp_workspace = op1['format'] or op2['format']
self.ops.append(("c", operands, indices, self.tcurr, None, None, None))
# self.ops.append(("c", operands, iLabels, self.tcurr, mask, mask_type, semiring))
self.ops.append(("c", operands, indices, self.tcurr, mask[0], mask[1], None))
format = self.sp_matmult_conversions[op1['format']][op2['format']]
self.tsemantics[self.tcurr] = {'shape': shape, 'labels': labels, 'format': format}
self.declarations.append(('d', 'T', 'l', self.tcurr))
self.tcurr +=1

return out_id

def visit_Method_Call(self, node: Call, obj):
Expand Down Expand Up @@ -351,9 +373,20 @@ def visit_Return(self, node):

def visit_Einsum_Call(self, node: Call):
out_id = self.tcurr
mask = None
mask_type = "none"
mask = (None, None)
# mask_type = "none"
semiring = None
if self.mask != None:
mask = (NewVisitor.visit(self, self.mask.elts[0]), self.mask.elts[1].value)
mask_sems = self.tsemantics[mask[0]]
if 'labels' not in mask_sems:
labels = []
for d in mask_sems['shape']:
self.iLabelsToVals[self.icurr] = (d, mask_sems['format'])
labels.append(self.icurr)
self.icurr += 1
mask_sems['labels'] = labels
self.mask = None
iLabels = node.args[0].value
ops, res = iLabels.split('->')
ops = ops.split(',')
Expand All @@ -370,11 +403,12 @@ def visit_Einsum_Call(self, node: Call):
labels.append(self.icurr)
self.icurr += 1
mask_sems['labels'] = labels
mask_type = "pull"
mask = (mask, "pull")
elif key.arg == 'semiring':
semiring = key.value.value
elif key.arg == 'mask_type':
mask_type = key.value.value
mask = (mask[0], key.value.value)

operands = []
if len(node.args) > 1:
for arg in node.args[1:]:
Expand Down Expand Up @@ -529,7 +563,7 @@ def visit_Einsum_Call(self, node: Call):
format = self.sp_matmult_conversions[format][self.tsemantics[op]['format']]
if format != DENSE:
self.need_opt_comp_workspace = True
self.ops.append(("c", operands, iLabels, self.tcurr, mask, mask_type, semiring))
self.ops.append(("c", operands, iLabels, self.tcurr, mask[0], mask[1], semiring))
self.tsemantics[self.tcurr] = {'shape': shape, 'labels': labels, 'format': format}
self.tcurr += 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def run_numpy(L0,L1,L2):

@comet.compile(flags=None)
def run_comet_with_jit(L0,L1,L2):
C = comet.einsum('ij,jk->ik', L1,L2, mask=L0,mask_type="push")
C[L0,"push"] = L1 @ L2 # Performs masking. Currently, only works on a single matmul operation
#or C = comet.einsum('ij,jk->ik', L1,L2, mask=L0, mask_type="push")
#or C[L0,"push"] = comet.einsum('ij,jk->ik', L1,L2)
D = C.sum()

return D
Expand Down
6 changes: 3 additions & 3 deletions frontends/numpy-scipy/integration_tests/numpy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

categories = ['ops', 'opts', 'kernels', 'compound_exps', 'semiring']
files = []
if not os.path.exists("./llvm/"):
if not os.path.exists("../llvm/"):
os.symlink("../../llvm", "../llvm")
if not os.path.exists("./build/"):
if not os.path.exists("../build/"):
os.symlink("../../build", "../build")

for c in categories:
Expand Down Expand Up @@ -56,4 +56,4 @@
os.unlink("./"+c+"/MLIRGen")

os.unlink("../llvm")
os.unlink("../build")
os.unlink("../build")

0 comments on commit 5a5a521

Please sign in to comment.