diff --git a/examples/BuddyBert/bert-main.cpp b/examples/BuddyBert/bert-main.cpp index c75ea9d8a9..d3f0075491 100644 --- a/examples/BuddyBert/bert-main.cpp +++ b/examples/BuddyBert/bert-main.cpp @@ -93,7 +93,7 @@ int main() { /// Execute forward inference of the model. _mlir_ciface_forward(&result, &arg0, &arg1, &pureStrContainer, - &attention_mask, &token_type_ids); + &token_type_ids, &attention_mask); const auto inferenceEnd = std::chrono::high_resolution_clock::now(); const std::chrono::duration inferenceTime = diff --git a/examples/BuddyLeNet/buddy-lenet-import.py b/examples/BuddyLeNet/buddy-lenet-import.py index 95e76de253..c787061a55 100644 --- a/examples/BuddyLeNet/buddy-lenet-import.py +++ b/examples/BuddyLeNet/buddy-lenet-import.py @@ -23,7 +23,6 @@ import numpy as np import torch -from torch._inductor.decomposition import decompositions as inductor_decomp from buddy.compiler.frontend import DynamoCompiler from buddy.compiler.graph import GraphDriver @@ -39,13 +38,12 @@ ) model = LeNet() -model = torch.load(model_path + "/lenet-model.pth") +model = torch.load(model_path + "/lenet-model.pth", weights_only=False) model = model.eval() # Initialize Dynamo Compiler with specific configurations as an importer. dynamo_compiler = DynamoCompiler( primary_registry=tosa.ops_registry, - aot_autograd_decomposition=inductor_decomp, ) data = torch.randn([1, 1, 28, 28]) diff --git a/examples/BuddyLlama/CMakeLists.txt b/examples/BuddyLlama/CMakeLists.txt index a6bfc2f742..6953b7de7d 100644 --- a/examples/BuddyLlama/CMakeLists.txt +++ b/examples/BuddyLlama/CMakeLists.txt @@ -53,6 +53,7 @@ add_custom_command( COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLlama/subgraph0.mlir -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | ${BUDDY_BINARY_DIR}/buddy-opt + -convert-elementwise-to-linalg -arith-expand -eliminate-empty-tensors -empty-tensor-to-alloc-tensor diff --git a/examples/BuddyLlama/import-llama2.py b/examples/BuddyLlama/import-llama2.py index 2903d6bd81..d893ee87f6 100644 --- a/examples/BuddyLlama/import-llama2.py +++ b/examples/BuddyLlama/import-llama2.py @@ -38,7 +38,7 @@ ) # Initialize the tokenizer and model from the specified model path. -tokenizer = LlamaTokenizer.from_pretrained(model_path) +tokenizer = LlamaTokenizer.from_pretrained(model_path, legacy=True) model = LlamaForCausalLM.from_pretrained(model_path, torchscript=True) model.config.use_cache = False diff --git a/examples/BuddyLlama/llama-main.cpp b/examples/BuddyLlama/llama-main.cpp index 0bfc1e5d2f..61c42f0db2 100644 --- a/examples/BuddyLlama/llama-main.cpp +++ b/examples/BuddyLlama/llama-main.cpp @@ -24,7 +24,7 @@ using namespace buddy; -constexpr size_t ParamsSize = 6755192832; +constexpr size_t ParamsSize = 6738415680; constexpr size_t MaxVocabSize = 32000; constexpr size_t MaxTokenLength = 40; constexpr size_t HiddenSize = 4096; diff --git a/examples/BuddyWhisper/whisper-main.cpp b/examples/BuddyWhisper/whisper-main.cpp index 7d69ea3074..011b5c847e 100644 --- a/examples/BuddyWhisper/whisper-main.cpp +++ b/examples/BuddyWhisper/whisper-main.cpp @@ -33,7 +33,7 @@ using namespace std; using namespace buddy; using namespace dap; -constexpr size_t ParamsSize = 99148800; +constexpr size_t ParamsSize = 72593920; constexpr size_t MaxVocabSize = 51865; constexpr size_t MaxTokenLength = 448; @@ -180,4 +180,4 @@ int main() { << std::endl; return 0; -} +} \ No newline at end of file diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index 9d8c80f014..6441b23de3 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -124,6 +124,7 @@ def __init__( "mean.dim": MeanOp, "rsqrt.default": RsqrtOp, "mul.Tensor": MulOp, + "mul.Scalar": MulOp, "t.default": TOp, "mm.default": MatmulOp, "transpose.int": TransposeOp, @@ -167,6 +168,10 @@ def __init__( "split.Tensor":SplitOp, "max.default":MaxOp, "gt.Scalar":GtOp, + "_scaled_dot_product_flash_attention_for_cpu.default": ScaledDotProductFlashAttentionForCpuOp, + "ge.Scalar": GeOp, + "gt.Tensor": GreaterThanOp, + "_unsafe_index.Tensor": UnsafeIndexOp, } @property @@ -257,11 +262,26 @@ def _compile_fx( return for torchdynamo's call. """ - params = { - **dict(gm.named_parameters(remove_duplicate=False)), - **dict(gm.named_buffers(remove_duplicate=False)), - } - params_flat, _ = pytree.tree_flatten(params) + # params = { + # # **dict(gm.named_parameters(remove_duplicate=False)), + # **dict(gm.named_buffers(remove_duplicate=False)), + # } + # print(len(params)) + # params_flat, _ = pytree.tree_flatten(params) + inputs_pos = [] + params_pos = [] + buffers_pos = [] + for i, node in enumerate(gm.graph.nodes): + if i >= len(inputs): + break + if not str(node).startswith("l_self"): + inputs_pos.append(i) + elif "buffer" in str(node): + buffers_pos.append(i) + else: + params_pos.append(i) + + params_flat = [inputs[i] for i in params_pos + buffers_pos] if self._verbose: print("Graph in tabular form:") @@ -271,7 +291,9 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): """Compile a FX graph in Aten/Prims IR to MLIR.""" nonlocal params_flat func_inputs = [] - for inp in _inputs[len(params_flat) :]: + for i in inputs_pos: + # for inp in _inputs[len(params_flat) :]: + inp = _inputs[i] inp_shape = inp.shape inp_dtype = self._torch_dtype_translate(str(inp.dtype)) func_inputs.append(TensorMeta(inp_shape, inp_dtype)) @@ -286,7 +308,22 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): self._func_name, self._verbose ) - for gm_node in _gm.graph.nodes: + param_nodes = [] + buffers_nodes = [] + input_nodes = [] + other_nodes = [] + for i, node in enumerate(_gm.graph.nodes): + if i in params_pos: + param_nodes.append(node) + elif i in buffers_pos: + buffers_nodes.append(node) + elif i in inputs_pos: + input_nodes.append(node) + else: + other_nodes.append(node) + gm_nodes = param_nodes + buffers_nodes + input_nodes + other_nodes + + for gm_node in gm_nodes: node_users = [] for user in gm_node.users.keys(): node_users.append(str(user)) diff --git a/frontend/Python/graph/operation.py b/frontend/Python/graph/operation.py index 0eb31fd961..c1a7b09746 100644 --- a/frontend/Python/graph/operation.py +++ b/frontend/Python/graph/operation.py @@ -534,3 +534,27 @@ class GtOp(Op): def __init__(self) -> None: super().__init__() self._op_type = OpType.ElementwiseType + + +class ScaledDotProductFlashAttentionForCpuOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ElementwiseType + + +class GeOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ElementwiseType + + +class GreaterThanOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.BroadcastType + + +class UnsafeIndexOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ReshapeType diff --git a/frontend/Python/ops/linalg.py b/frontend/Python/ops/linalg.py index b561b3433a..ec6c827e6c 100644 --- a/frontend/Python/ops/linalg.py +++ b/frontend/Python/ops/linalg.py @@ -1073,6 +1073,26 @@ def mul_op( element = mlir_element_attr_get(dtype, node.args[1]) attr = ir.DenseElementsAttr.get_splat(tensor_type, element) input2 = arith.ConstantOp(tensor_type, attr).result + + input1_dtype = ir.RankedTensorType(input1.type).element_type + input2_dtype = ir.RankedTensorType(input2.type).element_type + if input1_dtype != mlir_dtype: + input1 = tosa.CastOp( + ir.RankedTensorType.get( + ir.RankedTensorType(input1.type).shape, + mlir_dtype, + ), + input1, + ) + if input2_dtype != mlir_dtype: + input2 = tosa.CastOp( + ir.RankedTensorType.get( + ir.RankedTensorType(input2.type).shape, + mlir_dtype, + ), + input2, + ) + if input1 is None or input2 is None: return mul_result_tensor_type = ir.RankedTensorType.get(shape, mlir_dtype) @@ -1211,28 +1231,51 @@ def index_op( return input1_shape = ir.RankedTensorType(input1.type).shape input2 = node.args[1] + input2_dim_sum = 0 + for i in range(len(input2)): + input2_dim_sum += len(symbol_table.get((str(input2[i]), 0)).type.shape) output_shape = list(node.tensor_meta["shape"]) + input_shape = input1.type.shape dtype = node.tensor_meta["dtype"] mlir_dtype = mlir_element_type_get(dtype) if len(input2) < len(input1_shape): tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) output = tensor.EmptyOp(output_shape, mlir_dtype) - loops = ir.RankedTensorType( - symbol_table.get((str(input2[0]), 0)).type - ).shape generic_map = ir.AffineMap.get_permutation( - [i for i in range(len(output_shape))] + [i for i in range(max(len(output_shape), len(input_shape)))] ) - input_map = [ - ir.AffineMapAttr.get( - generic_map.get_submap([j for j in range(len(loops))]) + input_map = [] + for i in range(len(input2)): + input2_shape = symbol_table.get((str(input2[i]), 0)).type.shape + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(i, i + len(input2_shape))] + ) + ) + ) + if len(input_shape) > len(output_shape): + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [ + j + for j in range( + len(input_shape) - len(output_shape), + len(input_shape), + ) + ] + ) + ) ) - for i in range(len(input2)) - ] + [ - ir.AffineMapAttr.get( - generic_map.get_submap([j for j in range(len(output_shape))]) + else: + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(len(output_shape))] + ) + ) ) - ] operands = [symbol_table.get((str(i), 0)) for i in input2] op = linalg.GenericOp( [tensor_type], @@ -1241,7 +1284,7 @@ def index_op( ir.ArrayAttr.get(input_map), ir.ArrayAttr.get( [ir.Attribute.parse("#linalg.iterator_type")] - * len(output_shape) + * max(len(output_shape), len(input_shape)) ), ) arguments = [ @@ -1253,7 +1296,9 @@ def index_op( indexcast_op = arith.IndexCastOp(ir.IndexType.get(), i) block.append(indexcast_op) index.append(indexcast_op.result) - for i in range(len(loops), len(output_shape) - len(input2) + 1): + for i in range( + input2_dim_sum, max(len(input_shape), len(output_shape)) + ): index_op = linalg.IndexOp(ir._i64Attr(i, None)) block.append(index_op) index.append(index_op.result) @@ -1553,6 +1598,9 @@ def softmax_op( if dim < 0: dim += len(output_shape) mlir_dtype = mlir_element_type_get(dtype) + max_vals = tosa.ReduceMaxOp(input1, dim) + sub_op_output = ir.RankedTensorType.get(input1.type.shape, mlir_dtype) + input1 = tosa.SubOp(sub_op_output, input1, max_vals) # tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) # output = tensor.EmptyOp(output_shape, mlir_dtype) # op = linalg.softmax( @@ -1781,18 +1829,23 @@ def where_op( input3 = symbol_table.get((str(node.args[2]), 0)) if input1 is None or input2 is None or input3 is None: return - output_shape = list(node.tensor_meta["shape"]) dtype = node.tensor_meta["dtype"] mlir_dtype = mlir_element_type_get(dtype) tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) output = tensor.EmptyOp(output_shape, mlir_dtype) + + if not isinstance(input2.type, ir.RankedTensorType): + input2 = tensor.SplatOp(tensor_type, input2).result + if not isinstance(input3.type, ir.RankedTensorType): + input3 = tensor.SplatOp(tensor_type, input3).result + generic_map = ir.AffineMap.get_permutation( [i for i in range(len(output_shape))] ) op = linalg.GenericOp( [tensor_type], - [input1, input3], + [input1, input2, input3], [output], ir.ArrayAttr.get( [ @@ -1811,6 +1864,11 @@ def where_op( [i for i in range(len(output_shape))] ) ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(len(output_shape))] + ) + ), ] ), ir.ArrayAttr.get( @@ -1822,11 +1880,14 @@ def where_op( op.region, [ ir.RankedTensorType(input1.type).element_type, + ir.RankedTensorType(input2.type).element_type, ir.RankedTensorType(input3.type).element_type, ir.RankedTensorType(output.result.type).element_type, ], ) - select_op = arith.SelectOp(block.arguments[0], input2, block.arguments[1]) + select_op = arith.SelectOp( + block.arguments[0], block.arguments[1], block.arguments[2] + ) block.append(select_op) block.append(linalg.YieldOp([select_op.result])) @@ -1966,6 +2027,321 @@ def gt_op(node: GtOp, symbol_table): return cmp_op +def ge_op( + node: GeOp, + symbol_table: Dict[Tuple[str, int], ir.Operation], +): + """ + Import the tensor greater equal operation. + From buddy GreaterEqualOp to MLIR arith `constant` operation. + Note: This op, campare two input nodes, and output bool tensor to represent + compare result. + Args: + node: Containing information from the input graph node. + symbol_table: A dictionary mapping symbols to their corresponding + operations. + Returns: + op: The operation return the linalg.generic op. + """ + input_tensor = symbol_table.get((str(node.args[0]), 0), node.args[0]) + input_dtype = ir.RankedTensorType(input_tensor.type).element_type + input_shape = ir.RankedTensorType(input_tensor.type).shape + tensor_type = ir.RankedTensorType.get(input_shape, input_dtype) + + scalar = arith.ConstantOp(input_dtype, node.args[1]) + rhs = tensor.SplatOp(tensor_type, scalar) + + if str(input_dtype).find("i") != -1: + cmp_op = arith.CmpIOp(5, input_tensor, rhs) + else: + cmp_op = arith.CmpFOp(3, input_tensor, rhs) + + return cmp_op + + +def greater_than_op( + node: GreaterThanOp, + symbol_table: Dict[Tuple[str, int], ir.Operation], +): + """ + Import the tensor greater than operation. + From buddy GreaterThanOp to MLIR arith `constant` operation. + Note: This op, campare two input nodes, and output bool tensor to represent + compare result. + Args: + node: Containing information from the input graph node. + symbol_table: A dictionary mapping symbols to their corresponding + operations. + Returns: + op: The operation return the linalg.generic op. + """ + input1 = symbol_table.get((str(node.args[0]), 0)) + input2 = symbol_table.get((str(node.args[1]), 0)) + output_shape = list(node.tensor_meta["shape"]) + dtype = node.tensor_meta["dtype"] + # value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 4) + shp1 = list(ir.RankedTensorType(ir.Value(input1).type).shape) + shp2 = list(ir.RankedTensorType(ir.Value(input2).type).shape) + dtype = mlir_element_type_get(dtype) + tensor_type = ir.RankedTensorType.get(output_shape, dtype) + output = tensor.EmptyOp(output_shape, dtype) + if len(shp1) < len(shp2): + if int(shp1[-1]) > 1 and shp2[-1] == 1: + generic_map = ir.AffineMap.get_permutation( + [i for i in range(len(shp2) + 1)] + ) + op = linalg.GenericOp( + [tensor_type], + [input1, input2], + [output], + ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get( + generic_map.get_submap( + [ + i + for i in range( + len(shp2) - len(shp1), len(shp2) + ) + ] + ) + ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(0, len(shp2) - 1)] + + [len(shp2)] + ) + ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(0, len(shp2))] + ) + ), + ] + ), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] + * len(shp2) + + [ir.Attribute.parse("#linalg.iterator_type")] + ), + ) + block = ir.Block.create_at_start( + op.region, + [ + ir.RankedTensorType(input2.type).element_type, + ir.RankedTensorType(input2.type).element_type, + dtype, + ], + ) + if ( + str(ir.RankedTensorType(input2.type).element_type).find("i") + != -1 + ): + cmpop = arith.CmpIOp(4, block.arguments[0], block.arguments[1]) + else: + cmpop = arith.CmpFOp(2, block.arguments[0], block.arguments[1]) + block.append(cmpop) + block.append(linalg.YieldOp([cmpop.result])) + + return op + + +def unsafe_index_op( + node: UnsafeIndexOp, + symbol_table: Dict[Tuple[str, int], ir.Operation], +): + """ + Import the tensor _unsafe_index operation. + From buddy UnsafeIndexOp to MLIR linalg `generic` + operation. + Note: This op, get input node slice result by input index. + Args: + node: Containing information from the input graph node. + symbol_table: A dictionary mapping symbols to their corresponding + operations. + Returns: + op: The operation return the linalg.generic op. + """ + assert len(node.args) == 2 + input1 = symbol_table.get((str(node.args[0]), 0)) + if input1 is None: + return + input1_shape = ir.RankedTensorType(input1.type).shape + input2 = node.args[1] + have_none = False + for i in input2: + if i == None: + have_none = True + break + input2_dim_sum = 0 + for i in range(len(input2)): + input2_dim_sum += ( + len(symbol_table.get((str(input2[i]), 0)).type.shape) + if input2[i] != None + else 0 + ) + output_shape = list(node.tensor_meta["shape"]) + input_shape = input1.type.shape + dtype = node.tensor_meta["dtype"] + mlir_dtype = mlir_element_type_get(dtype) + if len(input2) < len(input1_shape): + tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) + output = tensor.EmptyOp(output_shape, mlir_dtype) + generic_map = ir.AffineMap.get_permutation( + [i for i in range(max(len(output_shape), len(input_shape)))] + ) + input_map = [] + for i in range(len(input2)): + input2_shape = symbol_table.get((str(input2[i]), 0)).type.shape + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(i, i + len(input2_shape))] + ) + ) + ) + if len(input_shape) > len(output_shape): + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [ + j + for j in range( + len(input_shape) - len(output_shape), + len(input_shape), + ) + ] + ) + ) + ) + else: + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(len(output_shape))] + ) + ) + ) + operands = [symbol_table.get((str(i), 0)) for i in input2] + op = linalg.GenericOp( + [tensor_type], + operands, + [output], + ir.ArrayAttr.get(input_map), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] + * max(len(output_shape), len(input_shape)) + ), + ) + arguments = [ + ir.RankedTensorType(i.type).element_type for i in operands + ] + [ir.RankedTensorType(output.result.type).element_type] + block = ir.Block.create_at_start(op.region, arguments) + index = [] + for i in block.arguments[:-1]: + indexcast_op = arith.IndexCastOp(ir.IndexType.get(), i) + block.append(indexcast_op) + index.append(indexcast_op.result) + for i in range( + input2_dim_sum, max(len(input_shape), len(output_shape)) + ): + index_op = linalg.IndexOp(ir._i64Attr(i, None)) + block.append(index_op) + index.append(index_op.result) + value = tensor.ExtractOp(input1, index) + block.append(value) + block.append(linalg.YieldOp([value.result])) + else: + tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) + output = tensor.EmptyOp(output_shape, mlir_dtype) + generic_map = ir.AffineMap.get_permutation( + [i for i in range(max(len(output_shape), len(input_shape)))] + ) + input_map = [] + for i in range(len(input2)): + if input2[i] == None: + continue + input2_shape = symbol_table.get((str(input2[i]), 0)).type.shape + if have_none: + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap([j for j in range(i, i + 1)]) + ) + ) + if len(input_shape) > len(output_shape): + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [ + j + for j in range( + len(input_shape) - len(output_shape), + len(input_shape), + ) + ] + ) + ) + ) + else: + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(len(output_shape))] + ) + ) + ) + if have_none: + operands = [] + for i in input2: + if i == None: + continue + input2_ = symbol_table.get((str(i), 0)) + input2_shape = input2_.type.shape + if i != None and len(input2_shape) > 1: + total_size = 1 + for x in input2_shape: + total_size *= x + reshape_op = tosa.ReshapeOp( + input2_, memoryview(array.array("i", [total_size])) + ) + operands.append(reshape_op.result) + + else: + operands = [symbol_table.get((str(i), 0)) for i in input2] + op = linalg.GenericOp( + [tensor_type], + operands, + [output], + ir.ArrayAttr.get(input_map), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] + * max(len(output_shape), len(input_shape)) + ), + ) + arguments = [ + ir.RankedTensorType(i.type).element_type for i in operands + ] + [ir.RankedTensorType(output.result.type).element_type] + block = ir.Block.create_at_start(op.region, arguments) + index = [] + None_count = 0 + for i in range(len(input2)): + if input2[i] == None: + None_count += 1 + index_op = linalg.IndexOp(ir._i64Attr(i, None)) + block.append(index_op) + index.append(index_op.result) + else: + indexcast_op = arith.IndexCastOp( + ir.IndexType.get(), block.arguments[i - None_count] + ) + block.append(indexcast_op) + index.append(indexcast_op.result) + value = tensor.ExtractOp(input1, index) + block.append(value) + block.append(linalg.YieldOp([value.result])) + return op + + ops_registry = { "MatmulOp": matmul_op, "ArangeOp": arange_op, @@ -2001,4 +2377,7 @@ def gt_op(node: GtOp, symbol_table): "SplitOp": split_op, "MaxOp": max_op, "GtOp": gt_op, + "GeOp": ge_op, + "GreaterThanOp": greater_than_op, + "UnsafeIndexOp": unsafe_index_op, } diff --git a/frontend/Python/ops/tosa.py b/frontend/Python/ops/tosa.py index 797fdfd6d2..8ba1a834ec 100644 --- a/frontend/Python/ops/tosa.py +++ b/frontend/Python/ops/tosa.py @@ -18,13 +18,13 @@ # # ===--------------------------------------------------------------------------- -import array +import array, copy from typing import Dict, List, Tuple, Union import numpy import sys import mlir.ir as ir -from mlir.dialects import tensor, tosa, arith, linalg +from mlir.dialects import tensor, tosa, arith, linalg, math from ..graph import TensorDType from ..graph import ( @@ -62,6 +62,7 @@ ClampMaxOp, RandIntLowOp, ArgMaxOp, + ScaledDotProductFlashAttentionForCpuOp, ) from .utils import * @@ -273,8 +274,48 @@ def _inner_op(result_type, input1, input2): ir.IntegerAttr.get(ir.IntegerType.get_signless(8), 0), ) - input1 = symbol_table.get((str(node.args[0]), 0), node.args[0]) - input2 = symbol_table.get((str(node.args[1]), 0), node.args[1]) + output_shape = list(node.tensor_meta["shape"]) + dtype = node.tensor_meta["dtype"] + mlir_dtype = mlir_element_type_get(dtype) + + if isinstance(node.args[0], str): + input1 = symbol_table.get((str(node.args[0]), 0), node.args[0]) + else: + data = [node.args[0]] + input1_shape = numpy.array(data).shape + tensor_type = ir.RankedTensorType.get(input1_shape, mlir_dtype) + element = mlir_element_attr_get(dtype, node.args[0]) + attr = ir.DenseElementsAttr.get_splat(tensor_type, element) + input2 = arith.ConstantOp(tensor_type, attr).result + + if isinstance(node.args[1], str): + input2 = symbol_table.get((str(node.args[1]), 0), node.args[1]) + else: + data = [node.args[1]] + input2_shape = numpy.array(data).shape + tensor_type = ir.RankedTensorType.get(input2_shape, mlir_dtype) + element = mlir_element_attr_get(dtype, node.args[1]) + attr = ir.DenseElementsAttr.get_splat(tensor_type, element) + input2 = arith.ConstantOp(tensor_type, attr).result + + input1_dtype = ir.RankedTensorType(input1.type).element_type + input2_dtype = ir.RankedTensorType(input2.type).element_type + if input1_dtype != mlir_dtype: + input1 = tosa.CastOp( + ir.RankedTensorType.get( + ir.RankedTensorType(input1.type).shape, + mlir_dtype, + ), + input1, + ).result + if input2_dtype != mlir_dtype: + input2 = tosa.CastOp( + ir.RankedTensorType.get( + ir.RankedTensorType(input2.type).shape, + mlir_dtype, + ), + input2, + ).result return _gen_arith_binary_op(input1, input2, _inner_op) @@ -522,9 +563,67 @@ def convert_element_type_op(node: ConvertElementTypeOp, symbol_table): } input_tensor = symbol_table.get((str(node.args[0]), 0)) to_cast_type = types_mapping[node.args[1]] - sizes = ir.RankedTensorType(input_tensor.type).shape - output_type = ir.RankedTensorType.get(sizes, to_cast_type) - return tosa.CastOp(output_type, input_tensor) + input_type = ir.RankedTensorType(input_tensor.type).element_type + # When converting float to int, tosa.cast lowers to math.roundeven, but we don't need rounding. + if str(to_cast_type).find("i") != -1 and str(input_type).find("f") != -1: + output_shape = list(node.tensor_meta["shape"]) + tensor_type = ir.RankedTensorType.get(output_shape, to_cast_type) + output = tensor.EmptyOp(output_shape, to_cast_type) + + if str(to_cast_type) == "i1": + false_val = arith.ConstantOp(to_cast_type, 0) + true_val = arith.ConstantOp(to_cast_type, 1) + zero_val = arith.ConstantOp(input_type, 0.0) + + generic_map = ir.AffineMap.get_permutation( + [i for i in range(len(output_shape))] + ) + op = linalg.GenericOp( + [tensor_type], + [input_tensor], + [output], + ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(len(output_shape))] + ) + ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(len(output_shape))] + ) + ), + ] + ), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] + * len(output_shape) + ), + ) + block = ir.Block.create_at_start( + op.region, + [ + input_type, + to_cast_type, + ], + ) + if str(to_cast_type) == "i1": + is_zero = arith.CmpFOp(1, block.arguments[0], zero_val) + result = arith.SelectOp(is_zero, false_val, true_val) + block.append(is_zero) + block.append(result) + block.append(linalg.YieldOp([result.result])) + else: + fptosi_op = arith.FPToSIOp(to_cast_type, block.arguments[0]) + block.append(fptosi_op) + block.append(linalg.YieldOp([fptosi_op.result])) + else: + sizes = ir.RankedTensorType(input_tensor.type).shape + output_type = ir.RankedTensorType.get(sizes, to_cast_type) + op = tosa.CastOp(output_type, input_tensor) + + return op def clone_op(node: CloneOp, symbol_table): @@ -800,6 +899,7 @@ def expand_op(node: ExpandOp, symbol_table) -> ir.Operation: the result. """ to_expand_tensor = symbol_table.get((str(node.args[0]), 0)) + original_size = to_expand_tensor.type.shape new_size = node.args[1] result_element_type = ir.RankedTensorType( to_expand_tensor.type @@ -813,8 +913,14 @@ def expand_op(node: ExpandOp, symbol_table) -> ir.Operation: element = ir.FloatAttr.get(result_element_type, 0.0) else: raise NotImplementedError("Unsupported element type!") + expanded_size = [] + for dim, size in zip(original_size, new_size): + if size == -1: + expanded_size.append(dim) + else: + expanded_size.append(size) new_size_tensor_type = ir.RankedTensorType.get( - new_size, result_element_type + expanded_size, result_element_type ) new_size_attr = ir.DenseElementsAttr.get_splat( new_size_tensor_type, element @@ -1479,6 +1585,196 @@ def argmax_op(node: ArgMaxOp, symbol_table): return op +def scaled_dot_product_flash_attention_for_cpu_op( + node: ScaledDotProductFlashAttentionForCpuOp, symbol_table +): + """ + Perform scaled dot-product attention computation. + Args: + node (ScaledDotProductFlashAttentionForCpuOp): The scaled dot-product attention operation node with metadata. + symbol_table: Mapping of variable names to tensor references. + Returns: + result_reshape_op: Reshaped result tensor of the attention operation. + log_sumexp_op: Log-sum-exp constant operation. + """ + query = symbol_table.get((str(node.args[0]), 0), node.args[0]) + key = symbol_table.get((str(node.args[1]), 0), node.args[1]) + value = symbol_table.get((str(node.args[2]), 0), node.args[2]) + + if len(node.args) == 4: + dropout_p = node.args[3] + assert dropout_p != 0.0 + if len(node.args) == 5: + dropout_p = node.args[3] + is_causal = node.args[4] + assert dropout_p != 0.0 + assert is_causal == True + + attn_mask = node.kwargs.get("attn_mask", None) + scale = node.kwargs.get("scale", None) + + query_shape = query.type.shape + key_shape = key.type.shape + value_shape = value.type.shape + output_shape = list(node.tensor_meta["shape"]) + L, S = query_shape[-2], key_shape[-2] + scale_factor = ( + 1 / numpy.sqrt(query.type.shape[-1]) if scale is None else scale + ) + + # Initialize attention bias + dtype = node.tensor_meta["dtype"][0] + attn_bias_shape = [L, S] + mlir_dtype = mlir_element_type_get(dtype) + attn_bias_type = ir.RankedTensorType.get(attn_bias_shape, mlir_dtype) + zero_constant = arith.ConstantOp(mlir_dtype, 0.0) + attn_bias = tensor.SplatOp(attn_bias_type, zero_constant) + if attn_mask is not None: + attn_mask = symbol_table.get((str(attn_mask), 0), attn_mask) + if attn_mask.type.element_type == ir.IntegerType.get_signless(1): + assert attn_mask.type.element_type == ir.IntegerType.get_signless(1) + tensor_type = ir.RankedTensorType.get( + attn_mask.type.shape, ir.IntegerType.get_signless(1) + ) + true_tensor = arith.ConstantOp( + tensor_type, + ir.DenseElementsAttr.get_splat( + tensor_type, ir.BoolAttr.get(True) + ), + ) + attn_mask = arith.XOrIOp(attn_mask, true_tensor) + minus_inf_tensor = arith.ConstantOp( + attn_mask.type, + ir.DenseElementsAttr.get_splat( + attn_mask.type, ir.FloatAttr.get(f32_type, float("-inf")) + ), + ) + attn_bias = tensor.SelectOp(attn_mask, minus_inf_tensor, attn_bias) + else: + if attn_mask.type.shape != attn_bias.result.type.shape: + attn_mask = tosa.ReshapeOp( + attn_mask, + memoryview(array.array("i", attn_bias.result.type.shape)), + ) + attn_bias = tosa.AddOp(attn_bias.result.type, attn_bias, attn_mask) + + # Transpose key tensor + key_shape = list(key.type.shape) + perm_list = list(range(len(key_shape))) + perm_list[-1], perm_list[-2] = perm_list[-2], perm_list[-1] + perm_const_op = tosa.ConstOp( + ir.DenseElementsAttr.get(memoryview(array.array("i", perm_list))) + ) + perm_shape = [] + perm_shape.append(key_shape[0]) + perm_shape.append(key_shape[1]) + perm_shape.append(key_shape[3]) + perm_shape.append(key_shape[2]) + permute_result_type = ir.RankedTensorType.get(perm_shape, mlir_dtype) + key = tosa.TransposeOp( + permute_result_type, key, perm_const_op.results[0] + ).result + + # Matrix multiplication of query and key + query_reshape_op = tosa.ReshapeOp( + query, + memoryview( + array.array( + "i", + [ + query_shape[0] * query_shape[1], + query_shape[2], + query_shape[3], + ], + ) + ), + ) + key_reshape_op = tosa.ReshapeOp( + key, + memoryview( + array.array( + "i", [key_shape[0] * key_shape[1], key_shape[3], key_shape[2]] + ) + ), + ) + matmul_result_shp = [ + key_shape[0] * key_shape[1], + query_shape[2], + key_shape[2], + ] + matmul_result_type = ir.RankedTensorType.get(matmul_result_shp, mlir_dtype) + matmul_op = tosa.MatMulOp( + matmul_result_type, query_reshape_op.result, key_reshape_op.result + ) + # Multiply result by scale factor + scale_factor_constant = arith.ConstantOp(mlir_dtype, scale_factor) + scale_factor = tensor.SplatOp(matmul_result_type, scale_factor_constant) + mul_op = tosa.MulOp( + matmul_result_type, + matmul_op, + scale_factor, + ir.IntegerAttr.get(ir.IntegerType.get_signless(8), 0), + ) + + # Add attention bias to the result + add_op = tosa.AddOp(matmul_result_type, mul_op.result, attn_bias) + # Apply softmax to the result + softmax_output_shape = list(add_op.result.type.shape) + softmax_dim = len(softmax_output_shape) - 1 + + # Subtract the maximum value along the dimension where softmax is applied to prevent overflow during the exp operation. + max_vals = tosa.ReduceMaxOp(add_op.result, softmax_dim) + sub_op = tosa.SubOp(add_op.result.type, add_op, max_vals) + exp_op = math.ExpOp(sub_op) + reduce_sum_op = tosa.ReduceSumOp(exp_op, softmax_dim) + log_op = tosa.LogOp(reduce_sum_op.result.type, reduce_sum_op) + log_sumexp = tosa.AddOp(max_vals.result.type, max_vals, log_op) + log_weights = tosa.SubOp(add_op.result.type, add_op, log_sumexp) + softmax_result = math.ExpOp(log_weights) + log_sumexp = tosa.ReshapeOp( + log_sumexp, + memoryview( + array.array( + "i", + output_shape[1], + ) + ), + ) + + # This step includes dropout during training. + # Multiply the result by the value tensor. + value_reshape_op = tosa.ReshapeOp( + value, + memoryview( + array.array( + "i", + [key_shape[0] * key_shape[1], value_shape[2], value_shape[3]], + ) + ), + ) + matmul_result_shp = matmul_result_shp = [ + key_shape[0] * key_shape[1], + query_shape[2], + value_shape[3], + ] + matmul_result_type = ir.RankedTensorType.get(matmul_result_shp, mlir_dtype) + matmul_op = tosa.MatMulOp( + matmul_result_type, softmax_result.result, value_reshape_op.result + ) + + result_reshape_op = tosa.ReshapeOp( + matmul_op.result, + memoryview( + array.array( + "i", + [key_shape[0], key_shape[1], query_shape[2], value_shape[3]], + ) + ), + ) + + return result_reshape_op, log_sumexp + + ops_registry = { "AddOp": add_op, "MulOp": mul_op, @@ -1515,4 +1811,5 @@ def argmax_op(node: ArgMaxOp, symbol_table): "ClampMaxOp": clamp_max_op, "RandIntLowOp": randint_low_op, "ArgMaxOp": argmax_op, + "ScaledDotProductFlashAttentionForCpuOp": scaled_dot_product_flash_attention_for_cpu_op, } diff --git a/requirements.txt b/requirements.txt index 9818b8ec74..6b2fd250c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ --pre --extra-index-url https://download.pytorch.org/whl/cpu -torch == 2.1.2 +torch == 2.5.1 numpy < 2 -transformers == 4.33.1 -tokenizers == 0.13.3 -sentencepiece == 0.1.99 +transformers == 4.46.2 +tokenizers >= 0.20 +sentencepiece == 0.2.0 accelerate protobuf pybind11 == 2.11.1 @@ -12,3 +12,6 @@ tabulate datasets soundfile librosa +PyYAML +certifi +idna \ No newline at end of file diff --git a/tests/Python/test_convert_element_type.py b/tests/Python/test_convert_element_type.py index ca88384633..cf3cc7e941 100644 --- a/tests/Python/test_convert_element_type.py +++ b/tests/Python/test_convert_element_type.py @@ -29,7 +29,8 @@ def foo(x, to_cast_type): # CHECK: module { # CHECK-LABEL: func.func @forward -# CHECK: %{{.*}} = tosa.cast +# CHECK: %{{.*}} = tensor.empty +# CHECK: %{{.*}} = linalg.generic # CHECK: return %{{.*}} # CHECK: } # CHECK: } diff --git a/tests/Python/test_max_pool2d.py b/tests/Python/test_max_pool2d.py index eecfc73d93..cac892761d 100644 --- a/tests/Python/test_max_pool2d.py +++ b/tests/Python/test_max_pool2d.py @@ -1,7 +1,6 @@ # RUN: %PYTHON %s 2>&1 | FileCheck %s import torch -from torch._inductor.decomposition import decompositions as inductor_decomp from buddy.compiler.frontend import DynamoCompiler from buddy.compiler.ops import tosa @@ -19,7 +18,6 @@ def forward(self, a): model = TestModule() dynamo_compiler = DynamoCompiler( primary_registry=tosa.ops_registry, - aot_autograd_decomposition=inductor_decomp, ) in1 = torch.randn((1, 3, 640, 480)) @@ -27,7 +25,7 @@ def forward(self, a): model_opt = torch.compile(model, backend=dynamo_compiler) assert torch.allclose(model_opt(in1), model(in1), equal_nan=True) -graphs = dynamo_compiler.importer(model, in1) +graphs = dynamo_compiler._imported_graphs assert len(graphs) == 1 graph = graphs[0] graph.lower_to_top_level_ir() diff --git a/tests/Python/test_mean.py b/tests/Python/test_mean.py index 0595619d18..54cc092b48 100644 --- a/tests/Python/test_mean.py +++ b/tests/Python/test_mean.py @@ -24,7 +24,7 @@ def foo(x, y, keepdim): assert torch.allclose( foo_mlir(in1, in2, keepdim=in3), foo(in1, in2, keepdim=in3), equal_nan=True ) -graphs = dynamo_compiler.importer(foo, in1, in2, in3) +graphs = dynamo_compiler._imported_graphs assert len(graphs) == 1 graph = graphs[0] graph.lower_to_top_level_ir() diff --git a/tests/Python/test_reciprocal.py b/tests/Python/test_reciprocal.py index 9c31fb8b5b..427927c315 100644 --- a/tests/Python/test_reciprocal.py +++ b/tests/Python/test_reciprocal.py @@ -22,7 +22,7 @@ def foo(x): foo_mlir = torch.compile(foo, backend=dynamo_compiler) assert torch.allclose(foo_mlir(x), foo(x), equal_nan=True) -graphs = dynamo_compiler.importer(foo, x) +graphs = dynamo_compiler._imported_graphs assert len(graphs) == 1 graph = graphs[0] graph.lower_to_top_level_ir() diff --git a/tests/Python/test_sqrt.py b/tests/Python/test_sqrt.py index b929d11075..bd76a77928 100644 --- a/tests/Python/test_sqrt.py +++ b/tests/Python/test_sqrt.py @@ -22,7 +22,7 @@ def foo(x): foo_mlir = torch.compile(foo, backend=dynamo_compiler) assert torch.allclose(foo_mlir(x), foo(x), equal_nan=True) -graphs = dynamo_compiler.importer(foo, x) +graphs = dynamo_compiler._imported_graphs assert len(graphs) == 1 graph = graphs[0] graph.lower_to_top_level_ir()