Skip to content

Commit

Permalink
In numpy-scipy, added support for A.multiply(B) as a match for scipy'…
Browse files Browse the repository at this point in the history
…s method #22
  • Loading branch information
pthomadakis committed Oct 22, 2023
1 parent 5a5a521 commit e1010e2
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
19 changes: 18 additions & 1 deletion frontends/numpy-scipy/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,24 @@ def visit_Method_Call(self, node: Call, obj):
self.tsemantics[out_id] = {'shape': [1,], 'format': DENSE, 'labels': []}
self.ops.append(("s", [obj], out_id))
self.declarations.append(('d', 'v', 'l', out_id))

elif node.func.attr == "multiply":
op1 = NewVisitor.visit(self, node.args[0])
op1_sems = self.tsemantics[op1]
if 'labels' not in op1_sems:
op1_sems['labels'] = op_semantics['labels']
if self.tsemantics[obj]['format'] != DENSE:
op_semantics = self.tsemantics[obj]
self.tsemantics[op1]['labels'] = op_semantics['labels']
else:
op_semantics = self.tsemantics[op1]
if self.tsemantics[op1]['format'] != DENSE:
self.tsemantics[obj]['labels'] = op_semantics['labels']
s = 'a'
indices = "".join(chr(ord(s)+i) for i in range(len(op_semantics['labels'])))
self.ops.append(("*", [obj, op1], indices+','+indices+'->'+indices, self.tcurr, None))
format = self.sp_elw_mult_conversions[op_semantics['format']][op1_sems['format']]
self.tsemantics[self.tcurr] = {'shape': op_semantics['shape'], 'labels': op_semantics['labels'], 'format': format}
self.declarations.append(('d', 'T', 'l', self.tcurr))
self.tcurr +=1

return out_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import comet

def run_numpy(A,B):
C = A * B
C = A.multiply( B)

return C

@comet.compile(flags=None)
def run_comet_with_jit(A,B):
C = A * B
C = A.multiply( B)

return C

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import comet

def run_numpy(A,B):
C = A * B
C = A.multiply( B)

return C

@comet.compile(flags=None)
def run_comet_with_jit(A,B):
C = A * B
C = A.multiply( B)

return C

Expand Down

0 comments on commit e1010e2

Please sign in to comment.