Skip to content

Commit

Permalink
correct the format,and removed some unnecessary changes to the pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
“username” committed Oct 22, 2024
1 parent 1730830 commit 8a0989d
Show file tree
Hide file tree
Showing 7 changed files with 2 additions and 121 deletions.
2 changes: 1 addition & 1 deletion examples/BuddyMatmul/linalg-transposematmulb-f32.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ func.func @main(){
call @printMemrefF32(%printed_m5) : (memref<*xf32>) -> ()

return
}
}
54 changes: 0 additions & 54 deletions frontend/Python/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,50 +164,6 @@ def add_node(self, node: Op):
self._body.append(node)
self.node_table[node.name] = node

def check_deletenode(self, node : Op) -> bool:
#node : graphnode.name
if (not(node.name in self.node_table) ):
raise KeyError("node{0} not in graph".format(node.name))

if (len(node._children)==0):
return True
return False;

def delete_node(self, node: Op,parents : List[Op]):
for i in parents:
i._children.remove(node.name)
node.args.clear()
node.kwargs.clear()
node._children.clear()
self._body.remove(node)
self.node_table.pop(node.name)

def displace_node(self,node: Op,newnode:Op):
newnode._arguments = node.args
newnode._keyword_arguments = node.kwargs
newnode._tensor_meta = node.tensor_meta
newnode._op_type = node._op_type

#deal with users/childrens
for i in node._children:
newnode.add_children(i)
users = [self.node_table[i] for i in node._children]
for user in users:
user._parents[user._parents.index(node.name)]=newnode.name
user.args[user.args.index(node.name)]=newnode.name
node._children.clear()
#deal with parents+args
for i in node._parents:
newnode.add_parent(i)
parents = [self.node_table[i] for i in node._parents]
for parent in parents:
parent._children[parent._children.index(node.name)]=newnode.name
node._parents.clear()
#update node table
self._body[self._body.index(node)] = newnode
self.node_table.pop(node.name)
self.node_table[newnode.name] = newnode

def init_op_group(self):
"""
Initializes operation groups within the graph.
Expand All @@ -223,16 +179,6 @@ def init_op_group(self):
self.group_map_device[subgraph_name] = DeviceType.UNKNOW
self.op_groups[subgraph_name] = group

def check_classicfusetype(self,op:Op):
pattern = None
if isinstance(op,MatmulOp):
parentop = [ self.node_table[str(i)] for i in op._parents]
for target in parentop:
if (isinstance(target,PermuteOp) ):
pattern = target,parentop,"transpose+mamtmul2D"
#TODO : other patterns can be fused
return pattern

def fuse_ops(self, pattern_list: List[FunctionType]):
"""
Fuse operations in the graph based on provided fusion patterns.
Expand Down
4 changes: 0 additions & 4 deletions frontend/Python/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,6 @@ def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReduceType

class transpose_Matmul_fusedOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReduceType

class GetItemOp(Op):
def __init__(self) -> None:
Expand Down
41 changes: 0 additions & 41 deletions frontend/Python/graph/transform/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,45 +28,6 @@
# OP_TYPE_FUSABLE_BY_SPECIFIC_PASS = []
# ANCHOR_OP_TYPE = []

classicfuse_register = {
"transpose+mamtmul2D": transpose_Matmul_fusedOp
}

def classic_fuse(graph : Graph):
for op in graph.body:
pattern = graph.check_classicfusetype(op)
if (pattern):
do_classicfusion(graph,op,pattern[0],pattern[1],pattern[2])
else:
continue

def do_classicfusion(graph : Graph,node,target : Op,parents : List[Op],pattern : str):
"""
Function to fuse some typical operations into one operation.
Such as transpose + matmul
Args:
- graph (Graph): The input graph to be simplified.
Returns:
- None: Modifies the input graph in place.
"""

fusedop = classicfuse_register.get(pattern)()
fusedop.name = "fused"+node.name
graph.displace_node(node,fusedop)
fusedop.args.pop(fusedop.args.index(target.name))
fusedop._parents.pop(fusedop._parents.index(target.name))
fusedop.args.extend(target.args)
fusedop._parents.extend(target._parents)
targets_parent = [graph.node_table[i] for i in target._parents]
for i in targets_parent:
i.add_children(fusedop.name)
target._children.pop(target._children.index(fusedop.name))

if(graph.check_deletenode(target)):
graph.delete_node(target,targets_parent)

def simply_fuse(graph: Graph):
"""
Function to fuse all operations into one graph.
Expand All @@ -79,8 +40,6 @@ def simply_fuse(graph: Graph):
"""
new_op_group = []
device = DeviceType.UNKNOW
classic_fuse(graph)

for op in graph.body:
if isinstance(op, PlaceholderOp):
continue
Expand Down
19 changes: 0 additions & 19 deletions frontend/Python/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,24 +1150,6 @@ def matmul_op(
op = linalg.matmul(input1, input2, outs=[matmul_result_buffer])
return op

def transpose_matmul_fused_op(
node: transpose_Matmul_fusedOp,
symbol_table:Dict[Tuple[str, int], ir.Operation]
):
input1 = symbol_table.get((str(node.args[0]),0))
input2 = symbol_table.get((str(node.args[1]),0))

if input1 is None or input2 is None:
return
output_shape = list(node.tensor_meta["shape"])
dtype = node.tensor_meta["dtype"]
mlir_dtype = mlir_element_type_get(dtype)
tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype)
element = mlir_element_attr_get(dtype, 0.0)
attr = ir.DenseElementsAttr.get_splat(tensor_type, element)
result_buffer = arith.ConstantOp(tensor_type, attr).result
op = linalg.matmul_transpose_b(input1, input2, outs=[result_buffer])
return op

def transpose_op(
node: TransposeOp,
Expand Down Expand Up @@ -1986,7 +1968,6 @@ def gt_op(node: GtOp, symbol_table):

ops_registry = {
"MatmulOp": matmul_op,
"transpose_Matmul_fusedOp": transpose_matmul_fused_op,
"ArangeOp": arange_op,
"UnsqueezeOp": unsqueeze_op,
"ViewOp": view_op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class MatMulTransposeBVecPattern : public ConversionPattern{
Value B = op->getOperand(1);
Value C = op->getOperand(2);

// Get shape of input and output
// Get shape of input and output.
ShapedType ATy = A.getType().cast<ShapedType>();
Type eleTy = ATy.getElementType();

Expand Down
1 change: 0 additions & 1 deletion tools/buddy-opt/buddy-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ void registerLowerSchePass();
void registerFuncBufferizeDynamicOffsetPass();
void registerConvertMemcpyToGPUPass();
void registerLegalizeShmemOutliningPass();
void registerMatMul_TransposeB_VecPass();
void registerMatMulTransposeBVecPass();
void registerConvertMemcpyToGPUPass();
void registerLegalizeShmemOutliningPass();
Expand Down

0 comments on commit 8a0989d

Please sign in to comment.