diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b718d0832..6f2e388d1 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -21,7 +21,7 @@ jobs: run: git fetch --no-tags --prune --depth=1 origin "${GITHUB_BASE_REF?}:${GITHUB_BASE_REF?}" - name: Install black run: | - python3 -m pip install black==23.3 + python3 -m pip install black - name: Check if modified files are formatted run: | # The filter lowercase `d` means to exclude deleted files. diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index b7facb903..03872dea3 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -50,7 +50,7 @@ jobs: # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile -r ${{ github.workspace }}/iree-turbine/pytorch-cpu-requirements.txt - pip install --no-compile --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt + pip install --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt pip install --no-compile --pre -e ${{ github.workspace }}/iree-turbine[testing] pip install --upgrade --pre --no-cache-dir iree-compiler iree-runtime -f https://iree.dev/pip-release-links.html pip install --no-compile --pre --upgrade -e models -r models/requirements.txt @@ -69,7 +69,8 @@ jobs: source turbine_venv/bin/activate pytest -v models/turbine_models/tests/sd_test.py - pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu + pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5 pytest -v models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux - pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 - + pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default + pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default --batch_size 2 + pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5 \ No newline at end of file diff --git a/.github/workflows/test_shark.yml b/.github/workflows/test_shark.yml index a60a098bd..6f2e4b4ed 100644 --- a/.github/workflows/test_shark.yml +++ b/.github/workflows/test_shark.yml @@ -20,7 +20,7 @@ jobs: strategy: matrix: version: [3.11] - os: [nodai-ubuntu-builder-large] + os: [nodai-amdgpu-mi250-x86-64] runs-on: ${{matrix.os}} steps: @@ -49,7 +49,6 @@ jobs: cd $GITHUB_WORKSPACE/SHARK python${{ matrix.version }} -m venv shark.venv source shark.venv/bin/activate - sed -i 's/SHARK-Turbine#/SHARK-Turbine.git@${{github.sha}}#/g' requirements.txt pip install -r requirements.txt --no-cache-dir pip install -e . python apps/shark_studio/tests/api_test.py diff --git a/.gitignore b/.gitignore index d85c8598b..54f4c40cc 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,7 @@ wheelhouse *.safetensors *.gguf *.vmfb +*.mlir +*.npy +*.png +*tmp* diff --git a/models/requirements.txt b/models/requirements.txt index ed2a0b0c1..0aed40159 100644 --- a/models/requirements.txt +++ b/models/requirements.txt @@ -1,11 +1,16 @@ protobuf -sentencepiece -shark_turbine +gguf transformers==4.37.1 +torchsde accelerate -diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release +peft +diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b # turbine tank downloading/uploading azure-storage-blob # microsoft/phi model einops +pytest +scipy +shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main +-e git+https://github.com/nod-ai/sharktank.git@main#egg=sharktank&subdirectory=sharktank diff --git a/models/setup.py b/models/setup.py index fae7c4a61..09d60cfe3 100644 --- a/models/setup.py +++ b/models/setup.py @@ -55,12 +55,11 @@ def load_version_info(): ), install_requires=[ "Shark-Turbine", - "brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b", "protobuf", "sentencepiece", - "transformers==4.37.1", + "transformers>=4.37.1", "accelerate", - "diffusers==0.24.0", + "diffusers==0.29.0.dev0", "azure-storage-blob", "einops", ], diff --git a/models/turbine_models/custom_models/llama_argmax_td_spec.mlir b/models/turbine_models/custom_models/llama_argmax_td_spec.mlir new file mode 100644 index 000000000..0ef957cb3 --- /dev/null +++ b/models/turbine_models/custom_models/llama_argmax_td_spec.mlir @@ -0,0 +1,169 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// The configuration used for executable compilation. +// This specifies the device configurations that support this custom kernel. +#rocm_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {target_arch = "gfx1100", ukernels = "none"}> + +module attributes {transform.with_named_sequence} { + util.func private @argmax_1d_f32_entry_point(%arg0: tensor<1x?xf32>) -> tensor<1xi64> { + %c1 = arith.constant 1 : index + %dim = tensor.dim %arg0, %c1 : tensor<1x?xf32> + // Note: This is not safe if the dim size exceeds INT32_MAX. To pass a 64 + // bit value it must be broken down into two 32-bit values for the high and + // low bits. + %dim_i32 = arith.index_cast %dim : index to i32 + // Inline external dispatch that conforms to the ABI that the kernel + // requires. This is the primary reason for the surrounding function as + // details like tensor shape and push constants need to line up after + // splicing in the custom dispatch. This allows the kernel author to manage + // such details by hand without needing the rewrite patterns to worry about + // things like order of push constants. + %4 = hal.dispatch.extern "argmax_F32I64"[%dim](%dim_i32, %arg0) : (i32, tensor<1x?xf32>{%dim}) -> tensor<1xi64> + count(%device: !hal.device, %workload: index) -> (index, index, index) { + %c1_0 = arith.constant 1 : index + hal.return %c1_0, %c1_0, %c1_0 : index, index, index + } + layout(#hal.pipeline.layout, + <1, storage_buffer> + ]> + ]>) + bindings([ + #hal.interface.binding<0, 0>, + #hal.interface.binding<0, 1> + ]) + objects({ + #rocm_target ordinal(0) = [ + #hal.executable.object<{ + data = dense<""> : vector<37600xi8> + }> + ] + }) + attributes {subgroupSize = 32, workgroup_size = [32 : index, 1 : index, 1 : index]} + util.return %4 : tensor<1xi64> + } + // data = dense<"0x7f454c was generated by generate_hsaco.sh under filename.hex. It uses + // xxd -p -c 1000000 filename.hsaco > filename.hex to generate the hexdump. and the shape is + // vector. + + // Custom matcher for argmax operations equivalent to the custom kernel. This + // matcher will be run one-by-one on all operations contained within the + // target function. On success, it will return the handle to the matched + // argmax operation. + transform.named_sequence @match_argmax(%generic: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + // Fail fast on non-linalg generics. + transform.match.operation_name %generic ["linalg.generic"] : !transform.any_op + %matched = transform.match.structured failures(propagate) %generic : (!transform.any_op) -> (!transform.any_op) { + ^bb1(%argmax: !transform.any_op): + // Verify that the rank (i.e. number of loops) of the linalg op is 2, + // with one parallel iterator and one reduction iterator. + // TODO: Add optionality for the parallel dimensions. + %c2 = transform.param.constant 2 : i64 -> !transform.param + %rank = transform.match.structured.rank %argmax : (!transform.any_op) -> !transform.param + transform.match.param.cmpi eq %rank, %c2 : !transform.param + transform.match.structured.dim %argmax[0] {parallel} : !transform.any_op + transform.match.structured.dim %argmax[-1] {reduction} : !transform.any_op + + // Verify a single input (target vector to compute the argmax of) and two + // outputs, one for the maximum value and one for the index. + %c1 = transform.param.constant 1 : i64 -> !transform.param + %n_inputs = transform.match.structured.num_inputs %argmax : (!transform.any_op) -> !transform.param + transform.match.param.cmpi eq %n_inputs, %c1 : !transform.param + %n_outputs = transform.match.structured.num_inits %argmax : (!transform.any_op) -> !transform.param + transform.match.param.cmpi eq %n_outputs, %c2 : !transform.param + + transform.match.structured.yield %argmax : !transform.any_op + } + + // Verify the operand shapes of the linalg op. For example, in the below, + // dim 0 must be statically 1, and dim 1 must be statically divisible by 64. + %in0 = transform.get_operand %matched[0] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %in0 = tensor<1x?xf32> : !transform.any_value + transform.iree.match.dim_is_multiple_of %in0[1], 64 : !transform.any_value + %out0 = transform.get_operand %matched[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %out0 = tensor<1xf32> : !transform.any_value + %out1 = transform.get_operand %matched[2] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %out1 = tensor<1xi64> : !transform.any_value + + // Verify the region of the argmax op. This does a structural comparison of + // region(s) of the payload operation against the single operation contained + // within the body of this operation. This does no verification of other + // input types/attributes. This is because typically for kernel matching, + // the most important part to get exactly right is the inner loop. Otherwise + // small variations to shape information and iterator counts and such are + // better suited for more general matchers. + transform.iree.match.regions %matched : !transform.any_op { + ^bb0(%target: tensor<1x?xf32>, %empty_max: tensor<1xf32>, %empty_idx: tensor<1xi64>): + %5:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%target : tensor<1x?xf32>) + outs(%empty_max, %empty_idx : tensor<1xf32>, tensor<1xi64>) { + ^bb0(%in: f32, %out: f32, %out_0: i64): + %6 = linalg.index 1 : index + %7 = arith.index_cast %6 : index to i64 + %8 = arith.maximumf %in, %out : f32 + %9 = arith.cmpf ogt, %in, %out : f32 + %10 = arith.select %9, %7, %out_0 : i64 + linalg.yield %8, %10 : f32, i64 + } -> (tensor<1xf32>, tensor<1xi64>) + } + transform.yield %generic : !transform.any_op + } + + // Rewrite callback for `transform.foreach_match`. The input signature for + // this sequence must match exactly with the outputs of the matcher. In this + // case we just take the argmax as an input, import the entry point for the + // custom kernel authored above, and replace the users of the argmax with a + // call to the function. + transform.named_sequence @cast_and_call_argmax(%argmax: !transform.any_op {transform.readonly}) { + %module = transform.util.get_nearest_symbol_table %argmax : (!transform.any_op) -> !transform.any_op + %func = transform.util.import_symbol @argmax_1d_f32_entry_point into %module if undefined : (!transform.any_op) -> !transform.any_op + %ins = transform.get_operand %argmax[0] : (!transform.any_op) -> !transform.any_value + %outs = transform.get_result %argmax[1] : (!transform.any_op) -> !transform.any_value + transform.util.cast_and_call %func(%ins) -> %outs before %argmax { + // This specifies how to resolve type mismatches between the arguments + // of the function and the inputs to the argmax. In this example, the + // only casts this will generate are same-rank tensor casts that drop + // static information. + transform.type_conversion.tensor.cast_shape_dynamic_dims + } : (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) -> !transform.any_op + transform.yield + } + + // Entry point for the transform interpreter, nested on the full module. This + // is because the rewrites needed for importing the custom kernel needs to + // add a new symbol to the module's symbol table. + transform.named_sequence @__transform_main(%module: !transform.any_op) { + // Gather the set of functions within the module. + %funcs = transform.structured.match ops{["util.func"]} in %module : (!transform.any_op) -> !transform.any_op + // For each function in the module, run the matcher on all contained + // operations. + transform.foreach %funcs : !transform.any_op { + ^bb1(%func: !transform.any_op): + transform.foreach_match in %func + // -> + // Multiple matcher-action pairs can be specified comma separated, + // here we are only doing a single kind of match and replace. + // + // Note that the operations within the module are walked in + // post-order, meaning actions must be very careful in their + // replacements not to modify successors of operations. Nested + // regions and DAG roots will be visited last so it is safest to + // do matching + replacement on the root of the DAG rather than + // trying to look ahead. The other option is to avoid dce/cse until + // after the walk is complete. + @match_argmax -> @cast_and_call_argmax + : (!transform.any_op) -> (!transform.any_op) + } + // Cleanup now dead instances of argmax. + transform.apply_dce to %module : !transform.any_op + transform.yield + } +} \ No newline at end of file diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py new file mode 100644 index 000000000..3260d6f9c --- /dev/null +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -0,0 +1,725 @@ +# Copyright 2024 Advanced Micro Devices, inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import torch +import ast +from collections.abc import Iterable + +import iree.runtime as ireert +from turbine_models.custom_models.sd_inference import utils, schedulers +from turbine_models.custom_models.sdxl_inference.pipeline_ir import ( + get_pipeline_ir, +) +from turbine_models.utils.sdxl_benchmark import run_benchmark +from turbine_models.model_runner import vmfbRunner + +from PIL import Image +import gc +import os +import numpy as np +import time +import copy +from datetime import datetime as dt + +np_dtypes = { + "fp16": np.float16, + "fp32": np.float32, + "float16": np.float16, + "float32": np.float32, +} +torch_dtypes = { + "fp16": torch.float16, + "fp32": torch.float32, + "float16": torch.float16, + "float32": torch.float32, +} + + +def merge_arg_into_map(model_map, arg, arg_name): + if isinstance(arg, dict): + for key in arg.keys(): + if key not in model_map.keys(): + continue + if not model_map[key].get(arg_name): + model_map[key][arg_name] = arg[key] + else: + for key in model_map.keys(): + if not model_map[key].get(arg_name): + model_map[key][arg_name] = arg + return model_map + + +def merge_export_arg(model_map, arg, arg_name): + if isinstance(arg, dict): + for key in arg.keys(): + if key not in model_map.keys(): + continue + if arg_name not in model_map[key].get("export_args", {}): + model_map[key]["export_args"][arg_name] = arg[key] + else: + for key in model_map.keys(): + if not model_map[key].get("export_args", {}).get(arg_name): + continue + model_map[key]["export_args"][arg_name] = arg + return model_map + + +# def str_to_list(string): +# out = string.strip("[]").replace(" ", "").split(";") +# for item in out: +# item = ast.literal_eval(item) +# return out + + +class PipelineComponent: + """ + Wraps a VMFB runner with attributes for embedded metadata, device info, utilities and + has methods for handling I/O or otherwise assisting in interfacing with their pipeline + and its other components. + This aims to make new pipelines and execution modes easier to write, manage, and debug. + """ + + def __init__(self, dest_type="devicearray", dest_dtype="float16"): + self.runner = None + self.module_name = None + self.device = None + self.metadata = None + self.benchmark = False + self.dest_type = dest_type + self.dest_dtype = dest_dtype + + def load( + self, + rt_device: str, + vmfb_path: str, + module_name: str, + external_weight_path: str = None, + extra_plugin=None, + ): + self.module_name = module_name + print( + f"Loading {module_name} from {vmfb_path} with external weights: {external_weight_path}." + ) + self.runner = vmfbRunner( + rt_device, vmfb_path, external_weight_path, extra_plugin + ) + self.device = self.runner.config.device + self.module = getattr(self.runner.ctx.modules, module_name) + self.get_metadata() + + def unload(self): + self.device = None + self.runner = None + gc.collect() + + def get_metadata(self): + self.metadata = {} + for function_name in self.module.vm_module.function_names: + if any(x in function_name for x in ["$async", "__init"]): + continue + try: + self.metadata[function_name] = self.module[ + function_name + ].vm_function.reflection + except: + logging.warning( + f"Could not get metadata for {self.module_name}['{function_name}']." + ) + self.metadata[function_name] = None + + def _validate_or_convert_inputs(self, function_name, inputs): + val_inputs = [None for i in inputs] + if self.metadata.get(function_name): + expected_input_shapes = self.metadata.get(function_name, {}).get( + "input_shapes" + ) + if expected_input_shapes: + expected_input_shapes = ast.literal_eval(expected_input_shapes) + expected_input_dtypes = self.metadata.get(function_name, {}).get( + "input_dtypes", "" + ) + if expected_input_dtypes: + expected_input_dtypes = ast.literal_eval(expected_input_dtypes) + if not expected_input_dtypes: + pass + if not expected_input_shapes: + logging.warning( + f"No input shapes found for {self.module_name}['{function_name}']." + ) + for idx, i in enumerate(inputs): + if not isinstance(i, ireert.DeviceArray): + val_inputs[idx] = ireert.asdevicearray(self.device, i) + pass + if not isinstance(expected_input_shapes, list): + expected_input_shapes = [expected_input_shapes] + for i, input_dtype in enumerate(expected_input_dtypes): + if not isinstance(inputs[i], ireert.DeviceArray): + val_inputs[i] = ireert.asdevicearray( + self.device, inputs[i], input_dtype + ) + elif str(inputs[i].dtype).split(".")[-1] != input_dtype: + logging.warning( + f"Converting input {i} to {input_dtype} for {self.module_name}['{function_name}']." + ) + val_inputs[i] = inputs[i].astype(input_dtype) + else: + val_inputs[i] = inputs[i] + for i, input_shape in enumerate(expected_input_shapes): + if isinstance(input_shape, str): + input_shape = ast.literal_eval(input_shape) + elif not input_shape: + continue + actual = tuple(val_inputs[i].shape) + expected = tuple(input_shape) + for idx, shape in enumerate(expected): + if shape == "?": + pass + elif actual[idx] != shape: + raise ValueError( + f"Expected input {i} to be of shape {input_shape} for {self.module_name}['{function_name}'], got {str(tuple(inputs[i].shape))}." + ) + else: + for idx, i in enumerate(inputs): + if not isinstance(i, ireert.DeviceArray): + val_inputs[idx] = ireert.asdevicearray(self.device, i) + else: + val_inputs[idx] = inputs[idx] + return val_inputs + + def _output_cast(self, output): + if isinstance(output, tuple): + out_tuple = () + for array in output: + array_out = self._output_cast(array) + out_tuple += (array_out,) + return out_tuple + match self.dest_type: + case "devicearray": + output = ( + output.astype(self.dest_dtype) + if output.dtype != self.dest_dtype + else output + ) + return output + case "torch": + output = torch.tensor( + output.to_host(), dtype=torch_dtypes[self.dest_dtype] + ) + return output + case "numpy": + return output.to_host().astype(np_dtypes[self.dest_dtype]) + case _: + return output + + def _run(self, function_name, inputs: list): + return self.module[function_name](*inputs) + + def _run_and_benchmark(self, function_name, inputs: list): + start_time = time.time() + output = self._run(function_name, inputs) + latency = time.time() - start_time + print(f"Latency for {self.module_name}['{function_name}']: {latency}sec") + return output + + def __call__(self, function_name, inputs: list): + casted_output = False + if not isinstance(inputs, list): + inputs = [inputs] + inputs = self._validate_or_convert_inputs(function_name, inputs) + if self.benchmark: + output = self._run_and_benchmark(function_name, inputs) + else: + output = self._run(function_name, inputs) + output = self._output_cast(output) + return output + + +class TurbinePipelineBase: + """ + This class is a lightweight base for Stable Diffusion + inference API classes. It should provide methods for: + + - Exporting and compiling a set (model map) of torch IR modules + - preparing weights for an inference job + - loading weights for an inference job + - utilities i.e. filenames, downloads + + The general flow of an arbitrary child of this pipeline base is as follows: + 1. Initialize a model map and class attributes. + 2. Preparation: Check if all necessary files are present, and generate them if not. (prepare_all() / prepare_submodel()) + - This is done by submodel, so that users can generate a new submodel with the same pipeline. + - If vmfb not found, first check turbine tank for matching .vmfb file. + - If vmfb not downloadable, try downloading .mlir. + - If neither on Azure, run the export function in model map to export to torch IR and compile with IREE. + - If weights not found, run the export function in model map with weights_only=True. + - Apps should populate the weights with custom weights by now so they can be managed and converted if needed here. + 3. Load the pipeline: Load the prepared files onto devices as vmfbRunners. (load_pipeline() / load_submodel() / reload_submodel()) + 4. Run Inference: + + + + Arguments: + model_map: dict + A dictionary mapping submodel names to their export functions and hf model ids. This is used throughout the pipeline. + It also should provide I/O information for the submodels. + height: int + The height of the image to be generated + width: int + The width of the image to be generated + precision: str + The precision of the image latents. This usually decides the precision of all models in the pipeline. + max_length: int + The maximum sequence length for text encoders and diffusion models. + batch_size: int + The number of images to generate from each inference batch. This changes the shapes in all submodels. + device: str | dict[str] + Either a string i.e. "rocm://0", or a dictionary of such with keys matching the submodels of a given pipeline. + If a string, a dictionary will be created based on the pipeline's model map and the same device will be used for all submodels. + target: str | dict[str] + Either a string i.e. "gfx1100", or a dictionary with keys matching the submodels of a given pipeline. + ireec_flags: str | dict[str] + A comma-separated string of flags to pass to the IREE compiler, or a dict of them with keys matching submodels of a given pipeline. + """ + + def __init__( + self, + model_map: dict, + device: str | dict[str], + target: str | dict[str], + ireec_flags: str | dict[str] = None, + precision: str | dict[str] = "fp16", + td_spec: str | dict[str] = None, + decomp_attn: bool | dict[bool] = False, + external_weights: str | dict[str] = None, + pipeline_dir: str = "./shark_vmfbs", + external_weights_dir: str = "./shark_weights", + hf_model_name: str | dict[str] = None, + common_export_args: dict = {}, + ): + self.map = model_map + if isinstance(device, dict): + assert isinstance( + target, dict + ), "Device and target triple must be both dicts or both strings." + for submodel in self.map.keys(): + assert submodel in device.keys(), f"Device for {submodel} not found." + assert ( + submodel in target.keys() + ), f"Target arch for {submodel} not found." + self.map[submodel]["device"] = utils.iree_backend_map(device[submodel]) + self.map[submodel]["driver"] = utils.iree_device_map(device[submodel]) + self.map[submodel]["target"] = target[submodel] + else: + assert isinstance( + target, str + ), "Device and target triple must be both dicts or both strings." + for submodel in self.map.keys(): + self.map[submodel]["device"] = utils.iree_backend_map(device) + self.map[submodel]["driver"] = utils.iree_device_map(device) + self.map[submodel]["target"] = target + + map_arguments = { + "ireec_flags": ireec_flags, + "precision": precision, + "td_spec": td_spec, + "decomp_attn": decomp_attn, + "external_weights": external_weights, + "hf_model_name": hf_model_name, + } + for arg in map_arguments.keys(): + self.map = merge_arg_into_map(self.map, map_arguments[arg], arg) + + self.map = merge_arg_into_map( + self.map, np_dtypes[self.map[submodel]["precision"]], "np_dtype" + ) + self.map = merge_arg_into_map( + self.map, torch_dtypes[self.map[submodel]["precision"]], "torch_dtype" + ) + for arg in common_export_args.keys(): + for submodel in self.map.keys(): + self.map[submodel].get("export_args", {})[arg] = self.map[submodel].get( + arg, common_export_args[arg] + ) + for submodel in self.map.keys(): + for key, value in map_arguments.items(): + self.map = merge_export_arg(self.map, value, key) + for key, value in self.map[submodel].get("export_args", {}).items(): + if key == "hf_model_name": + self.map[submodel]["keywords"].append( + utils.create_safe_name(value.split("/")[-1], "") + ) + if key == "decomp_attn": + if not value: + self.map[submodel]["keywords"].append("!decomp_attn") + else: + self.map[submodel]["keywords"].append("decomp_attn") + elif key == "batch_size": + self.map[submodel]["keywords"].append(f"bs{value}") + elif key in ["height"]: + dims = f"{self.map[submodel]['export_args']['width']}x{self.map[submodel]['export_args']['height']}" + self.map[submodel]["keywords"].append(dims) + elif key in ["max_length", "precision"]: + self.map[submodel]["keywords"].append(str(value)) + + self.pipeline_dir = pipeline_dir + if not os.path.exists(self.pipeline_dir): + os.makedirs(self.pipeline_dir) + self.external_weights_dir = external_weights_dir + if not os.path.exists(self.external_weights_dir): + os.makedirs(self.external_weights_dir) + + # Disabled for now -- enable through option when turbine tank is ready. + self.download = False + + # These arguments are set at run or load time. + self.compiled_pipeline = False + self.split_scheduler = False + self.cpu_scheduling = False + + # TODO: set this based on user-inputted guidance scale and negative prompt. + self.do_classifier_free_guidance = True # False if any(x in hf_model_name for x in ["turbo", "lightning"]) else True + self._interrupt = False + + # FILE MANAGEMENT AND PIPELINE SETUP + + def prepare_all( + self, + mlirs: dict = {}, + vmfbs: dict = {}, + weights: dict = {}, + interactive: bool = False, + ): + ready = self.is_prepared(vmfbs, weights) + match ready: + case True: + print("All necessary files found.") + return + case False: + if interactive: + do_continue = input( + f"\nIt seems you are missing some necessary files. Would you like to generate them now? (y/n)" + ) + if do_continue.lower() != "y": + exit() + for submodel in self.map.keys(): + if not self.map[submodel].get("vmfb"): + print("Fetching: ", submodel) + self.export_submodel( + submodel, input_mlir=self.map[submodel].get("mlir") + ) + if not self.map[submodel]["export_args"]["external_weights"]: + assert not self.map[submodel].get( + "weights" + ), f"External weights should not be used for a model with inlined params." + if not self.map[submodel].get("weights") and self.map[submodel][ + "export_args" + ].get("external_weights"): + self.export_submodel(submodel, weights_only=True) + return self.prepare_all(mlirs, vmfbs, weights, interactive) + + def is_prepared(self, vmfbs, weights): + missing = {} + ready = False + pipeline_dir = self.pipeline_dir + for key in self.map: + missing[key] = [] + # vmfb is already present in model map + if self.map[key].get("vmfb"): + continue + # vmfb is passed in to this function + elif vmfbs.get(key): + self.map[key]["vmfb"] = vmfbs[key] + continue + # search self.pipeline_dir for key-specific vmfb + keywords = self.map[key].get("keywords", []) + mlir_keywords = copy.deepcopy(keywords) + mlir_keywords.extend( + [ + "mlir", + ] + ) + keywords.extend( + [ + "vmfb", + self.map[key]["target"], + ] + ) + neg_keywords = [] + for kw in keywords: + if kw.startswith("!"): + neg_keywords.append(kw.strip("!")) + keywords.remove(kw) + mlir_keywords.remove(kw) + avail_files = os.listdir(pipeline_dir) + candidates = [] + # print("MLIR KEYS: ", mlir_keywords) + # print("AVAILABLE FILES: ", avail_files) + for filename in avail_files: + if all(str(x) in filename for x in keywords) and not any( + x in filename for x in neg_keywords + ): + candidates.append(os.path.join(pipeline_dir, filename)) + if all(str(x) in filename for x in mlir_keywords) and not any( + x in filename for x in neg_keywords + ): + self.map[key]["mlir"] = os.path.join(pipeline_dir, filename) + if len(candidates) == 1: + self.map[key]["vmfb"] = candidates[0] + elif len(candidates) > 1: + print(f"Multiple files found for {key}: {candidates}") + print(f"Choosing {candidates[0]} for {key}.") + self.map[key]["vmfb"] = candidates[0] + else: + # vmfb not found in pipeline_dir. Add to list of files to generate. + missing[key].append("vmfb") + + # Make sure vmfb needs external weights, as they may be inlined. + if self.map[key].get("export_args", {}).get("external_weights"): + if not self.map[key]["external_weights"]: + continue + if self.map[key].get("weights"): + # weights already found in model map + continue + elif weights.get(key): + # weights passed in to this function + self.map[key]["weights"] = weights[key] + continue + # search self.external_weights_dir for key-specific weights + w_keywords = [ + self.map[key]["export_args"]["external_weight_path"], + ] + + avail_files = os.listdir(self.external_weights_dir) + candidates = [] + for filename in avail_files: + if all(str(x) in filename for x in w_keywords): + candidates.append( + os.path.join(self.external_weights_dir, filename) + ) + if len(candidates) == 1: + self.map[key]["weights"] = candidates[0] + elif len(candidates) > 1: + print(f"Multiple weight files found for {key}: {candidates}") + print(f"Choosing {candidates[0]} for {key}.") + self.map[key][weights] = candidates[0] + elif self.map[key].get("external_weights"): + # weights not found in external_weights_dir. Add to list of files to generate. + missing[key].append("weights") + if not any(x for x in missing.values()): + ready = True + else: + print("Missing files: ", missing) + ready = False + return ready + + def get_mlir_from_turbine_tank(self, submodel, container_name): + from turbine_models.turbine_tank import downloadModelArtifacts + + safe_name = utils.create_safe_name( + self.hf_model_name, + f"_{self.max_length}_{self.height}x{self.width}_{self.precision}_{submodel}.mlir", + ) + mlir_path = downloadModelArtifacts( + safe_name, + container_name, + ) + return mlir_path + + # IMPORT / COMPILE PHASE + + def export_submodel( + self, + submodel: str, + input_mlir: str = None, + weights_only: bool = False, + ): + if not os.path.exists(self.pipeline_dir): + os.makedirs(self.pipeline_dir) + + if ( + self.map[submodel].get("external_weights") + and self.external_weights_dir + and not self.map[submodel].get("weights") + ): + if not os.path.exists(self.external_weights_dir): + os.makedirs(self.external_weights_dir, exist_ok=False) + + self.map[submodel]["export_args"]["external_weight_path"] = os.path.join( + self.external_weights_dir, + self.map[submodel]["export_args"]["external_weight_path"], + ) + elif self.map[submodel].get("weights") and self.map[submodel].get( + "use_weights_to_export" + ): + self.map[submodel]["export_args"]["external_weight_path"] = self.map[ + submodel + ]["weights"] + + elif not self.map[submodel].get("external_weights"): + self.map[submodel]["weights"] = None + + if weights_only: + input_mlir = None + elif "mlir" in self.map[submodel].keys(): + input_mlir = self.map[submodel]["mlir"] + elif self.download: + try: + input_mlir = self.get_mlir_from_turbine_tank( + submodel, self.tank_container + ) + except: + input_mlir = None + else: + input_mlir = None + self.map[submodel]["export_args"]["input_mlir"] = self.map[submodel].get( + "mlir", input_mlir + ) + + match submodel: + case "unetloop": # SDXL ONLY FOR NOW + pipeline_file = get_pipeline_ir( + self.map[submodel]["export_args"]["width"], + self.map[submodel]["export_args"]["height"], + self.map[submodel]["export_args"]["precision"], + self.map[submodel]["export_args"]["batch_size"], + self.map[submodel]["export_args"]["max_length"], + "unet_loop", + ) + dims = [ + self.map[submodel]["export_args"]["width"], + self.map[submodel]["export_args"]["height"], + ] + dims = "x".join([str(x) for x in dims]) + pipeline_keys = [ + utils.create_safe_name( + self.map[submodel]["export_args"]["hf_model_name"].split("/")[ + -1 + ], + "", + ), + "bs" + str(self.map[submodel]["export_args"]["batch_size"]), + dims, + self.map[submodel]["export_args"]["precision"], + str(self.map[submodel]["export_args"]["max_length"]), + "unetloop", + ] + vmfb_path = utils.compile_to_vmfb( + pipeline_file, + self.map["unet"]["device"], + self.map["unet"]["target"], + None, + os.path.join(self.pipeline_dir, "_".join(pipeline_keys)), + return_path=True, + mlir_source="str", + ) + self.map[submodel]["vmfb"] = vmfb_path + self.map[submodel]["weights"] = None + case "fullpipeline": # SDXL ONLY FOR NOW + pipeline_file = get_pipeline_ir( + self.map[submodel]["export_args"]["width"], + self.map[submodel]["export_args"]["height"], + self.map[submodel]["export_args"]["precision"], + self.map[submodel]["export_args"]["batch_size"], + self.map[submodel]["export_args"]["max_length"], + "tokens_to_image", + ) + dims = [ + self.map[submodel]["export_args"]["width"], + self.map[submodel]["export_args"]["height"], + ] + dims = "x".join([str(x) for x in dims]) + pipeline_keys = [ + utils.create_safe_name( + self.map[submodel]["export_args"]["hf_model_name"].split("/")[ + -1 + ], + "", + ), + "bs" + str(self.map[submodel]["export_args"]["batch_size"]), + dims, + self.map[submodel]["export_args"]["precision"], + str(self.map[submodel]["export_args"]["max_length"]), + "fullpipeline", + ] + vmfb_path = utils.compile_to_vmfb( + pipeline_file, + self.map["unet"]["device"], + self.map["unet"]["target"], + None, + os.path.join(self.pipeline_dir, "_".join(pipeline_keys)), + return_path=True, + mlir_source="str", + ) + self.map[submodel]["vmfb"] = vmfb_path + self.map[submodel]["weights"] = None + case _: + export_args = self.map[submodel].get("export_args", {}) + if weights_only: + export_args["weights_only"] = True + export_args["input_mlir"] = None + if export_args: + exported = self.map[submodel]["export_fn"](**export_args) + else: + exported = self.map[submodel]["export_fn"]() + if not self.map[submodel].get("weights") and self.map[submodel][ + "export_args" + ].get("external_weights", None): + self.map[submodel]["weights"] = self.map[submodel][ + "export_args" + ].get("external_weight_path", None) + if not weights_only: + self.map[submodel]["vmfb"] = exported + + # LOAD + def load_map(self): + for submodel in self.map.keys(): + if not self.map[submodel]["load"]: + print("Skipping load for ", submodel) + continue + self.load_submodel(submodel) + + def load_submodel(self, submodel): + if not self.map[submodel].get("vmfb"): + raise ValueError(f"VMFB not found for {submodel}.") + if not self.map[submodel].get("weights") and self.map[submodel].get( + "external_weights" + ): + raise ValueError(f"Weights not found for {submodel}.") + dest_type = self.map[submodel].get("dest_type", "devicearray") + self.map[submodel]["runner"] = PipelineComponent(dest_type=dest_type) + self.map[submodel]["runner"].load( + self.map[submodel]["driver"], + self.map[submodel]["vmfb"], + self.map[submodel]["module_name"], + self.map[submodel].get("weights"), + self.map[submodel].get("extra_plugin"), + ) + setattr(self, submodel, self.map[submodel]["runner"]) + + def unload_submodel(self, submodel): + self.map[submodel]["runner"].unload() + setattr(self, submodel, None) + + +def numpy_to_pil_image(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [] + for batched_image in images: + for image in range(0, batched_image.size(dim=0)): + pil_images.append(Image.fromarray(image.squeeze(), mode="L")) + else: + pil_images = [] + for image in images: + pil_images.append(Image.fromarray(image)) + return pil_images diff --git a/models/turbine_models/custom_models/resnet_18.py b/models/turbine_models/custom_models/resnet_18.py index c2321f49a..c1fd59b74 100644 --- a/models/turbine_models/custom_models/resnet_18.py +++ b/models/turbine_models/custom_models/resnet_18.py @@ -8,7 +8,7 @@ from iree.compiler.ir import Context import iree.runtime as rt from turbine_models.custom_models.sd_inference import utils - +import shark_turbine.ops.iree as ops import argparse parser = argparse.ArgumentParser() @@ -32,6 +32,7 @@ parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") # TODO: Add other resnet models +torch.random.manual_seed(0) class Resnet18Model(torch.nn.Module): @@ -43,8 +44,7 @@ def __init__(self): # self.extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-18") def forward(self, pixel_values_tensor: torch.Tensor): - with torch.no_grad(): - logits = self.model.forward(pixel_values_tensor).logits + logits = self.model.forward(pixel_values_tensor).logits predicted_id = torch.argmax(logits, -1) return predicted_id @@ -69,7 +69,24 @@ def main(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float32)): utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, "resnet_18") +def export_static_resnet_18_model( + resnet_model, compile_to="torch", device=None, target_triple=None, max_alloc=None +): + resnet_model = resnet_model.half() + input_args = (torch.empty((5, 3, 224, 224), dtype=torch.float16),) + exported = export(resnet_model, args=input_args) + + module_str = str(exported.mlir_module) + if compile_to != "vmfb": + return module_str + else: + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, "resnet_18") + + def run_resnet_18_vmfb_comparison(resnet_model, args): + import numpy as np + + torch_dtype = torch.float32 if args.precision == "fp32" else torch.float16 config = rt.Config(args.device) if args.vmfb_path: @@ -87,7 +104,8 @@ def run_resnet_18_vmfb_comparison(resnet_model, args): vm_modules=vm_modules, config=config, ) - inp = torch.rand(5, 3, 224, 224, dtype=torch.float32) + inp = torch.rand(5, 3, 224, 224, dtype=torch_dtype) + np.save(f"test_input_{args.precision}.npy", inp.numpy()) device_inputs = [rt.asdevicearray(config.device, inp)] # Turbine output @@ -104,10 +122,12 @@ def run_resnet_18_vmfb_comparison(resnet_model, args): torch_output = resnet_model.forward(inp) torch_output = torch_output.detach().cpu().numpy() print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + np.save(f"resnet18_golden_out.npy", torch_output) err = utils.largest_error(torch_output, turbine_output) print("LARGEST ERROR:", err) - assert err < 9e-5 + del CompModule + return err if __name__ == "__main__": diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py b/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py new file mode 100644 index 000000000..78acb4e5f --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py @@ -0,0 +1,416 @@ +import argparse +import os +from pathlib import Path + + +def path_expand(s): + return Path(s).expanduser().resolve() + + +def is_valid_file(arg): + if not os.path.exists(arg): + return None + else: + return arg + + +# Note: this is where command-line options for the scripts in this directory +# are defined along with their defaults. Thus, they should not be referenced +# within modelling or inference code, only at the entry point to the script. + +# We should consider separating out the options that are "model configs" from +# the options that control the compiler, runtime, and script behavior, +# when applicable, as the formermost would best be kept in a separate +# config or imported from huggingface. + +p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter +) + +############################################################################## +# SD3 Source Options +############################################################################## + +p.add_argument( + "--hf_auth_token", + type=str, + help="The Hugging Face auth token, if required", + default=None, +) +p.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="stabilityai/stable-diffusion-3-medium-diffusers", +) +p.add_argument( + "--scheduler_id", + type=str, + help="Scheduler ID", + default="EulerDiscrete", +) +p.add_argument( + "--model_path", + type=str, + help="Path to model .safetensors from which the model is defined.", + default=None, +) +p.add_argument( + "--vae_model_path", + type=str, + help="Path to vae model .safetensors from which the model is defined.", + default=None, +) + +############################################################################## +# SD3 Inference Options +# These options are used to control runtime parameters for SD3 inference. +############################################################################## + +p.add_argument( + "--prompt", + type=str, + default=" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + help="Prompt input to stable diffusion.", +) + +p.add_argument( + "--negative_prompt", + type=str, + default="Watermark, blurry, oversaturated, low resolution, pollution", + help="Negative prompt input to stable diffusion.", +) + +p.add_argument( + "--num_inference_steps", type=int, default=30, help="Number of UNet inference steps" +) + +p.add_argument( + "--batch_count", + type=int, + default=1, + help="Number of batches to run for a single prompt", +) + +p.add_argument( + "--guidance_scale", + type=float, + default=4, + help="Scale by which to adjust prompt guidance to the unconditional noise prediction output of UNet after each iteration.", +) + +p.add_argument( + "--seed", type=float, default=0, help="Seed for random number/latents generation." +) + +p.add_argument( + "--denoise", + type=float, + default=1.0, + help="Denoising factor for image to image", +) + +p.add_argument( + "--external_weight_path", + type=str, + default="", + help="Path to external weights file, for jobs with one weights filepath. When importing, this is used to specify where to save the model weights, and at runtime, this is used to specify where to load the model weights from.", +) + +p.add_argument( + "--external_weights_dir", + type=str, + default="", + help="Directory containing external weights for a job that requires more than one weights file. When importing, this is used to specify where to save the model weights, and at runtime, this is used to specify where to load the model weights from. Files will then be saved according to the parameters that make them unique, i.e. ___.", +) + +p.add_argument( + "--vmfb_path", type=str, default="", help="path to vmfb containing compiled module" +) + +p.add_argument( + "--pipeline_vmfb_path", + type=str, + default="", + help="path to vmfb containing compiled meta-module", +) + +p.add_argument( + "--scheduler_vmfb_path", + type=str, + default="", + help="path to vmfb containing compiled scheduler", +) + +p.add_argument( + "--split_scheduler", + default=False, + action="store_true", + help="Use a decoupled unet and scheduler for better QOL.", +) + +p.add_argument( + "--cpu_scheduling", + default=False, + action="store_true", + help="Run scheduling on torch cpu (will be slower due to data movement costs).", +) + +p.add_argument( + "--external_weight_file", + type=str, + default=None, + help="Path to external weights, used in benchmark scripts.", +) + +p.add_argument( + "--pipeline_dir", + type=str, + default=None, + help="Directory to save pipeline artifacts", +) + +p.add_argument( + "--compiled_pipeline", + default=False, + action="store_true", + help="Do one-shot inference from tokens to image in a shrink-wrapped pipeline binary.", +) + +p.add_argument( + "--npu_delegate_path", + type=str, + default=None, + help="Path to npu executable plugin .dll for running VAE on NPU.", +) + + +p.add_argument( + "--clip_device", + default=None, + type=str, + help="Device to run CLIP on. If None, defaults to the device specified in args.device.", +) + +p.add_argument( + "--mmdit_device", + default=None, + type=str, + help="Device to run MMDiT on. If None, defaults to the device specified in args.device.", +) + +p.add_argument( + "--vae_device", + default=None, + type=str, + help="Device to run VAE on. If None, defaults to the device specified in args.device.", +) + +p.add_argument( + "--clip_target", + default=None, + type=str, + help="IREE target for CLIP compilation. If None, defaults to the target specified by --iree_target_triple.", +) + +p.add_argument( + "--mmdit_target", + default=None, + type=str, + help="IREE target for mmdit compilation. If None, defaults to the target specified by --iree_target_triple.", +) + +p.add_argument( + "--vae_target", + default=None, + type=str, + help="IREE target for vae compilation. If None, defaults to the target specified by --iree_target_triple.", +) + +############################################################################## +# SD3 Modelling Options +# These options are used to control model defining parameters for SD3. +# These are MLIR - changing variables! If you change them, you will need +# to import/download and recompile the model. +############################################################################## + +p.add_argument("--batch_size", type=int, default=1, help="Batch size for inference") +p.add_argument( + "--height", type=int, default=1024, help="Height of Stable Diffusion output image." +) +p.add_argument( + "--width", type=int, default=1024, help="Width of Stable Diffusion output image" +) +p.add_argument( + "--precision", + type=str, + default="fp16", + help="Precision of Stable Diffusion weights and graph.", +) +p.add_argument( + "--vae_precision", + type=str, + default=None, + help="Precision of Stable Diffusion VAE weights and graph.", +) +p.add_argument( + "--max_length", type=int, default=77, help="Sequence Length of Stable Diffusion" +) +p.add_argument("--vae_variant", type=str, default="decode", help="encode, decode") +p.add_argument( + "--shift", type=float, default=3, help="Sampling shift value for sd3 scheduling" +) +p.add_argument( + "--vae_decomp_attn", + type=bool, + default=False, + help="Decompose attention for VAE decode only at fx graph level", +) +p.add_argument( + "--vae_dtype", + type=str, + default="fp32", + help="Precision of VAE graph.", +) + +############################################################################## +# SD3 script general options. +############################################################################## + +p.add_argument("--compile_to", type=str, default="mlir", help="torch, linalg, vmfb") +p.add_argument( + "--init_image", + type=str, + default=None, + help="Path to initial image for inference", +) +p.add_argument( + "--external_weights", + type=str, + default=None, + choices=["safetensors", "irpa", "gguf", None], + help="Externalizes model weights from the torch dialect IR and its successors", +) + +# See --external_weight_path and --external_weight_dir to specify where to save the model weights. +p.add_argument( + "--weights_only", + action="store_true", + help="Just grab the weights for your model and exit instead of exporting any IR.", +) +p.add_argument( + "--compare_vs_torch", + action="store_true", + help="Runs both turbine vmfb and a torch model to compare results", +) +p.add_argument( + "--decomp_attn", + default=False, + action="store_true", + help="Decompose attention at fx graph level", +) +p.add_argument( + "--exit_on_vmfb", + default=True, + action="store_false", + help="Exit program on vmfb compilation completion. Most scripts will also save .mlir if this is disabled.", +) +p.add_argument( + "--input_mlir", + type=str, + default=None, + help="Path to input mlir file to compile. Comma-separate paths to provide more than one input to pipelines.", +) +p.add_argument( + "--download_mlir", + default=False, + action="store_true", + help="Download missing mlir files from Azure storage.", +) +p.add_argument( + "--container_name", + type=str, + default=None, + help="Azure storage container name to download mlir files from.", +) +p.add_argument("--export", type=str, default="all", help="clip, mmdit, vae, all") +p.add_argument( + "--output", + type=str, + default="SD3_output.png", + help="Path to output file for generated images.", +) +p.add_argument( + "--attn_repro", + default=False, + action="store_true", + help="Just compile attention reproducer for mmdit.", +) +p.add_argument( + "--vae_input_path", + type=str, + default=None, + help="Path to input latents for VAE inference numerics validation.", +) + + +############################################################################## +# IREE Compiler Options +############################################################################## + +p.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") + +p.add_argument( + "--rt_device", + type=str, + default="local-task", + help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.", +) + +# TODO: Bring in detection for target triple +p.add_argument( + "--iree_target_triple", + type=str, + default="", + help="Specify vulkan target triple or rocm/cuda target device.", +) + +p.add_argument("--ireec_flags", type=str, default="", help="extra iree-compile options") + +p.add_argument( + "--attn_flags", + type=str, + default="", + help="extra iree-compile options for models with iree_linalg_ext.attention ops.", +) + +p.add_argument( + "--attn_spec", + type=str, + default=None, + help="extra iree-compile options for models with iree_linalg_ext.attention ops. Set this to 'default' if you are using mfma-capable hardware with ROCM.", +) + +p.add_argument( + "--clip_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling CLIP/prompt_encoder. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + +p.add_argument( + "--vae_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling VAE. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + +p.add_argument( + "--mmdit_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling unet. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + + +args, unknown = p.parse_known_args() diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_full.py b/models/turbine_models/custom_models/sd3_inference/sd3_full.py new file mode 100644 index 000000000..f88cda03f --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_full.py @@ -0,0 +1,277 @@ +# Copyrigh 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys + +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo + +import safetensors +import argparse +from turbine_models.turbine_tank import turbine_tank + +SEED = 1 + + +def export_vae( + model, + height, + width, + compile_to="torch", + external_weight_prefix=None, + device=None, + target_triple=None, + max_alloc="", + upload_ir=False, + dtype=torch.float32, +): + mapper = {} + utils.save_external_weights(mapper, model, "safetensors", external_weight_prefix) + latent_shape = [1, 16, height // 8, width // 8] + input_arg = torch.empty(latent_shape) + input_arg = (input_arg.to(dtype),) + if external_weight_prefix != None and len(external_weight_prefix) > 1: + externalize_module_parameters(model) + + exported = export(model, args=input_arg) + + module_str = str(exported.mlir_module) + safe_name = utils.create_safe_name(str(dtype).lstrip("torch."), "_mmdit") + if compile_to != "vmfb": + return module_str + else: + print("compiling to vmfb") + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) + return module_str + + +def export_unet_dynamic( + unet_model, + height, + width, + compile_to="torch", + external_weight_path=None, + device=None, + target_triple=None, + max_alloc="", + upload_ir=False, + dtype=torch.float32, +): + cond_shape = [1, 154, 4096] # 77, 4096] + pool_shape = [1, 2048] + latent_shape = [1, 16, height // 8, width // 8] + if dtype == torch.float16: + unet_model = unet_model.half() + mapper = {} + utils.save_external_weights(mapper, unet_model, "safetensors", external_weight_path) + + if weights_only: + return external_weight_path + + fxb = FxProgramsBuilder(unet_model) + + sigmas = torch.export.Dim("sigmas") + dynamic_shapes = {"sigmas": {0: sigmas}, "latent": {}, "noise": {}} + example_init_args = [ + torch.empty([19], dtype=dtype), + torch.empty(latent_shape, dtype=dtype), + torch.empty(latent_shape, dtype=dtype), + ] + example_sampling_args = [ + torch.empty(latent_shape, dtype=dtype), + torch.empty(1, dtype=dtype), + torch.empty(1, dtype=dtype), + torch.empty(cond_shape, dtype=dtype), + torch.empty(pool_shape, dtype=dtype), + torch.empty(cond_shape, dtype=dtype), + torch.empty(pool_shape, dtype=dtype), + torch.empty(1, dtype=dtype), + ] + + @fxb.export_program(args=(example_init_args,), dynamic_shapes=dynamic_shapes) + def _initialize(module, inputs): + # 1.0 is denoise currently symfloat not supported in fx_importer + return module.init_dynamic(*inputs) + + @fxb.export_program(args=(example_sampling_args,)) + def _do_sampling(module, inputs): + return module.do_sampling(*inputs) + + class CompiledTresleches(CompiledModule): + initialize = _initialize + do_sampling = _do_sampling + + # _vae_decode = vae_decode + + if external_weights: + externalize_module_parameters(unet_model) + save_module_parameters(external_weight_path, unet_model) + + inst = CompiledTresleches(context=Context(), import_to="IMPORT") + module_str = str(CompiledModule.get_mlir_module(inst)) + print("exported model") + + safe_name = utils.create_safe_name(str(dtype).lstrip("torch."), "_mmdit") + if compile_to != "vmfb": + return module_str + else: + print("compiling to vmfb") + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) + return module_str + + +def export_preprocessor( + model, + compile_to="torch", + external_weight_path=None, + device=None, + target_triple=None, + max_alloc="", + dtype=torch.float32, + height=512, + width=512, +): + external_weights = "safetensors" + + def get_noise(): + latent = torch.ones(1, 16, height // 8, width // 8, device="cpu") * 0.0609 + generator = torch.manual_seed(SEED) + return torch.randn( + latent.size(), + dtype=latent.dtype, + layout=latent.layout, + generator=generator, + device="cpu", + ) + + input_args = [torch.empty([1, 77, 2], dtype=torch.int64) for x in range(6)] + input_args += get_noise() + if dtype == torch.float16: + model = model.half() + + mapper = {} + + utils.save_external_weights(mapper, model, external_weights, external_weight_path) + + if external_weight_path != None and len(external_weight_path) > 1: + print("externalizing weights") + externalize_module_parameters(model) + + exported = export(model, args=tuple(input_args)) + print("exported model") + + # import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + # inst = CompiledTresleches(context=Context(), import_to=import_to) + + # module_str = str(CompiledModule.get_mlir_module(inst)) + module_str = str(exported.mlir_module) + safe_name = utils.create_safe_name("sd3", "clips") + if compile_to != "vmfb": + return module_str + else: + print("compiling to vmfb") + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) + return module_str + + +@torch.no_grad() +def main(args): + import turbine_sd3 + from safetensors import safe_open + + vulkan_max_allocation = "4294967296" if args.device == "vulkan" else "" + # st_file = "/mnt2/tresleches/models/sd3_8b_beta.safetensors" + st_file = "/mnt2/tresleches/models/sd3_2b_512_alpha.safetensors" + dtype = torch.float32 + if args.precision == "f16": + dtype = torch.float16 + elif args.precision == "bf16": + dtype = torch.bfloat16 + print(args.export) + + if args.export in ["dynamic"]: + print("exporting dynamic") + unet_model = turbine_sd3.SD3Inferencer( + model=st_file, vae=turbine_sd3.VAEFile, shift=1.0, dtype=dtype + ).eval() + mod_str = export_unet_dynamic( + unet_model=unet_model, + height=args.height, + width=args.width, + compile_to=args.compile_to, + external_weight_path=args.external_weight_path, + device=args.device, + target_triple=args.iree_target_triple, + max_alloc=vulkan_max_allocation, + upload_ir=False, + dtype=dtype, + ) + safe_name = utils.create_safe_name("hc_sd3", "-unet") + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") + export_pre = args.export in ["all", "clip"] + print(export_pre) + if export_pre: + print("exporting preprocessor") + pre = turbine_sd3.Preprocess() + mod_str = export_preprocessor( + model=pre, + compile_to=args.compile_to, + external_weight_path=args.external_weight_path, + device=args.device, + target_triple=args.iree_target_triple, + max_alloc=vulkan_max_allocation, + dtype=dtype, + height=args.height, + width=args.width, + ) + safe_name = utils.create_safe_name("hc_sd3", "_preprocess") + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") + should_export_vae = args.export in ["all", "vae"] + if should_export_vae: + print("exporting vae") + from turbine_impls import SDVAE + + with turbine_sd3.safe_open( + turbine_sd3.VAEFile, framework="pt", device="cpu" + ) as f: + vae = SDVAE(device="cpu", dtype=dtype).eval().cpu() + prefix = "" + if any(k.startswith("first_stage_model.") for k in f.keys()): + prefix = "first_stage_model." + turbine_sd3.load_into(f, vae, prefix, "cpu", dtype) + print("Something") + mod_str = export_vae( + model=vae, + height=args.height, + width=args.width, + compile_to=args.compile_to, + external_weight_prefix=args.external_weight_path, + device=args.device, + target_triple=args.iree_target_triple, + max_alloc=vulkan_max_allocation, + dtype=dtype, + ) + safe_name = utils.create_safe_name("hc_sd3", "_vae") + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") + + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + torch._dynamo.config.capture_scalar_outputs = True + main(args) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py new file mode 100644 index 000000000..b71d3129e --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py @@ -0,0 +1,335 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import copy +import os +import sys +import math + +from safetensors import safe_open +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from shark_turbine.dynamo.passes import ( + DEFAULT_DECOMPOSITIONS, +) +from shark_turbine.transforms.general.add_metadata import AddMetadataPass +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import SD3Transformer2DModel + + +class MMDiTModel(torch.nn.Module): + def __init__( + self, + hf_model_name="stabilityai/stable-diffusion-3-medium-diffusers", + dtype=torch.float16, + ): + super().__init__() + self.mmdit = SD3Transformer2DModel.from_pretrained( + hf_model_name, + subfolder="transformer", + torch_dtype=dtype, + low_cpu_mem_usage=False, + ) + + def forward( + self, + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + ): + noise_pred = self.mmdit( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + return_dict=False, + )[0] + return noise_pred + + +class MMDiTAttention(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention( + q, k, v, dropout_p=0.0, is_causal=False + ) + + +@torch.no_grad() +def export_attn( + precision="fp16", + device="cpu", + target_triple="x86_64-unknown-linux-gnu", + ireec_flags="", + compile_to="torch", + decomp_attn=False, + attn_spec=None, +): + dtype = torch.float16 if precision == "fp16" else torch.float32 + qkv_shape = (2, 24, 4250, 64) + attn_module = MMDiTAttention() + safe_name = "attn_repro_" + precision + "_" + target_triple + if decomp_attn == True: + safe_name += "_decomp" + + if dtype == torch.float16: + attn_module = attn_module.half() + + example_qkv = [ + torch.empty(qkv_shape, dtype=dtype), + torch.empty(qkv_shape, dtype=dtype), + torch.empty(qkv_shape, dtype=dtype), + ] + + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + fxb = FxProgramsBuilder(attn_module) + + @fxb.export_program( + args=(example_qkv,), + ) + def _forward( + module, + inputs, + ): + return module.forward(*inputs) + + class CompiledAttn(CompiledModule): + run_forward = _forward + + inst = CompiledAttn(context=Context(), import_to="IMPORT") + + module_str = str(CompiledModule.get_mlir_module(inst)) + + if compile_to != "vmfb": + return module_str + else: + vmfb_path = utils.compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags, + safe_name, + return_path=True, + attn_spec=attn_spec, + ) + return vmfb_path + + +@torch.no_grad() +def export_mmdit_model( + mmdit_model, + hf_model_name, + batch_size, + height, + width, + precision="fp32", + max_length=77, + hf_auth_token=None, + compile_to="torch", + external_weights=None, + external_weight_path=None, + device=None, + target_triple=None, + ireec_flags=None, + decomp_attn=False, + exit_on_vmfb=False, + pipeline_dir=None, + attn_spec=None, + input_mlir=None, + weights_only=False, +): + dtype = torch.float16 if precision == "fp16" else torch.float32 + np_dtype = "float16" if precision == "fp16" else "float32" + safe_name = utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_mmdit", + ) + if pipeline_dir: + safe_name = os.path.join(pipeline_dir, safe_name) + if decomp_attn == True: + ireec_flags += ",--iree-opt-aggressively-propagate-transposes=False" + + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target_triple, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + attn_spec=attn_spec, + ) + return vmfb_path + + mapper = {} + + utils.save_external_weights( + mapper, mmdit_model, external_weights, external_weight_path + ) + + if weights_only: + return external_weight_path + + do_classifier_free_guidance = True + init_batch_dim = 2 if do_classifier_free_guidance else 1 + batch_size = batch_size * init_batch_dim + hidden_states_shape = ( + batch_size, + 16, + height // 8, + width // 8, + ) + encoder_hidden_states_shape = (batch_size, 154, 4096) + pooled_projections_shape = (batch_size, 2048) + example_forward_args = [ + torch.empty(hidden_states_shape, dtype=dtype), + torch.empty(encoder_hidden_states_shape, dtype=dtype), + torch.empty(pooled_projections_shape, dtype=dtype), + torch.empty(init_batch_dim, dtype=dtype), + ] + + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + fxb = FxProgramsBuilder(mmdit_model) + + @fxb.export_program( + args=(example_forward_args,), + ) + def _forward( + module, + inputs, + ): + return module.forward(*inputs) + + class CompiledMmdit(CompiledModule): + run_forward = _forward + + if external_weights: + externalize_module_parameters(mmdit_model) + + inst = CompiledMmdit(context=Context(), import_to="IMPORT") + + module = CompiledModule.get_mlir_module(inst) + + model_metadata_run_forward = { + "model_name": "sd3_mmdit", + "input_shapes": [ + hidden_states_shape, + encoder_hidden_states_shape, + pooled_projections_shape, + init_batch_dim, + ], + "input_dtypes": [np_dtype for x in range(4)], + "output_shapes": [hidden_states_shape], + "output_dtypes": [np_dtype], + } + module = AddMetadataPass(module, model_metadata_run_forward, "run_forward").run() + module_str = str(module) + if compile_to != "vmfb": + return module_str + else: + vmfb_path = utils.compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags, + safe_name, + return_path=True, + attn_spec=attn_spec, + ) + if exit_on_vmfb: + exit() + return vmfb_path + + +if __name__ == "__main__": + import logging + + logging.basicConfig(level=logging.DEBUG) + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + if args.attn_repro: + mod_str = export_attn( + args.precision, + args.device, + args.iree_target_triple, + args.ireec_flags, + args.compile_to, + args.decomp_attn, + attn_spec=args.attn_spec, + ) + if args.compile_to != "vmfb": + safe_name = "attn_repro_" + args.precision + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") + exit() + if args.input_mlir: + mmdit_model = None + else: + mmdit_model = MMDiTModel( + args.hf_model_name, + dtype=torch.float16 if args.precision == "fp16" else torch.float32, + ) + mod_str = export_mmdit_model( + mmdit_model, + args.hf_model_name, + args.batch_size, + args.height, + args.width, + args.precision, + args.max_length, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_path, + args.device, + args.iree_target_triple, + args.ireec_flags + args.attn_flags + args.mmdit_flags, + args.decomp_attn, + attn_spec=args.attn_spec, + input_mlir=args.input_mlir, + weights_only=args.weights_only, + ) + if args.input_mlir: + exit() + safe_name = utils.create_safe_name( + args.hf_model_name, + f"_bs{args.batch_size}_{args.max_length}_{args.height}x{args.width}_{args.precision}_mmdit", + ) + if args.compile_to != "vmfb": + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py new file mode 100644 index 000000000..06100eab3 --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py @@ -0,0 +1,189 @@ +import argparse +from turbine_models.model_runner import vmfbRunner +from turbine_models.custom_models.sd_inference import utils, schedulers +from iree import runtime as ireert +import torch +import numpy as np +from tqdm.auto import tqdm +from shark_turbine.ops.iree import trace_tensor + +torch.random.manual_seed(0) + + +def run_mmdit_turbine( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + args, +): + mmdit_runner = vmfbRunner( + args.device, + args.vmfb_path, + args.external_weight_path, + ) + iree_inputs = [ + ireert.asdevicearray(mmdit_runner.config.device, hidden_states), + ireert.asdevicearray(mmdit_runner.config.device, encoder_hidden_states), + ireert.asdevicearray(mmdit_runner.config.device, pooled_projections), + ireert.asdevicearray(mmdit_runner.config.device, timestep), + ] + noise_pred = mmdit_runner.ctx.modules.compiled_mmdit["run_forward"]( + *iree_inputs + ).to_host() + return noise_pred + + +@torch.no_grad() +def run_diffusers_mmdit( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + args, +): + from turbine_models.custom_models.sd3_inference.sd3_mmdit import MMDiTModel + + mmdit_model = MMDiTModel( + args.hf_model_name, + dtype=torch.float32, + ) + noise_pred = mmdit_model.forward( + hidden_states.float(), + encoder_hidden_states.float(), + pooled_projections.float(), + timestep.float(), + ) + + return noise_pred.numpy() + + +def run_attn_turbine(q, k, v, args): + attn_runner = vmfbRunner( + args.device, + args.vmfb_path, + None, + ) + iree_inputs = [ + ireert.asdevicearray(attn_runner.config.device, q), + ireert.asdevicearray(attn_runner.config.device, k), + ireert.asdevicearray(attn_runner.config.device, v), + ] + attn_output = attn_runner.ctx.modules.compiled_attn["run_forward"]( + *iree_inputs + ).to_host() + return attn_output + + +@torch.no_grad() +def run_attn_torch(q, k, v, args): + from turbine_models.custom_models.sd3_inference.sd3_mmdit import MMDiTAttention + + mmdit_attn = MMDiTAttention() + attn_output = mmdit_attn.forward( + torch.tensor(q, dtype=torch.float32), + torch.tensor(k, dtype=torch.float32), + torch.tensor(v, dtype=torch.float32), + ) + + return attn_output.numpy() + + +def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]): + if not np.allclose(turbine_output, torch_output, rtol=4e-2, atol=4e-2): + if turbine_output.ndim > 0: + orig_dim = dim + for idx, i in enumerate(torch_output): + dim = [*orig_dim, idx] + try: + np.testing.assert_allclose( + turbine_output[idx], torch_output[idx], rtol=4e-2, atol=4e-2 + ) + except Exception as e: + err = np.abs(turbine_output[idx] - torch_output[idx]) + failed_dims.append(dim) + errs.append([err, turbine_output[idx], torch_output[idx]]) + failed_dims, errs = find_errs( + turbine_output[idx], torch_output[idx], dim, failed_dims, errs + ) + return (failed_dims, errs) + + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + import numpy as np + import os + + torch.random.manual_seed(0) + + if args.precision == "fp16": + dtype = torch.float16 + np_dtype = np.float16 + else: + dtype = torch.float32 + np_dtype = np.float32 + + if args.attn_repro: + qkv_shape = (2, 24, 4250, 64) + example_qkv = [ + np.load("q.npy").astype(np_dtype), + np.load("k.npy").astype(np_dtype), + np.load("v.npy").astype(np_dtype), + ] + turbine_output = run_attn_turbine( + *example_qkv, + args, + ) + torch_output = run_attn_torch(*example_qkv, args).astype(np.float16) + np.save("turbine_attn_output.npy", turbine_output) + np.save("torch_attn_output.npy", torch_output) + failed_dims, errs = find_errs(turbine_output, torch_output) + for idx, dim in enumerate(failed_dims): + if len(dim) == len(torch_output.shape): + print("Failed dimension: ", dim, " with error: ", errs[idx][0]) + print("Turbine output: ", errs[idx][1]) + print("Torch output: ", errs[idx][2]) + print(torch_output.shape) + exit() + + batch_size = args.batch_size * 2 # do classifier free guidance + hidden_states = torch.randn( + (batch_size, 16, args.height // 8, args.width // 8), dtype=dtype + ) + encoder_hidden_states = torch.randn( + (batch_size, args.max_length * 2, 4096), dtype=dtype + ) + pooled_projections = torch.randn((batch_size, 2048), dtype=dtype) + timestep = torch.tensor([0, 0], dtype=dtype) + + turbine_output = run_mmdit_turbine( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + args, + ) + print( + "TURBINE SPLIT OUTPUT:", + turbine_output, + turbine_output.shape, + turbine_output.dtype, + ) + turbine_output = turbine_output + + if args.compare_vs_torch: + print("generating torch output: ") + torch_output = run_diffusers_mmdit( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + args, + ) + np.save("torch_mmdit_output.npy", torch_output.astype(np.float16)) + print("torch OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + + print("\n(torch (comfy) image latents to iree image latents): ") + + np.testing.assert_allclose(turbine_output, torch_output, rtol=4e-2, atol=4e-2) + print("passed!") diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py new file mode 100644 index 000000000..1068d6b6c --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py @@ -0,0 +1,915 @@ +# Copyright 2024 Advanced Micro Devices, inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import torch +from tqdm.auto import tqdm +from turbine_models.custom_models.sd3_inference import ( + sd3_text_encoders, + sd3_mmdit, + sd3_vae, + sd3_schedulers, +) +from turbine_models.custom_models.sd3_inference.text_encoder_impls import SD3Tokenizer +import iree.runtime as ireert +from turbine_models.custom_models.sd_inference import utils +from turbine_models.model_runner import vmfbRunner +from transformers import CLIPTokenizer +from diffusers import FlowMatchEulerDiscreteScheduler + +from PIL import Image +import os +import numpy as np +import time +import copy +from datetime import datetime as dt + +empty_pipe_dict = { + "clip": None, + "mmdit": None, + "scheduler": None, + "vae": None, +} + +EMPTY_FLAGS = { + "clip": None, + "mmdit": None, + "vae": None, + "pipeline": None, +} + + +class SharkSD3Pipeline: + def __init__( + self, + hf_model_name: str, + height: int, + width: int, + precision: str, + max_length: int, + batch_size: int, + num_inference_steps: int, + device: str | dict[str], + iree_target_triple: str | dict[str], + ireec_flags: dict = EMPTY_FLAGS, + attn_spec: str = None, + decomp_attn: bool = False, + pipeline_dir: str = "./shark_vmfbs", + external_weights_dir: str = "./shark_weights", + external_weights: str = "safetensors", + vae_decomp_attn: bool = False, + cpu_scheduling: bool = False, + vae_precision: str = "fp32", + scheduler_id: str = None, # compatibility only, always uses EulerFlowScheduler + shift: float = 1.0, + custom_vae: str = None, + ): + self.hf_model_name = hf_model_name + # self.scheduler_id = scheduler_id + self.height = height + self.width = width + self.shift = shift + self.precision = precision + self.max_length = max_length + self.batch_size = batch_size + self.num_inference_steps = None + self.devices = {} + if isinstance(device, dict): + assert isinstance( + iree_target_triple, dict + ), "Device and target triple must be both dicts or both strings." + self.devices["clip"] = { + "device": device["clip"], + "driver": utils.iree_device_map(device["clip"]), + "target": iree_target_triple["clip"], + } + self.devices["mmdit"] = { + "device": device["mmdit"], + "driver": utils.iree_device_map(device["mmdit"]), + "target": iree_target_triple["mmdit"], + } + self.devices["vae"] = { + "device": device["vae"], + "driver": utils.iree_device_map(device["vae"]), + "target": iree_target_triple["vae"], + } + else: + assert isinstance( + iree_target_triple, str + ), "Device and target triple must be both dicts or both strings." + self.devices["clip"] = { + "device": device, + "driver": utils.iree_device_map(device), + "target": iree_target_triple, + } + self.devices["mmdit"] = { + "device": device, + "driver": utils.iree_device_map(device), + "target": iree_target_triple, + } + self.devices["vae"] = { + "device": device, + "driver": utils.iree_device_map(device), + "target": iree_target_triple, + } + self.iree_target_triple = iree_target_triple + self.ireec_flags = ireec_flags if ireec_flags else EMPTY_FLAGS + self.attn_spec = attn_spec + self.decomp_attn = decomp_attn + self.pipeline_dir = pipeline_dir + self.external_weights_dir = external_weights_dir + self.external_weights = external_weights + self.vae_decomp_attn = vae_decomp_attn + self.custom_vae = None + self.cpu_scheduling = cpu_scheduling + self.torch_dtype = torch.float32 if self.precision == "fp32" else torch.float16 + self.vae_precision = vae_precision if vae_precision else self.precision + self.vae_dtype = torch.float32 if vae_precision == "fp32" else torch.float16 + # TODO: set this based on user-inputted guidance scale and negative prompt. + self.do_classifier_free_guidance = True # False if any(x in hf_model_name for x in ["turbo", "lightning"]) else True + self._interrupt = False + + # FILE MANAGEMENT AND PIPELINE SETUP + + def check_prepared( + self, + mlirs: dict, + vmfbs: dict, + weights: dict, + interactive: bool = True, + ): + ready, vmfbs, weights = self.is_prepared(vmfbs, weights) + if not ready: + if interactive: + do_continue = input( + f"\nIt seems you are missing some necessary files. Would you like to generate them now? (y/n)" + ) + if do_continue.lower() != "y": + exit() + else: + do_continue = "y" + if do_continue.lower() == "y": + for submodel in vmfbs.keys(): + if vmfbs[submodel] == None: + print("Fetching: ", submodel) + vmfb, weight = self.export_submodel(submodel, input_mlir=mlirs) + vmfbs[submodel] = vmfb + if weights[submodel] is None: + weights[submodel] = weight + elif weights[submodel] is None and not any( + x in submodel for x in ["pipeline", "scheduler"] + ): + _, weight = self.export_submodel(submodel, weights_only=True) + weights[submodel] = weight + ready, vmfbs, weights = self.is_prepared(vmfbs, weights) + if ready: + print("All necessary files found.") + return vmfbs, weights + else: + print("There was an error generating the necessary files.") + exit() + else: + print("All necessary files found. Loading pipeline.") + return vmfbs, weights + + def is_prepared(self, vmfbs, weights): + missing = [] + dims = f"{str(self.width)}x{str(self.height)}" + for key in vmfbs: + if key == "scheduler": + continue + elif key == "vae": + keywords = ["vae", self.vae_precision, dims] + device_key = "vae" + elif key == "clip": + keywords = ["text_encoders", self.precision, self.max_length] + device_key = "clip" + else: + keywords = [key, self.precision, self.max_length, dims] + device_key = key + avail_files = os.listdir(self.pipeline_dir) + keywords.append("vmfb") + keywords.append(utils.create_safe_name(self.hf_model_name, "")) + keywords.append(self.devices[device_key]["target"]) + print(keywords) + for filename in avail_files: + if all(str(x) in filename for x in keywords): + vmfbs[key] = os.path.join(self.pipeline_dir, filename) + if not vmfbs[key]: + missing.append(key + " vmfb") + for w_key in weights: + if any(x in w_key for x in ["pipeline", "scheduler"]): + continue + if weights[w_key] is not None: + continue + if self.external_weights is None: + continue + default_name = os.path.join( + self.external_weights_dir, w_key + "." + self.external_weights + ) + if w_key == "clip": + default_name = os.path.join( + self.external_weights_dir, + f"sd3_text_encoders_{self.precision}.irpa", + ) + if w_key == "mmdit": + default_name = os.path.join( + self.external_weights_dir, + f"sd3_mmdit_{self.precision}." + self.external_weights, + ) + if weights[w_key] is None and os.path.exists(default_name): + weights[w_key] = os.path.join(default_name) + else: + missing.append(w_key + "." + self.external_weights) + if len(missing) > 0: + print(f"Missing files: " + ", ".join(missing)) + return False, vmfbs, weights + else: + return True, vmfbs, weights + + def get_mlir_from_turbine_tank(self, submodel, container_name): + from turbine_models.turbine_tank import downloadModelArtifacts + + safe_name = utils.create_safe_name( + self.hf_model_name, + f"_{self.max_length}_{self.height}x{self.width}_{self.precision}_{submodel}.mlir", + ) + mlir_path = downloadModelArtifacts( + safe_name, + container_name, + ) + return mlir_path + + # IMPORT / COMPILE PHASE + + def get_torch_models(self, submodel): + match submodel: + case "vae": + vae_torch = sd3_vae.VaeModel( + # This is a public model, so no auth required + self.hf_model_name, + ) + return vae_torch + case "mmdit": + mmdit_torch = sd3_mmdit.MMDiTModel( + dtype=self.torch_dtype, + ) + return mmdit_torch + + def export_submodel( + self, + submodel: str, + input_mlir: str = None, + weights_only: bool = False, + ): + if not os.path.exists(self.pipeline_dir): + os.makedirs(self.pipeline_dir) + if self.external_weights and self.external_weights_dir: + if not os.path.exists(self.external_weights_dir): + os.makedirs(self.external_weights_dir, exist_ok=True) + vae_external_weight_path = os.path.join( + self.external_weights_dir, + f"sd3_vae_{self.vae_precision}." + self.external_weights, + ) + mmdit_external_weight_path = os.path.join( + self.external_weights_dir, + f"sd3_mmdit_{self.precision}." + self.external_weights, + ) + text_encoders_external_weight_path = os.path.join( + self.external_weights_dir, f"sd3_text_encoders_{self.precision}.irpa" + ) + elif self.external_weights is None: + print( + "No external weights type specified using --external_weights, weights for imported .mlir files will not be externalized." + ) + vae_external_weight_path = None + mmdit_external_weight_path = None + text_encoders_external_weight_path = None + else: + print( + f"No external weights directory specified using --external_weights_dir, we assume you have your own weights in {self.pipeline_dir}." + ) + if not os.path.exists(self.pipeline_dir): + os.makedirs(self.pipeline_dir, exist_ok=True) + vae_external_weight_path = os.path.join( + self.pipeline_dir, + f"sd3_vae_{self.vae_precision}." + self.external_weights, + ) + mmdit_external_weight_path = os.path.join( + self.pipeline_dir, + f"sd3_mmdit_{self.precision}." + self.external_weights, + ) + text_encoders_external_weight_path = os.path.join( + self.pipeline_dir, f"sd3_text_encoders_{self.precision}.irpa" + ) + if weights_only: + input_mlir = { + "vae": None, + "clip": None, + "mmdit": None, + "scheduler": None, + } + match submodel: + case "mmdit": + if not input_mlir[submodel]: + mmdit_torch = self.get_torch_models("mmdit") + else: + mmdit_torch = None + mmdit_vmfb = sd3_mmdit.export_mmdit_model( + mmdit_torch, + self.hf_model_name, + self.batch_size, + self.height, + self.width, + self.precision, + self.max_length, + None, + "vmfb", + self.external_weights, + mmdit_external_weight_path, + self.devices["mmdit"]["device"], + self.devices["mmdit"]["target"], + self.ireec_flags["mmdit"], + self.decomp_attn, + exit_on_vmfb=False, + pipeline_dir=self.pipeline_dir, + attn_spec=self.attn_spec, + input_mlir=input_mlir["mmdit"], + weights_only=weights_only, + ) + del mmdit_torch + return mmdit_vmfb, mmdit_external_weight_path + case "scheduler": + scheduler_vmfb = sd3_schedulers.export_scheduler_model( + self.hf_model_name, + self.batch_size, + self.height, + self.width, + self.shift, + self.num_inference_steps, + self.precision, + "vmfb", + self.devices["mmdit"]["device"], + self.devices["mmdit"]["target"], + self.ireec_flags["scheduler"], + exit_on_vmfb=False, + pipeline_dir=self.pipeline_dir, + input_mlir=None, + ) + return scheduler_vmfb, None + case "vae": + if not input_mlir[submodel]: + vae_torch = self.get_torch_models("vae") + else: + vae_torch = None + vae_vmfb = sd3_vae.export_vae_model( + vae_torch, + self.hf_model_name, + self.batch_size, + self.height, + self.width, + self.vae_precision, + "vmfb", + self.external_weights, + vae_external_weight_path, + self.devices["vae"]["device"], + self.devices["vae"]["target"], + self.ireec_flags["vae"], + self.vae_decomp_attn, + exit_on_vmfb=False, + pipeline_dir=self.pipeline_dir, + attn_spec=self.attn_spec, + input_mlir=input_mlir["vae"], + weights_only=weights_only, + ) + del vae_torch + return vae_vmfb, vae_external_weight_path + case "clip": + _, text_encoders_vmfb = sd3_text_encoders.export_text_encoders( + self.hf_model_name, + None, + self.max_length, + self.precision, + "vmfb", + self.external_weights, + text_encoders_external_weight_path, + self.devices["clip"]["device"], + self.devices["clip"]["target"], + self.ireec_flags["clip"], + exit_on_vmfb=False, + pipeline_dir=self.pipeline_dir, + input_mlir=input_mlir["clip"], + attn_spec=self.attn_spec, + output_batchsize=self.batch_size, + ) + return text_encoders_vmfb, text_encoders_external_weight_path + + # LOAD + + def load_pipeline( + self, + vmfbs: dict, + weights: dict, + compiled_pipeline: bool = False, + split_scheduler: bool = True, + extra_device_args: dict = {}, + ): + if "npu_delegate_path" in extra_device_args.keys(): + delegate = extra_device_args["npu_delegate_path"] + else: + delegate = None + + self.runners = {} + runners = {} + load_start = time.time() + runners["pipe"] = vmfbRunner( + self.devices["mmdit"]["driver"], + vmfbs["mmdit"], + weights["mmdit"], + ) + unet_loaded = time.time() + print("\n[LOG] MMDiT loaded in ", unet_loaded - load_start, "sec") + + # if not self.cpu_scheduling: + # runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper( + # self.devices["mmdit"]["driver"], + # vmfbs["scheduler"], + # ) + # else: + # print("Using torch CPU scheduler.") + # runners["scheduler"] = FlowMatchEulerDiscreteScheduler.from_pretrained( + # self.hf_model_name, subfolder="scheduler" + # ) + + # sched_loaded = time.time() + # print("\n[LOG] Scheduler loaded in ", sched_loaded - unet_loaded, "sec") + runners["vae"] = vmfbRunner( + self.devices["vae"]["driver"], + vmfbs["vae"], + weights["vae"], + extra_plugin=delegate, + ) + vae_loaded = time.time() + print("\n[LOG] VAE Decode loaded in ", vae_loaded - unet_loaded, "sec") + runners["clip"] = vmfbRunner( + self.devices["clip"]["driver"], + vmfbs["clip"], + weights["clip"], + ) + clip_loaded = time.time() + print("\n[LOG] Text Encoders loaded in ", clip_loaded - vae_loaded, "sec") + + tok_start = time.time() + self.tokenizer = SD3Tokenizer() + tok_loaded = time.time() + print("\n[LOG] Tokenizers loaded in ", tok_loaded - tok_start, "sec") + self.runners = runners + self.compiled_pipeline = compiled_pipeline + print("Successfully loaded pipeline.") + + # RUN + + def generate_images( + self, + prompt: str, + negative_prompt: str = "", + batch_count: int = 1, + guidance_scale: float = 4, + seed: float = -1, + return_imgs: bool = False, + steps: int = None, + cpu_scheduling: bool = False, + scheduler_id: str = None, + progress=None, + ): + needs_new_scheduler = ( + steps and steps != self.num_inference_steps + ) or cpu_scheduling != self.cpu_scheduling + self.cpu_scheduling = cpu_scheduling + if steps: + self.num_inference_steps = steps + if steps and not self.cpu_scheduling and needs_new_scheduler: + self.runners["scheduler"] = None + self.num_inference_steps = steps + scheduler_path = f"EulerFlowScheduler_{self.num_inference_steps}" + if not os.path.exists(scheduler_path): + scheduler_path, _ = self.export_submodel("scheduler") + try: + self.runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper( + self.devices["mmdit"]["driver"], + scheduler_path, + ) + except: + print("JIT export of scheduler failed. Loading CPU scheduler.") + self.cpu_scheduling = True + if self.cpu_scheduling and needs_new_scheduler: + self.runners["scheduler"] = FlowMatchEulerDiscreteScheduler.from_pretrained( + self.hf_model_name, subfolder="scheduler" + ) + + # TODO: implement case where this is false e.g. in SDXL Turbo + do_classifier_free_guidance = True + + # Workaround for turbo support (guidance_scale 0) + if guidance_scale == 0: + negative_prompt = prompt + prompt = "" + + iree_dtype = "float32" if self.precision == "fp32" else "float16" + torch_dtype = torch.float32 if self.precision == "fp32" else torch.float16 + + samples = [] + numpy_images = [] + + for i in range(batch_count): + generator = torch.Generator().manual_seed(int(seed)) + shape = ( + self.batch_size, + 16, + self.height // 8, + self.width // 8, + ) + rand_sample = torch.randn( + shape, + generator=generator, + dtype=torch.float32, + layout=torch.strided, + ) + samples.append( + ireert.asdevicearray( + self.runners["pipe"].config.device, rand_sample, dtype=iree_dtype + ) + ) + + if not self.cpu_scheduling: + guidance_scale = ireert.asdevicearray( + self.runners["pipe"].config.device, + np.asarray([guidance_scale]), + dtype=iree_dtype, + ) + + tokenize_start = time.time() + text_input_ids_dict = self.tokenizer.tokenize_with_weights(prompt) + uncond_input_ids_dict = self.tokenizer.tokenize_with_weights(negative_prompt) + text_input_ids_list = list(text_input_ids_dict.values()) + uncond_input_ids_list = list(uncond_input_ids_dict.values()) + text_encoders_inputs = [ + ireert.asdevicearray( + self.runners["clip"].config.device, text_input_ids_list[0] + ), + ireert.asdevicearray( + self.runners["clip"].config.device, text_input_ids_list[1] + ), + ireert.asdevicearray( + self.runners["clip"].config.device, text_input_ids_list[2] + ), + ireert.asdevicearray( + self.runners["clip"].config.device, uncond_input_ids_list[0] + ), + ireert.asdevicearray( + self.runners["clip"].config.device, uncond_input_ids_list[1] + ), + ireert.asdevicearray( + self.runners["clip"].config.device, uncond_input_ids_list[2] + ), + ] + + # Tokenize prompt and negative prompt. + encode_prompts_start = time.time() + prompt_embeds, pooled_prompt_embeds = self.runners[ + "clip" + ].ctx.modules.compiled_text_encoder["encode_tokens"](*text_encoders_inputs) + encode_prompts_end = time.time() + if self.cpu_scheduling: + timesteps, num_inference_steps = sd3_schedulers.retrieve_timesteps( + self.runners["scheduler"], + num_inference_steps=steps, + timesteps=None, + ) + steps = num_inference_steps + + for i in range(batch_count): + if self._interrupt: + self._interrupt = False + return + unet_start = time.time() + if not self.cpu_scheduling: + latents, steps, timesteps = self.runners["scheduler"].initialize( + samples[i] + ) + else: + latents = torch.tensor(samples[i].to_host(), dtype=self.torch_dtype) + iree_inputs = [ + latents, + ireert.asdevicearray( + self.runners["pipe"].config.device, prompt_embeds, dtype=iree_dtype + ), + ireert.asdevicearray( + self.runners["pipe"].config.device, + pooled_prompt_embeds, + dtype=iree_dtype, + ), + None, + ] + for s in tqdm( + iterable=range(steps), desc=f"Inference steps ({steps}), batch {i+1}" + ): + if self._interrupt: + self._interrupt = False + return + # print(f"step {s}") + if self.cpu_scheduling: + step_index = s + t = timesteps[s] + if self.do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + timestep = ireert.asdevicearray( + self.runners["pipe"].config.device, + t.expand(latent_model_input.shape[0]), + dtype=iree_dtype, + ) + latent_model_input = ireert.asdevicearray( + self.runners["pipe"].config.device, + latent_model_input, + dtype=iree_dtype, + ) + else: + step_index = ireert.asdevicearray( + self.runners["scheduler"].runner.config.device, + torch.tensor([s]), + "int64", + ) + latent_model_input, timestep = self.runners["scheduler"].prep( + latents, + step_index, + timesteps, + ) + t = ireert.asdevicearray( + self.runners["scheduler"].runner.config.device, + timestep.to_host()[0], + ) + noise_pred = self.runners["pipe"].ctx.modules.compiled_mmdit[ + "run_forward" + ]( + latent_model_input, + iree_inputs[1], + iree_inputs[2], + timestep, + ) + if not self.cpu_scheduling: + latents = self.runners["scheduler"].step( + noise_pred, + t, + latents, + guidance_scale, + step_index, + ) + else: + noise_pred = torch.tensor( + noise_pred.to_host(), dtype=self.torch_dtype + ) + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + latents = self.runners["scheduler"].step( + noise_pred, + t, + latents, + return_dict=False, + )[0] + if isinstance(latents, torch.Tensor): + latents = latents.type(self.vae_dtype) + latents = ireert.asdevicearray( + self.runners["vae"].config.device, + latents, + ) + else: + vae_numpy_dtype = ( + np.float32 if self.vae_precision == "fp32" else np.float16 + ) + latents = latents.astype(vae_numpy_dtype) + + vae_start = time.time() + vae_out = self.runners["vae"].ctx.modules.compiled_vae["decode"](latents) + + pipe_end = time.time() + + image = vae_out.to_host() + + numpy_images.extend([image]) + print("Batch #", i + 1, "\n") + print( + "UNet time(", + self.num_inference_steps, + "): ", + vae_start - unet_start, + "sec,", + ) + print( + "Unet average step latency: ", + (vae_start - unet_start) / self.num_inference_steps, + "sec", + ) + print("VAE time: ", pipe_end - vae_start, "sec") + print( + f"\nTotal time (txt2img, batch #{str(i+1)}): ", + (encode_prompts_end - encode_prompts_start) + (pipe_end - unet_start), + "sec\n", + ) + end = time.time() + print("Total CLIP time:", encode_prompts_end - encode_prompts_start, "sec") + print("Total tokenize time:", encode_prompts_start - tokenize_start, "sec") + if batch_count > 1: + print( + f"Total inference time ({batch_count} batch(es)):", + end - encode_prompts_start, + "sec", + ) + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + images = [] + for idx, image in enumerate(numpy_images): + if image.ndim == 4: + image = image[0] + image = torch.from_numpy(image).cpu().permute(1, 2, 0).float().numpy() + image = (image * 255).round().astype("uint8") + out_image = Image.fromarray(image) + images.extend([[out_image]]) + if return_imgs: + return images + for idx_batch, image_batch in enumerate(images): + for idx, image in enumerate(image_batch): + img_path = ( + "sd3_output_" + + timestamp + + "_" + + str(idx_batch) + + "_" + + str(idx) + + ".png" + ) + image.save(img_path) + print(img_path, "saved") + return + + +def run_diffusers_cpu( + hf_model_name, + prompt, + negative_prompt, + guidance_scale, + seed, + height, + width, + num_inference_steps, +): + from diffusers import StableDiffusion3Pipeline + + pipe = StableDiffusion3Pipeline.from_pretrained( + hf_model_name, torch_dtype=torch.float32 + ) + pipe = pipe.to("cpu") + generator = torch.Generator().manual_seed(int(seed)) + + image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + height=height, + width=width, + generator=generator, + ).images[0] + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + image.save(f"diffusers_reference_output_{timestamp}.png") + + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + if args.compare_vs_torch: + run_diffusers_cpu( + args.hf_model_name, + args.prompt, + args.negative_prompt, + args.guidance_scale, + args.seed, + args.height, + args.width, + args.num_inference_steps, + ) + exit() + map = empty_pipe_dict + mlirs = copy.deepcopy(map) + vmfbs = copy.deepcopy(map) + weights = copy.deepcopy(map) + + if any(x for x in [args.clip_device, args.mmdit_device, args.vae_device]): + assert all( + x for x in [args.clip_device, args.mmdit_device, args.vae_device] + ), "Please specify device for all submodels or pass --device for all submodels." + assert all( + x for x in [args.clip_target, args.mmdit_target, args.vae_target] + ), "Please specify target triple for all submodels or pass --iree_target_triple for all submodels." + args.device = "hybrid" + args.iree_target_triple = "_".join( + [args.clip_target, args.mmdit_target, args.vae_target] + ) + else: + args.clip_device = args.device + args.mmdit_device = args.device + args.vae_device = args.device + args.clip_target = args.iree_target_triple + args.mmdit_target = args.iree_target_triple + args.vae_target = args.iree_target_triple + + devices = { + "clip": args.clip_device, + "mmdit": args.mmdit_device, + "vae": args.vae_device, + } + targets = { + "clip": args.clip_target, + "mmdit": args.mmdit_target, + "vae": args.vae_target, + } + ireec_flags = { + "clip": args.ireec_flags + args.clip_flags, + "mmdit": args.ireec_flags + args.unet_flags, + "vae": args.ireec_flags + args.vae_flags, + "pipeline": args.ireec_flags, + "scheduler": args.ireec_flags, + } + if not args.pipeline_dir: + pipe_id_list = [ + args.hf_model_name.split("/")[-1], + str(args.height), + str(args.width), + str(args.max_length), + args.precision, + args.device, + args.iree_target_triple, + ] + if args.decomp_attn: + pipe_id_list.append("decomp") + args.pipeline_dir = os.path.join( + ".", + "_".join(pipe_id_list), + ) + if args.input_mlir: + user_mlir_list = args.input_mlir.split(",") + else: + user_mlir_list = [] + for submodel_id, mlir_path in zip(mlirs.keys(), user_mlir_list): + if submodel_id in mlir_path: + mlirs[submodel_id] = mlir_path + if not args.external_weights_dir and args.external_weights: + args.external_weights_dir = args.pipeline_dir + sd3_pipe = SharkSD3Pipeline( + args.hf_model_name, + args.height, + args.width, + args.precision, + args.max_length, + args.batch_size, + args.num_inference_steps, + devices, + targets, + ireec_flags, + args.attn_spec, + args.decomp_attn, + args.pipeline_dir, + args.external_weights_dir, + external_weights=args.external_weights, + vae_decomp_attn=args.vae_decomp_attn, + cpu_scheduling=args.cpu_scheduling, + vae_precision=args.vae_precision, + ) + if args.cpu_scheduling: + vmfbs.pop("scheduler") + weights.pop("scheduler") + vmfbs, weights = sd3_pipe.check_prepared(mlirs, vmfbs, weights) + if args.npu_delegate_path: + extra_device_args = {"npu_delegate_path": args.npu_delegate_path} + else: + extra_device_args = {} + sd3_pipe.load_pipeline( + vmfbs, + weights, + args.compiled_pipeline, + args.split_scheduler, + extra_device_args=extra_device_args, + ) + sd3_pipe.generate_images( + args.prompt, + args.negative_prompt, + args.batch_count, + args.guidance_scale, + args.seed, + False, + ) + print("Image generation complete.") diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py new file mode 100644 index 000000000..2c1d04cf1 --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py @@ -0,0 +1,395 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import inspect +from typing import List + +import torch +from typing import Any, Callable, Dict, List, Optional, Union +from shark_turbine.aot import * +import shark_turbine.ops.iree as ops +from shark_turbine.transforms.general.add_metadata import AddMetadataPass +from iree.compiler.ir import Context +import iree.runtime as ireert +import numpy as np + +from diffusers import ( + FlowMatchEulerDiscreteScheduler, +) + +from turbine_models.turbine_tank import turbine_tank +from turbine_models.custom_models.sd_inference import utils +from turbine_models.model_runner import vmfbRunner + + +class SharkSchedulerWrapper: + def __init__(self, rt_device, vmfb): + self.runner = vmfbRunner(rt_device, vmfb, None) + + def initialize(self, sample): + sample, steps, timesteps = self.runner.ctx.modules.compiled_scheduler[ + "run_init" + ](sample) + return sample, steps.to_host(), timesteps + + def prep(self, sample, t, timesteps): + return self.runner.ctx.modules.compiled_scheduler["run_prep"]( + sample, t, timesteps + ) + + def step(self, noise_pred, t, sample, guidance_scale, step_index): + return self.runner.ctx.modules.compiled_scheduler["run_step"]( + noise_pred, t, sample, guidance_scale, step_index + ) + + +class FlowSchedulingModel(torch.nn.Module): + def __init__( + self, + hf_model_name, + num_inference_steps, + dtype, + ): + super().__init__() + # For now, assumes SDXL implementation. May not need parametrization for other models, + # but keeping hf_model_name in case. + self.model = FlowMatchEulerDiscreteScheduler.from_pretrained( + hf_model_name, subfolder="scheduler" + ) + self.do_classifier_free_guidance = True + self.model.set_timesteps(num_inference_steps) + self.timesteps = self.model.timesteps + self.dtype = dtype + + # TODO: Make steps dynamic here + def initialize(self, sample): + step_count = torch.tensor(len(self.timesteps)) + timesteps = self.model.timesteps + # ops.trace_tensor("sample", sample[:,:,0,0]) + return ( + sample, + step_count, + timesteps.type(torch.float32), + ) + + def prepare_model_input(self, sample, t, timesteps): + t = timesteps[t] + + if self.do_classifier_free_guidance: + latent_model_input = torch.cat([sample] * 2) + else: + latent_model_input = sample + t = t.expand(latent_model_input.shape[0]) + return latent_model_input.type(self.dtype), t.type(self.dtype) + + def step(self, noise_pred, t, sample, guidance_scale, i): + self.model._step_index = i + + if self.do_classifier_free_guidance: + noise_preds = noise_pred.chunk(2) + noise_pred = noise_preds[0] + guidance_scale * ( + noise_preds[1] - noise_preds[0] + ) + sample = self.model.step(noise_pred, t, sample, return_dict=False)[0] + return sample.type(self.dtype) + + +# Wraps a diffusers scheduler running on native pytorch+cpu. +# This allows us to use it interchangeably with compiled schedulers in our pipeline(s). +class TorchCPUFlowSchedulerCompat: + @torch.no_grad() + def __init__( + self, scheduler, batch_size, num_inference_steps, dest_device, latents_dtype + ): + self.do_classifier_free_guidance = True + self.module = scheduler + self.dest = dest_device + self.dtype = latents_dtype + self.batch_size = batch_size + self.module.set_timesteps(num_inference_steps) + self.timesteps = self.module.timesteps + self.torch_dtype = ( + torch.float32 if latents_dtype == "float32" else torch.float16 + ) + + def initialize(self, sample): + if isinstance(sample, ireert.DeviceArray): + sample = torch.tensor(sample.to_host(), dtype=torch.float32) + step_indexes = torch.tensor(len(self.module.timesteps)) + timesteps = self.timesteps + return sample, step_indexes, timesteps + + def scale_model_input(self, sample, t, timesteps): + if self.do_classifier_free_guidance: + sample = torch.cat([sample] * 2) + t = timesteps[t] + t = t.expand(sample.shape[0]) + t = ireert.asdevicearray(self.dest, [t], self.dtype) + sample = ireert.asdevicearray(self.dest, sample, self.dtype) + return sample, t + + def step(self, noise_pred, t, latents, guidance_scale, i): + if isinstance(t, ireert.DeviceArray): + t = torch.tensor(t.to_host()) + if isinstance(guidance_scale, ireert.DeviceArray): + guidance_scale = torch.tensor(guidance_scale.to_host()) + noise_pred = torch.tensor(noise_pred.to_host()) + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + return self.module.step( + noise_pred, + t, + latents, + return_dict=False, + )[0] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +# Only used for cpu scheduling. +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@torch.no_grad() +def export_scheduler_model( + hf_model_name: str, + batch_size: int = 1, + height: int = 512, + width: int = 512, + shift: int = 1.0, + num_inference_steps: int = 30, + precision: str = "fp16", + compile_to: str = "torch", + device: str = None, + target_triple: str = None, + ireec_flags: str = None, + exit_on_vmfb: bool = False, + pipeline_dir: str = None, + input_mlir: str = None, + upload_ir=False, +): + dtype = torch.float16 if precision == "fp16" else torch.float32 + np_dtype = "float16" if precision == "fp16" else "float32" + scheduler_module = FlowSchedulingModel(hf_model_name, num_inference_steps, dtype) + vmfb_names = [ + "EulerFlowScheduler", + f"bs{batch_size}_{height}x{width}", + precision, + str(num_inference_steps), + target_triple, + ] + vmfb_name = "_".join(vmfb_names) + safe_name = utils.create_safe_name(hf_model_name, "_" + vmfb_name) + if pipeline_dir: + safe_name = os.path.join(pipeline_dir, safe_name) + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target_triple, + ireec_flags, + safe_name + "_" + target_triple, + mlir_source="file", + return_path=not exit_on_vmfb, + ) + return vmfb_path + + do_classifier_free_guidance = True + if do_classifier_free_guidance: + init_batch_dim = 2 + else: + init_batch_dim = 1 + + sample = ( + batch_size, + 16, + height // 8, + width // 8, + ) + noise_pred_shape = ( + batch_size * init_batch_dim, + 16, + height // 8, + width // 8, + ) + example_init_args = [torch.empty(sample, dtype=dtype)] + example_prep_args = ( + torch.empty(sample, dtype=dtype), + torch.empty(1, dtype=torch.int64), + torch.empty([19], dtype=torch.float32), + ) + timesteps = torch.export.Dim("timesteps") + prep_dynamic_args = { + "sample": {}, + "t": {}, + "timesteps": {0: timesteps}, + } + example_step_args = [ + torch.empty(noise_pred_shape, dtype=dtype), + torch.empty(1, dtype=dtype), + torch.empty(sample, dtype=dtype), + torch.empty(1, dtype=dtype), + torch.empty(1, dtype=torch.int64), + ] + + fxb = FxProgramsBuilder(scheduler_module) + + @fxb.export_program( + args=(example_init_args,), + ) + def _initialize(module, sample): + return module.initialize(*sample) + + @fxb.export_program( + args=example_prep_args, + dynamic_shapes=prep_dynamic_args, + ) + def _prep(module, sample, t, timesteps): + return module.prepare_model_input(sample, t, timesteps) + + @fxb.export_program( + args=(example_step_args,), + ) + def _step(module, inputs): + return module.step(*inputs) + + decomp_list = [] + # if decomp_attn == True: + # decomp_list.extend( + # [ + # torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + # torch.ops.aten._scaled_dot_product_flash_attention.default, + # ] + # ) + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + + class CompiledScheduler(CompiledModule): + run_init = _initialize + run_prep = _prep + run_step = _step + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledScheduler(context=Context(), import_to=import_to) + + module = CompiledModule.get_mlir_module(inst) + + model_metadata_run_init = { + "model_name": "sd3_scheduler_FlowEulerDiscrete", + "input_shapes": [sample], + "input_dtypes": [np_dtype], + "output_shapes": [sample, "?", "?"], + "output_dtypes": [np_dtype, "int32", "float32"], + } + model_metadata_run_prep = { + "model_name": "sd3_scheduler_FlowEulerDiscrete", + "input_shapes": [sample, 1, [19]], + "input_dtypes": [np_dtype, "float32", "float32"], + "output_shapes": [noise_pred_shape, noise_pred_shape[0]], + "output_dtypes": [np_dtype, "float32"], + } + model_metadata_run_step = { + "model_name": "sd3_scheduler_FlowEulerDiscrete", + "input_shapes": [noise_pred_shape, 1, sample, 1, 1], + "input_dtypes": [np_dtype, np_dtype, np_dtype, np_dtype, "int64"], + "output_shapes": [sample], + "output_dtypes": [np_dtype], + } + module = AddMetadataPass(module, model_metadata_run_init, "run_init").run() + module = AddMetadataPass(module, model_metadata_run_prep, "run_prep").run() + module = AddMetadataPass(module, model_metadata_run_step, "run_step").run() + + module_str = str(module) + if compile_to != "vmfb": + return module_str + elif compile_to == "vmfb": + vmfb = utils.compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags, + safe_name + "_" + target_triple, + return_path=True, + ) + if exit_on_vmfb: + exit() + return vmfb + + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + mod_str = export_scheduler_model( + args.hf_model_name, + args.batch_size, + args.height, + args.width, + args.shift, + args.num_inference_steps, + args.precision, + args.compile_to, + args.device, + args.iree_target_triple, + args.ireec_flags, + exit_on_vmfb=False, + input_mlir=args.input_mlir, + ) + vmfb_names = [ + "EulerFlowScheduler", + f"bs{args.batch_size}_{args.height}x{args.width}", + args.precision, + str(args.num_inference_steps), + args.iree_target_triple, + ] + safe_name = "_".join(vmfb_names) + if args.compile_to != "vmfb": + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py new file mode 100644 index 000000000..d3e4ecb54 --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py @@ -0,0 +1,251 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys + +import safetensors +from iree import runtime as ireert +import iree.compiler as ireec +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from shark_turbine.transforms.general.add_metadata import AddMetadataPass +from turbine_models.custom_models.sd_inference import utils +import torch +from turbine_models.custom_models.sd3_inference.text_encoder_impls import ( + SDClipModel, + SDXLClipG, + T5XXLModel, + load_into, +) +from huggingface_hub import hf_hub_download +from safetensors import safe_open + +CLIPG_CONFIG = { + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, +} + +CLIPL_CONFIG = { + "hidden_act": "quick_gelu", + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, +} + +T5_CONFIG = { + "d_ff": 10240, + "d_model": 4096, + "num_heads": 64, + "num_layers": 24, + "vocab_size": 32128, +} + + +class TextEncoderModule(torch.nn.Module): + @torch.no_grad() + def __init__( + self, + batch_size=1, + ): + super().__init__() + self.dtype = torch.float16 + self.clip_l = SDClipModel( + layer="hidden", + layer_idx=-2, + device="cpu", + dtype=self.dtype, + layer_norm_hidden_state=False, + return_projected_pooled=False, + textmodel_json_config=CLIPL_CONFIG, + ).half() + clip_l_weights = hf_hub_download( + repo_id="stabilityai/stable-diffusion-3-medium", + filename="text_encoders/clip_l.safetensors", + ) + with safe_open(clip_l_weights, framework="pt", device="cpu") as f: + load_into(f, self.clip_l.transformer, "", "cpu", self.dtype) + self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=self.dtype).half() + clip_g_weights = hf_hub_download( + repo_id="stabilityai/stable-diffusion-3-medium", + filename="text_encoders/clip_g.safetensors", + ) + with safe_open(clip_g_weights, framework="pt", device="cpu") as f: + load_into(f, self.clip_g.transformer, "", "cpu", self.dtype) + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=self.dtype).half() + t5_weights = hf_hub_download( + repo_id="stabilityai/stable-diffusion-3-medium", + filename="text_encoders/t5xxl_fp16.safetensors", + ) + with safe_open(t5_weights, framework="pt", device="cpu") as f: + load_into(f, self.t5xxl.transformer, "", "cpu", self.dtype) + + self.do_classifier_free_guidance = True + self.batch_size = batch_size + + def get_cond(self, tokens_l, tokens_g, tokens_t5xxl): + l_out, l_pooled = self.clip_l.forward(tokens_l) + g_out, g_pooled = self.clip_g.forward(tokens_g) + t5_out, _ = self.t5xxl.forward(tokens_t5xxl) + lg_out = torch.cat([l_out, g_out], dim=-1) + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + return torch.cat([lg_out, t5_out], dim=-2), torch.cat( + (l_pooled, g_pooled), dim=-1 + ) + + def forward(self, tokens_g, tokens_l, tokens_t5xxl, neg_g, neg_l, neg_t5): + conditioning, cond_pool = self.get_cond(tokens_l, tokens_g, tokens_t5xxl) + neg_cond, neg_cond_pool = self.get_cond(neg_l, neg_g, neg_t5) + + prompt_embeds = torch.cat([neg_cond, conditioning], dim=0) + pooled_prompt_embeds = torch.cat([neg_cond_pool, cond_pool], dim=0) + + return prompt_embeds, pooled_prompt_embeds + + +@torch.no_grad() +def export_text_encoders( + hf_model_name, + max_length=64, + batch_size=1, + precision="fp16", + compile_to="torch", + external_weights=None, + external_weight_path=None, + device=None, + target_triple=None, + ireec_flags=None, + exit_on_vmfb=False, + pipeline_dir=None, + input_mlir=None, + attn_spec=None, + decomp_attn=True, +): + + safe_name = utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{str(max_length)}_{precision}_text_encoders", + ) + if pipeline_dir: + safe_name = os.path.join(pipeline_dir, safe_name) + + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target_triple, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + const_expr_hoisting=True, + attn_spec=attn_spec, + ) + return vmfb_path + model = TextEncoderModule( + batch_size=batch_size, + ) + mapper = {} + + assert ( + ".safetensors" not in external_weight_path + ), "Original parameters format incompatible with IREE safetensors parser. Use '.irpa' instead." + + input_args = [torch.empty([batch_size, 77, 2], dtype=torch.int64) for x in range(6)] + + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + fxb = FxProgramsBuilder(model) + + @fxb.export_program( + args=(input_args,), + ) + def _forward( + module, + inputs, + ): + return module.forward(*inputs) + + class CompiledTextEncoder(CompiledModule): + encode_tokens = _forward + + if external_weights: + externalize_module_parameters(model) + save_module_parameters(external_weight_path, model) + + inst = CompiledTextEncoder(context=Context(), import_to="IMPORT") + + module = CompiledModule.get_mlir_module(inst) + + model_metadata_forward = { + "model_name": "sd3_clip_t5xxl_text_encoders", + "input_shapes": [(batch_size, max_length, 2) for x in range(6)], + "input_dtypes": ["int64" for x in range(6)], + "output_shapes": [ + (2 * batch_size, max_length * 2, 4096), + (2 * batch_size, 2048), + ], + "output_dtypes": ["float32"], + } + module = AddMetadataPass(module, model_metadata_forward, "forward").run() + module_str = str(module) + if compile_to != "vmfb": + return module_str + else: + vmfb_path = utils.compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags, + safe_name, + return_path=not exit_on_vmfb, + const_expr_hoisting=True, + attn_spec=attn_spec, + ) + return vmfb_path + + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + mod_str, _ = export_text_encoders( + args.hf_model_name, + args.max_length, + args.batch_size, + args.precision, + args.compile_to, + args.external_weights, + args.external_weight_path, + args.device, + args.iree_target_triple, + args.ireec_flags + args.clip_flags, + exit_on_vmfb=True, + pipeline_dir=args.pipeline_dir, + input_mlir=args.input_mlir, + attn_spec=args.attn_spec, + ) + if args.input_mlir or args.weights_only or args.compile_to == "vmfb": + exit() + safe_name = utils.create_safe_name( + args.hf_model_name, f"_{str(args.max_length)}_{args.precision}_text_encoders" + ) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders_runner.py b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders_runner.py new file mode 100644 index 000000000..ec54227ab --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders_runner.py @@ -0,0 +1,119 @@ +from turbine_models.model_runner import vmfbRunner +from turbine_models.custom_models.sd3_inference.text_encoder_impls import ( + SD3Tokenizer, + T5XXLTokenizer, + SDXLClipGTokenizer, +) +from iree import runtime as ireert +import torch +import numpy as np + + +def run_prompt_encoder( + vmfb_path, + device, + external_weight_path, + input_ids, + uncond_input_ids, +): + prompt_encoder_runner = vmfbRunner(device, vmfb_path, external_weight_path) + # np.save("input0.npy", input_ids[0].numpy()) + # np.save("input1.npy", input_ids[1].numpy()) + # np.save("input2.npy", input_ids[2].numpy()) + # np.save("input3.npy", uncond_input_ids[0].numpy()) + # np.save("input4.npy", uncond_input_ids[1].numpy()) + # np.save("input5.npy", uncond_input_ids[2].numpy()) + prompt_encoder_inputs = [ + ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[0]), + ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[1]), + ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[2]), + ireert.asdevicearray(prompt_encoder_runner.config.device, uncond_input_ids[0]), + ireert.asdevicearray(prompt_encoder_runner.config.device, uncond_input_ids[1]), + ireert.asdevicearray(prompt_encoder_runner.config.device, uncond_input_ids[2]), + ] + encoded_outputs = prompt_encoder_runner.ctx.modules.compiled_text_encoder[ + "encode_tokens" + ](*prompt_encoder_inputs) + for i in encoded_outputs: + i = i.to_host() + del prompt_encoder_inputs + return encoded_outputs + + +def run_tokenize( + tokenizer, + prompt, + negative_prompt, +): + prompt_tokens_dict = tokenizer.tokenize_with_weights(prompt) + neg_prompt_tokens_dict = tokenizer.tokenize_with_weights(negative_prompt) + text_input_ids_list = list(prompt_tokens_dict.values()) + uncond_input_ids_list = list(neg_prompt_tokens_dict.values()) + return text_input_ids_list, uncond_input_ids_list + + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + tokenizer = SD3Tokenizer() + + text_input_ids_list, uncond_input_ids_list = run_tokenize( + tokenizer, + args.prompt, + args.negative_prompt, + ) + turbine_output1, turbine_output2 = run_prompt_encoder( + args.vmfb_path, + args.rt_device, + args.external_weight_path, + text_input_ids_list, + uncond_input_ids_list, + ) + print( + "TURBINE OUTPUT 1:", + turbine_output1.to_host(), + turbine_output1.shape, + turbine_output1.dtype, + ) + + print( + "TURBINE OUTPUT 2:", + turbine_output2.to_host(), + turbine_output2.shape, + turbine_output2.dtype, + ) + + if args.compare_vs_torch: + print("generating torch output: ") + from turbine_models.custom_models.sd_inference import utils + from turbine_models.custom_models.sd3_inference.sd3_text_encoders import ( + TextEncoderModule, + ) + + torch_encoder_model = TextEncoderModule( + args.batch_size, + ) + torch_output1, torch_output2 = torch_encoder_model.forward( + *text_input_ids_list, *uncond_input_ids_list + ) + np.save("torch_output1.npy", torch_output1) + np.save("torch_output2.npy", torch_output2) + print( + "TORCH OUTPUT 1:", torch_output1, torch_output1.shape, torch_output1.dtype + ) + + print( + "TORCH OUTPUT 2:", torch_output2, torch_output2.shape, torch_output2.dtype + ) + rtol = 4e-2 + atol = 4e-2 + + np.testing.assert_allclose( + torch_output1, turbine_output1.to_host(), rtol, atol, verbose=True + ) + np.testing.assert_allclose( + torch_output2, turbine_output2.to_host(), rtol, atol, verbose=True + ) + print("Passed!") + # TODO: Figure out why we occasionally segfault without unlinking output variables + turbine_output1, turbine_output2 = (None, None) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_vae.py b/models/turbine_models/custom_models/sd3_inference/sd3_vae.py new file mode 100644 index 000000000..ff24864a6 --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_vae.py @@ -0,0 +1,200 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import copy +import os +import sys + +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from shark_turbine.dynamo.passes import ( + DEFAULT_DECOMPOSITIONS, +) +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import AutoencoderKL + + +class VaeModel(torch.nn.Module): + def __init__( + self, + hf_model_name, + ): + super().__init__() + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + + def decode(self, inp): + inp = (inp / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(inp, return_dict=False)[0] + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + return image + + def encode(self, inp): + image_np = inp / 255.0 + image_np = np.moveaxis(image_np, 2, 0) + batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0) + image_torch = torch.from_numpy(batch_images) + image_torch = 2.0 * image_torch - 1.0 + image_torch = image_torch + latent = self.vae.encode(image_torch) + return latent + + +def export_vae_model( + vae_model, + hf_model_name, + batch_size, + height, + width, + precision, + compile_to="torch", + external_weights=None, + external_weight_path=None, + device=None, + target_triple=None, + ireec_flags=None, + decomp_attn=False, + exit_on_vmfb=False, + pipeline_dir=None, + attn_spec=None, + input_mlir=None, + weights_only=False, +): + dtype = torch.float16 if precision == "fp16" else torch.float32 + safe_name = utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{height}x{width}_{precision}_vae", + ) + if pipeline_dir: + safe_name = os.path.join(pipeline_dir, safe_name) + + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target_triple, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + attn_spec=attn_spec, + ) + return vmfb_path + + if device == "cpu": + decomp_attn = True + + if dtype == torch.float16: + vae_model = vae_model.half() + mapper = {} + utils.save_external_weights( + mapper, vae_model, external_weights, external_weight_path + ) + if weights_only: + return external_weight_path + + input_image_shape = (height, width, 3) + input_latents_shape = (batch_size, 16, height // 8, width // 8) + encode_args = [ + torch.empty( + input_image_shape, + dtype=torch.float32, + ) + ] + decode_args = [ + torch.empty( + input_latents_shape, + dtype=dtype, + ) + ] + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + fxb = FxProgramsBuilder(vae_model) + + # @fxb.export_program(args=(encode_args,)) + # def _encode(module, inputs,): + # return module.encode(*inputs) + + @fxb.export_program(args=(decode_args,)) + def _decode(module, inputs): + return module.decode(*inputs) + + class CompiledVae(CompiledModule): + decode = _decode + + if external_weights: + externalize_module_parameters(vae_model) + + inst = CompiledVae(context=Context(), import_to="IMPORT") + + module_str = str(CompiledModule.get_mlir_module(inst)) + + if compile_to != "vmfb": + return module_str + else: + vmfb_path = utils.compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags, + safe_name, + return_path=not exit_on_vmfb, + attn_spec=attn_spec, + ) + return vmfb_path + + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + if args.input_mlir: + vae_model = None + else: + vae_model = VaeModel( + args.hf_model_name, + ) + mod_str = export_vae_model( + vae_model, + args.hf_model_name, + args.batch_size, + height=args.height, + width=args.width, + precision=args.precision, + compile_to=args.compile_to, + external_weights=args.external_weights, + external_weight_path=args.external_weight_path, + device=args.device, + target_triple=args.iree_target_triple, + ireec_flags=args.ireec_flags + args.attn_flags + args.vae_flags, + decomp_attn=args.decomp_attn, + attn_spec=args.attn_spec, + input_mlir=args.input_mlir, + ) + if args.input_mlir or (args.compile_to == "vmfb"): + exit() + safe_name = utils.create_safe_name( + args.hf_model_name, + f"_bs{str(args.batch_size)}_{args.height}x{args.width}_{args.precision}_vae", + ) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py b/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py new file mode 100644 index 000000000..521f90bb9 --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py @@ -0,0 +1,94 @@ +import argparse +from turbine_models.model_runner import vmfbRunner +from iree import runtime as ireert +import torch + +torch.random.manual_seed(0) + + +def run_vae( + device, + example_input, + vmfb_path, + hf_model_name, + external_weight_path, +): + runner = vmfbRunner(device, vmfb_path, external_weight_path) + inputs = [ireert.asdevicearray(runner.config.device, example_input)] + results = runner.ctx.modules.compiled_vae["decode"](*inputs).to_host() + results = imagearray_from_vae_out(results) + return results + + +def run_torch_vae(hf_model_name, variant, example_input): + from turbine_models.custom_models.sd_inference.vae import SD3VaeModel + + vae_model = SD3VaeModel( + hf_model_name, + ) + + if variant == "decode": + results = vae_model.decode(example_input) + elif variant == "encode": + results = vae_model.encode(example_input) + np_torch_output = results.detach().cpu().numpy() + np_torch_output = imagearray_from_vae_out(np_torch_output) + return np_torch_output + + +def imagearray_from_vae_out(image): + if image.ndim == 4: + image = image[0] + image = torch.from_numpy(image).cpu().permute(1, 2, 0).float().numpy() + image = (image * 255).round().astype("uint8") + return image + + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + import numpy as np + from PIL import Image + + dtype = torch.float16 if args.precision == "fp16" else torch.float32 + if args.vae_variant == "decode": + example_input = torch.rand( + args.batch_size, 16, args.height // 8, args.width // 8, dtype=dtype + ) + if args.vae_input_path: + example_input = np.load(args.vae_input_path) + if example_input.shape[0] == 2: + example_input = np.split(example_input, 2)[0] + elif args.vae_variant == "encode": + example_input = torch.rand( + args.batch_size, 3, args.height, args.width, dtype=dtype + ) + print("generating turbine output:") + turbine_results = run_vae( + args.device, + example_input, + args.vmfb_path, + args.hf_model_name, + args.external_weight_path, + ) + print( + "TURBINE OUTPUT:", + turbine_results, + turbine_results.shape, + turbine_results.dtype, + ) + if args.compare_vs_torch: + print("generating torch output: ") + from turbine_models.custom_models.sd_inference import utils + + torch_output = run_torch_vae( + args.hf_model_name, args.vae_variant, torch.tensor(example_input).float() + ) + print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + if args.vae_input_path: + out_image_torch = Image.fromarray(torch_output) + out_image_torch.save("vae_test_output_torch.png") + out_image_turbine = Image.fromarray(turbine_results) + out_image_turbine.save("vae_test_output_turbine.png") + # Allow a small amount of wiggle room for rounding errors (1) + + np.testing.assert_allclose(turbine_results, torch_output, rtol=1, atol=1) diff --git a/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py b/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py new file mode 100644 index 000000000..747b60d9b --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py @@ -0,0 +1,790 @@ +### This file contains impls for underlying related models (CLIP, T5, etc) + +import torch, math +from torch import nn +from transformers import CLIPTokenizer, T5TokenizerFast +from shark_turbine import ops + +################################################################################################# +### Core/Utility +################################################################################################# + + +def attention(q, k, v, heads, mask=None): + """Convenience wrapper around a basic attention operation""" + b, _, dim_head = q.shape + # ops.iree.trace_tensor("attention_q", q[0,0,:5]) + # ops.iree.trace_tensor("attention_k", k[0,0,:5]) + # ops.iree.trace_tensor("attention_v", v[0,0,:5]) + dim_head //= heads + q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v)) + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + # ops.iree.trace_tensor("attention_out", out[0,0,:5]) + return out.transpose(1, 2).reshape(b, -1, heads * dim_head) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + bias=True, + dtype=None, + device=None, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear( + in_features, hidden_features, bias=bias, dtype=dtype, device=device + ) + self.act = act_layer + self.fc2 = nn.Linear( + hidden_features, out_features, bias=bias, dtype=dtype, device=device + ) + + def forward(self, x): + x = self.fc1(x) + # ops.iree.trace_tensor("mlpfx", x[0,0,:5]) + x = self.act(x) + # ops.iree.trace_tensor("mlpact", x[0,0,:5]) + x = self.fc2(x) + # ops.iree.trace_tensor("mlpanotherfc", x[0,0,:5]) + return x + + +def load_into(f, model, prefix, device, dtype=None): + """Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module.""" + for key in f.keys(): + if key.startswith(prefix) and not key.startswith("loss."): + path = key[len(prefix) :].split(".") + obj = model + for p in path: + if obj is list: + obj = obj[int(p)] + else: + obj = getattr(obj, p, None) + if obj is None: + print( + f"Skipping key '{key}' in safetensors file as '{p}' does not exist in python model" + ) + break + if obj is None: + continue + try: + tensor = f.get_tensor(key).to(device=device) + if dtype is not None: + tensor = tensor.to(dtype=dtype) + obj.requires_grad_(False) + obj.set_(tensor) + except Exception as e: + print(f"Failed to load key '{key}' in safetensors file: {e}") + raise e + + +################################################################################################# +### CLIP +################################################################################################# + + +class CLIPAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device): + super().__init__() + self.heads = heads + self.q_proj = nn.Linear( + embed_dim, embed_dim, bias=True, dtype=dtype, device=device + ) + self.k_proj = nn.Linear( + embed_dim, embed_dim, bias=True, dtype=dtype, device=device + ) + self.v_proj = nn.Linear( + embed_dim, embed_dim, bias=True, dtype=dtype, device=device + ) + self.out_proj = nn.Linear( + embed_dim, embed_dim, bias=True, dtype=dtype, device=device + ) + + def forward(self, x, mask=None): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + out = attention(q, k, v, self.heads, mask) + return self.out_proj(out) + + +ACTIVATIONS = { + "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), + "gelu": torch.nn.functional.gelu, +} + + +class CLIPLayer(torch.nn.Module): + def __init__( + self, + embed_dim, + heads, + intermediate_size, + intermediate_activation, + dtype, + device, + ): + super().__init__() + self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) + self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + # self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) + self.mlp = Mlp( + embed_dim, + intermediate_size, + embed_dim, + act_layer=ACTIVATIONS[intermediate_activation], + dtype=dtype, + device=device, + ) + + def forward(self, x, mask=None): + x += self.self_attn(self.layer_norm1(x), mask) + x += self.mlp(self.layer_norm2(x)) + return x + + +class CLIPEncoder(torch.nn.Module): + def __init__( + self, + num_layers, + embed_dim, + heads, + intermediate_size, + intermediate_activation, + dtype, + device, + ): + super().__init__() + self.layers = torch.nn.ModuleList( + [ + CLIPLayer( + embed_dim, + heads, + intermediate_size, + intermediate_activation, + dtype, + device, + ) + for i in range(num_layers) + ] + ) + + def forward(self, x, mask=None, intermediate_output=None): + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layers) + intermediate_output + intermediate = None + for i, l in enumerate(self.layers): + x = l(x, mask) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + + +class CLIPEmbeddings(torch.nn.Module): + def __init__( + self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None + ): + super().__init__() + self.token_embedding = torch.nn.Embedding( + vocab_size, embed_dim, dtype=dtype, device=device + ) + self.position_embedding = torch.nn.Embedding( + num_positions, embed_dim, dtype=dtype, device=device + ) + + def forward(self, input_tokens): + return self.token_embedding(input_tokens) + self.position_embedding.weight + + +class CLIPTextModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + super().__init__() + self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) + self.encoder = CLIPEncoder( + num_layers, + embed_dim, + heads, + intermediate_size, + intermediate_activation, + dtype, + device, + ) + self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + + def forward( + self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True + ): + x = self.embeddings(input_tokens) + causal_mask = ( + torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device) + .fill_(float("-inf")) + .triu_(1) + ) + x, i = self.encoder( + x, mask=causal_mask, intermediate_output=intermediate_output + ) + x = self.final_layer_norm(x) + if i is not None and final_layer_norm_intermediate: + i = self.final_layer_norm(i) + pooled_output = x[ + torch.arange(x.shape[0], device=x.device), + input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1), + ] + return x, i, pooled_output + + +class CLIPTextModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_hidden_layers"] + self.text_model = CLIPTextModel_(config_dict, dtype, device) + embed_dim = config_dict["hidden_size"] + self.text_projection = nn.Linear( + embed_dim, embed_dim, bias=False, dtype=dtype, device=device + ) + self.text_projection.weight.copy_(torch.eye(embed_dim)) + self.dtype = dtype + + def get_input_embeddings(self): + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, embeddings): + self.text_model.embeddings.token_embedding = embeddings + + def forward(self, *args, **kwargs): + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out, x[2]) + + +class SDTokenizer: + def __init__( + self, + max_length=77, + pad_with_end=True, + tokenizer=None, + has_start_token=True, + pad_to_max_length=True, + min_length=None, + ): + self.tokenizer = tokenizer + self.max_length = max_length + self.min_length = min_length + empty = self.tokenizer("")["input_ids"] + if has_start_token: + self.tokens_start = 1 + self.start_token = empty[0] + self.end_token = empty[1] + else: + self.tokens_start = 0 + self.start_token = None + self.end_token = empty[0] + self.pad_with_end = pad_with_end + self.pad_to_max_length = pad_to_max_length + vocab = self.tokenizer.get_vocab() + self.inv_vocab = {v: k for k, v in vocab.items()} + self.max_word_length = 8 + + def tokenize_with_weights(self, text: str): + """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" + if self.pad_with_end: + pad_token = self.end_token + else: + pad_token = 0 + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0)) + to_tokenize = text.replace("\n", " ").split(" ") + to_tokenize = [x for x in to_tokenize if x != ""] + for word in to_tokenize: + batch.extend( + [ + (t, 1) + for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1] + ] + ) + batch.append((self.end_token, 1.0)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + return [batch] + + +class SDXLClipGTokenizer(SDTokenizer): + def __init__(self, tokenizer): + super().__init__(pad_with_end=False, tokenizer=tokenizer) + + +class SD3Tokenizer: + def __init__(self): + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) + self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + self.t5xxl = T5XXLTokenizer() + + def tokenize_with_weights(self, text: str | list[str]): + out = {} + if isinstance(text, list): + text = text[0] + out["g"] = self.clip_g.tokenize_with_weights(text) + out["l"] = self.clip_l.tokenize_with_weights(text) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text) + for k, v in out.items(): + out[k] = torch.tensor(v, dtype=torch.int64, device="cpu") + return out + + +class ClipTokenWeightEncoder: + def encode_token_weights(self, token_weight_pairs): + # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + tokens = token_weight_pairs[:, :, 0] + out, pooled = self(tokens) + if pooled is not None: + first_pooled = pooled[0:1].cpu() + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2).cpu(), first_pooled + + +class SDClipModel(torch.nn.Module): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + device="cpu", + max_length=77, + layer="last", + layer_idx=None, + textmodel_json_config=None, + dtype=None, + model_class=CLIPTextModel, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, + layer_norm_hidden_state=True, + return_projected_pooled=True, + ): + super().__init__() + assert layer in self.LAYERS + self.transformer = model_class(textmodel_json_config, dtype, device) + self.num_layers = self.transformer.num_layers + self.max_length = max_length + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.set_clip_options({"layer": layer_idx}) + self.options_default = ( + self.layer, + self.layer_idx, + self.return_projected_pooled, + ) + + def encode_token_weights(self, token_weight_pairs): + pass + + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get( + "projected_pooled", self.return_projected_pooled + ) + if layer_idx is None or abs(layer_idx) > self.num_layers: + self.layer = "last" + else: + self.layer = "hidden" + self.layer_idx = layer_idx + + def forward(self, token_weight_pairs): + # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + tokens = token_weight_pairs[:, :, 0] + # backup_embeds = self.transformer.get_input_embeddings() + # device = backup_embeds.weight.device + # tokens = torch.LongTensor(tokens).to(device) + outputs = self.transformer( + tokens, + intermediate_output=self.layer_idx, + final_layer_norm_intermediate=self.layer_norm_hidden_state, + ) + # self.transformer.set_input_embeddings(backup_embeds) + if self.layer == "last": + z = outputs[0] + else: + z = outputs[1] + pooled_output = None + if len(outputs) >= 3: + if ( + not self.return_projected_pooled + and len(outputs) >= 4 + and outputs[3] is not None + ): + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() + out, pooled = z.float(), pooled_output + if pooled is not None: + first_pooled = pooled[0:1].cpu() + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2).cpu(), first_pooled + + +class SDXLClipG(SDClipModel): + """Wraps the CLIP-G model into the SD-CLIP-Model interface""" + + def __init__( + self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None + ): + if layer == "penultimate": + layer = "hidden" + layer_idx = -2 + super().__init__( + device=device, + layer=layer, + layer_idx=layer_idx, + textmodel_json_config=config, + dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 0}, + layer_norm_hidden_state=False, + ) + + +class T5XXLModel(SDClipModel): + """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" + + def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None): + super().__init__( + device=device, + layer=layer, + layer_idx=layer_idx, + textmodel_json_config=config, + dtype=dtype, + special_tokens={"end": 1, "pad": 0}, + model_class=T5, + ) + + +################################################################################################# +### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl +################################################################################################# + + +class T5XXLTokenizer(SDTokenizer): + """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + + def __init__(self): + super().__init__( + pad_with_end=False, + tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), + has_start_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=77, + ) + + +class T5LayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): + super().__init__() + self.weight = torch.nn.Parameter( + torch.ones(hidden_size, dtype=dtype, device=device) + ) + self.variance_epsilon = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + return self.weight.to(device=x.device, dtype=x.dtype) * x + + +class T5DenseGatedActDense(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") + hidden_linear = self.wi_1(x) + x = hidden_gelu * hidden_linear + x = self.wo(x) + return x + + +class T5LayerFF(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x): + forwarded_states = self.layer_norm(x) + forwarded_states = self.DenseReluDense(forwarded_states) + x += forwarded_states + return x + + +class T5Attention(torch.nn.Module): + def __init__( + self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device + ): + super().__init__() + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) + self.num_heads = num_heads + self.relative_attention_bias = None + if relative_attention_bias: + self.relative_attention_num_buckets = 32 + self.relative_attention_max_distance = 128 + self.relative_attention_bias = torch.nn.Embedding( + self.relative_attention_num_buckets, self.num_heads, device=device + ) + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + def compute_bias(self, query_length, key_length, device): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + relative_position = ( + memory_position - context_position + ) # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def forward(self, x, past_bias=None): + q = self.q(x) + k = self.k(x) + v = self.v(x) + if self.relative_attention_bias is not None: + past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) + if past_bias is not None: + mask = past_bias + out = attention( + q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask + ) + return self.o(out), past_bias + + +class T5LayerSelfAttention(torch.nn.Module): + def __init__( + self, + model_dim, + inner_dim, + ff_dim, + num_heads, + relative_attention_bias, + dtype, + device, + ): + super().__init__() + self.SelfAttention = T5Attention( + model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device + ) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x, past_bias=None): + output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) + x += output + return x, past_bias + + +class T5Block(torch.nn.Module): + def __init__( + self, + model_dim, + inner_dim, + ff_dim, + num_heads, + relative_attention_bias, + dtype, + device, + ): + super().__init__() + self.layer = torch.nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + model_dim, + inner_dim, + ff_dim, + num_heads, + relative_attention_bias, + dtype, + device, + ) + ) + self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) + + def forward(self, x, past_bias=None): + x, past_bias = self.layer[0](x, past_bias) + x = self.layer[-1](x) + return x, past_bias + + +class T5Stack(torch.nn.Module): + def __init__( + self, + num_layers, + model_dim, + inner_dim, + ff_dim, + num_heads, + vocab_size, + dtype, + device, + ): + super().__init__() + self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) + self.block = torch.nn.ModuleList( + [ + T5Block( + model_dim, + inner_dim, + ff_dim, + num_heads, + relative_attention_bias=(i == 0), + dtype=dtype, + device=device, + ) + for i in range(num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward( + self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True + ): + intermediate = None + x = self.embed_tokens(input_ids) + past_bias = None + for i, l in enumerate(self.block): + x, past_bias = l(x, past_bias) + if i == intermediate_output: + intermediate = x.clone() + x = self.final_layer_norm(x) + if intermediate is not None and final_layer_norm_intermediate: + intermediate = self.final_layer_norm(intermediate) + return x, intermediate + + +class T5(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_layers"] + self.encoder = T5Stack( + self.num_layers, + config_dict["d_model"], + config_dict["d_model"], + config_dict["d_ff"], + config_dict["num_heads"], + config_dict["vocab_size"], + dtype, + device, + ) + self.dtype = dtype + + def get_input_embeddings(self): + return self.encoder.embed_tokens + + def set_input_embeddings(self, embeddings): + self.encoder.embed_tokens = embeddings + + def forward(self, *args, **kwargs): + return self.encoder(*args, **kwargs) diff --git a/models/turbine_models/custom_models/sd_inference/clip.py b/models/turbine_models/custom_models/sd_inference/clip.py index e3e23661e..11705a916 100644 --- a/models/turbine_models/custom_models/sd_inference/clip.py +++ b/models/turbine_models/custom_models/sd_inference/clip.py @@ -9,54 +9,52 @@ from iree.compiler.ir import Context from shark_turbine.aot import * +from shark_turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor from turbine_models.turbine_tank import turbine_tank -import argparse - -parser = argparse.ArgumentParser() -parser.add_argument( - "--hf_auth_token", type=str, help="The Hugging Face auth token, required" -) -parser.add_argument( - "--hf_model_name", - type=str, - help="HF model name", - default="CompVis/stable-diffusion-v1-4", -) -parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") -parser.add_argument("--external_weight_path", type=str, default="") -parser.add_argument( - "--external_weights", - type=str, - default=None, - help="saves ir/vmfb without global weights for size and readability, options [safetensors]", -) -parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") -# TODO: Bring in detection for target triple -parser.add_argument( - "--iree_target_triple", - type=str, - default="", - help="Specify vulkan target triple or rocm/cuda target device.", -) -parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") - +@torch.no_grad() def export_clip_model( hf_model_name, - hf_auth_token=None, - compile_to="torch", - external_weights=None, - external_weight_path=None, - device=None, - target_triple=None, - max_alloc=None, - upload_ir=False, + batch_size: int = 1, + max_length: int = 64, + precision: str = "fp16", + compile_to: str = "torch", + external_weights: str = None, + external_weight_path: str = None, + device: str = "llvm-cpu", + target: str = "x86_64-linux-gnu", + ireec_flags: str = None, + exit_on_vmfb: bool = False, + pipeline_dir: str = None, + input_mlir: str = None, + attn_spec: str = None, + weights_only: bool = False, + upload_ir: bool = False, + decomp_attn: bool = False, ): - input_len = 77 + input_len = max_length + safe_name = utils.create_safe_name( + hf_model_name, f"_bs{batch_size}_{str(max_length)}-{precision}-clip" + ) + if pipeline_dir not in [None, ""]: + safe_name = os.path.join(pipeline_dir, safe_name) + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + const_expr_hoisting=True, + attn_spec=attn_spec, + ) + return vmfb_path if "google/t5" in hf_model_name: from transformers import T5Tokenizer, T5Model @@ -75,24 +73,26 @@ def export_clip_model( tokenizer = CLIPTokenizer.from_pretrained( hf_model_name, subfolder="tokenizer", - token=hf_auth_token, ) hf_subfolder = "text_encoder" text_encoder_model = CLIPTextModel.from_pretrained( hf_model_name, subfolder=hf_subfolder, - token=hf_auth_token, ) - + if precision == "fp16": + text_encoder_model = text_encoder_model.half() mapper = {} utils.save_external_weights( mapper, text_encoder_model, external_weights, external_weight_path ) + if weights_only: + return external_weight_path if "google/t5" in hf_model_name: + input_shapes = [(batch_size, input_len), (batch_size, input_len)] - class CompiledClip(CompiledModule): + class CompiledTextEncoder(CompiledModule): if external_weights: params = export_parameters( text_encoder_model, @@ -103,7 +103,7 @@ class CompiledClip(CompiledModule): else: params = export_parameters(text_encoder_model) - def main( + def encode_tokens( self, inp=AbstractTensor(1, input_len, dtype=torch.int64), decoder_input_ids=AbstractTensor(1, input_len, dtype=torch.int64), @@ -113,8 +113,9 @@ def main( ) else: + input_shapes = [str((batch_size, input_len)), str((batch_size, input_len))] - class CompiledClip(CompiledModule): + class CompiledTextEncoder(CompiledModule): if external_weights: params = export_parameters( text_encoder_model, @@ -125,45 +126,84 @@ class CompiledClip(CompiledModule): else: params = export_parameters(text_encoder_model) - def main(self, inp=AbstractTensor(1, input_len, dtype=torch.int64)): + def encode_tokens_attn_mask( + self, + inp=AbstractTensor(1, input_len, dtype=torch.int64), + attn_mask=AbstractTensor(1, input_len, dtype=torch.int64), + ): + return jittable(text_encoder_model.forward)( + input_ids=inp, attention_mask=attn_mask + ) + + def encode_tokens( + self, + inp=AbstractTensor(1, input_len, dtype=torch.int64), + ): return jittable(text_encoder_model.forward)(input_ids=inp) import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledClip(context=Context(), import_to=import_to) - - module_str = str(CompiledModule.get_mlir_module(inst)) - safe_name = utils.create_safe_name(hf_model_name, "-clip") - if upload_ir: - with open(f"{safe_name}.mlir", "w+") as f: - f.write(module_str) - model_name_upload = hf_model_name.replace("/", "_") - model_name_upload += "-clip" - blob_name = turbine_tank.uploadToBlobStorage( - str(os.path.abspath(f"{safe_name}.mlir")), - f"{model_name_upload}/{model_name_upload}.mlir", - ) + inst = CompiledTextEncoder(context=Context(), import_to=import_to) + module = CompiledModule.get_mlir_module(inst) + + model_metadata_attn_mask = { + "model_name": hf_model_name + "_text_encoder", + "input_shapes": input_shapes, + "input_dtypes": ["int64", "int64"], + "use_attention_mask": True, + } + model_metadata_encode = { + "model_name": hf_model_name + "_text_encoder", + "input_shapes": input_shapes[0], + "input_dtypes": ["int64"], + "use_attention_mask": False, + } + module = AddMetadataPass( + module, model_metadata_attn_mask, "encode_tokens_attn_mask" + ).run() + module = AddMetadataPass(module, model_metadata_encode, "encode_tokens").run() + + module_str = str(module) if compile_to != "vmfb": - return module_str, tokenizer + return module_str else: - utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) - if upload_ir: - return blob_name + vmfb_path = utils.compile_to_vmfb( + module_str, + device, + target, + ireec_flags, + safe_name, + return_path=not exit_on_vmfb, + const_expr_hoisting=True, + attn_spec=attn_spec, + ) + return vmfb_path if __name__ == "__main__": - args = parser.parse_args() + from turbine_models.custom_models.sd_inference.sd_cmd_opts import args + mod_str, _ = export_clip_model( args.hf_model_name, - args.hf_auth_token, + args.max_length, + args.precision, args.compile_to, args.external_weights, args.external_weight_path, args.device, args.iree_target_triple, - args.vulkan_max_allocation, + args.ireec_flags + args.clip_flags, + exit_on_vmfb=True, + pipeline_dir=args.pipeline_dir, + input_mlir=args.input_mlir, + attn_spec=args.attn_spec, + weights_only=False, + upload_ir=False, + ) + if args.input_mlir: + exit() + safe_name = utils.create_safe_name( + args.hf_model_name, f"{str(args.max_length)}_{args.precision}_clip" ) - safe_name = args.hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) with open(f"{safe_name}.mlir", "w+") as f: f.write(mod_str) print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sd_inference/clip_runner.py b/models/turbine_models/custom_models/sd_inference/clip_runner.py index a4cf677cb..da0908fad 100644 --- a/models/turbine_models/custom_models/sd_inference/clip_runner.py +++ b/models/turbine_models/custom_models/sd_inference/clip_runner.py @@ -4,48 +4,6 @@ from iree import runtime as ireert import torch -parser = argparse.ArgumentParser() - -# TODO move common runner flags to generic flag file -parser.add_argument( - "--vmfb_path", type=str, default="", help="path to vmfb containing compiled module" -) -parser.add_argument( - "--external_weight_path", - type=str, - default="", - help="path to external weight parameters if model compiled without them", -) -parser.add_argument( - "--compare_vs_torch", - action="store_true", - help="Runs both turbine vmfb and a torch model to compare results", -) -parser.add_argument( - "--hf_model_name", - type=str, - help="HF model name", - default="CompVis/stable-diffusion-v1-4", -) -parser.add_argument( - "--hf_auth_token", - type=str, - help="The Hugging face auth token, required for some models", -) -parser.add_argument( - "--device", - type=str, - default="local-task", - help="local-sync, local-task, cuda, vulkan, rocm", -) - -parser.add_argument( - "--prompt", - type=str, - default="a photograph of an astronaut riding a horse", - help="prompt for clip model", -) - def run_clip( device, prompt, vmfb_path, hf_model_name, hf_auth_token, external_weight_path @@ -97,7 +55,7 @@ def run_clip( if "google/t5" in hf_model_name: inp += [ireert.asdevicearray(runner.config.device, example_input)] - results = runner.ctx.modules.compiled_clip["main"](*inp) + results = runner.ctx.modules.compiled_text_encoder["encode_tokens"](*inp) return results @@ -168,7 +126,8 @@ def run_torch_clip(hf_model_name, hf_auth_token, prompt): if __name__ == "__main__": - args = parser.parse_args() + from turbine_models.custom_models.sd_inference.sd_cmd_opts import args + turbine_output = run_clip( args.device, args.prompt, diff --git a/models/turbine_models/custom_models/sdxl_inference/default_mfma_attn_spec.mlir b/models/turbine_models/custom_models/sd_inference/default_mfma_attn_spec.mlir similarity index 54% rename from models/turbine_models/custom_models/sdxl_inference/default_mfma_attn_spec.mlir rename to models/turbine_models/custom_models/sd_inference/default_mfma_attn_spec.mlir index 794c83d99..e7e1d8bf5 100644 --- a/models/turbine_models/custom_models/sdxl_inference/default_mfma_attn_spec.mlir +++ b/models/turbine_models/custom_models/sd_inference/default_mfma_attn_spec.mlir @@ -5,8 +5,8 @@ // TODO: Figure out how to parameterize the tile sizes without duplicating // the attention function. -#layout_16 = #iree_gpu.mfma_layout -#layout = #iree_gpu.mfma_layout +#layout_16 = #iree_gpu.mma_layout +#layout = #iree_gpu.mma_layout module attributes { transform.with_named_sequence } { //===----------------------------------------------------------------------===// @@ -27,7 +27,7 @@ module attributes { transform.with_named_sequence } { } // Script for FA2 transform pipeline when head_dim % 64 = 0. - transform.named_sequence @__attention_main(%variant_op: !transform.any_op {transform.consumed}) { + transform.named_sequence @__attention_main(%variant_op: !transform.any_op {transform.readonly}) { // Get attention op // ========================================== %attention = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op @@ -42,7 +42,7 @@ module attributes { transform.with_named_sequence } { // Tile batch dimensions of attention // ========================================== %attention2 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op - %batch_tiled_attn, %loop = transform.structured.tile_using_for %attention2 [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %batch_tiled_attn, %loop = transform.structured.tile_using_for %attention2 tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op transform.apply_patterns to %top_level_func { transform.apply_patterns.canonicalization @@ -150,32 +150,30 @@ module attributes { transform.with_named_sequence } { transform.apply_patterns.scf.for_loop_canonicalization } : !transform.any_op transform.apply_cse to %func_3 : !transform.any_op - transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> () + transform.iree.eliminate_empty_tensors %func_3 : (!transform.any_op) -> () transform.apply_patterns to %func_3 { transform.apply_patterns.linalg.erase_unnecessary_inputs } : !transform.any_op - %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op) + %memref_func = transform.iree.bufferize { target_gpu } %func_3 : (!transform.any_op) -> (!transform.any_op) // Step 5. Pre-process the contract and transfer ops to put it in the right form. // =========================================================================== - %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_2 { - transform.apply_patterns.iree.fold_arith_ext_into_contraction + transform.apply_patterns to %memref_func { + transform.apply_patterns.vector.fold_arith_extension } : !transform.any_op // Step 6. Post-bufferization vector distribution // =========================================================================== - %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op - transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> () - transform.iree.map_nested_forall_to_gpu_threads %func_7 workgroup_dims = [64, 4, 1] subgroup_size = 64 : (!transform.any_op) -> () + transform.iree.forall_to_workgroup %memref_func : (!transform.any_op) -> () + transform.iree.map_nested_forall_to_gpu_threads %memref_func workgroup_dims = [64, 4, 1] subgroup_size = 64 : (!transform.any_op) -> () - transform.apply_patterns to %func_7 { + transform.apply_patterns to %memref_func { transform.apply_patterns.memref.fold_memref_alias_ops } : !transform.any_op - transform.iree.apply_licm %func_7 : !transform.any_op - transform.apply_patterns to %func_7 { + transform.iree.apply_licm %memref_func : !transform.any_op + transform.apply_patterns to %memref_func { transform.apply_patterns.canonicalization } : !transform.any_op - transform.apply_cse to %func_7 : !transform.any_op - %func_8 = transform.structured.hoist_redundant_vector_transfers %func_7 + transform.apply_cse to %memref_func : !transform.any_op + %func_8 = transform.structured.hoist_redundant_vector_transfers %memref_func : (!transform.any_op) -> !transform.any_op transform.apply_patterns to %func_8 { transform.apply_patterns.canonicalization @@ -187,17 +185,15 @@ module attributes { transform.with_named_sequence } { transform.apply_registered_pass "iree-amdgpu-prepare-chained-matmul" to %func_8 : (!transform.any_op) -> (!transform.any_op) // Get the vector.contract ops. - %contracts = transform.structured.match ops{["vector.contract"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + %contracts = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op %contract1, %contract2 = transform.split_handle %contracts : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %layout16x16x16 = transform.param.constant #layout -> !transform.any_param transform.iree.set_contraction_layout_attributes %contract1, %layout16x16x16 { read_layout_indices = array } : !transform.any_op, !transform.any_param transform.iree.set_contraction_layout_attributes %contract2, %layout16x16x16 : !transform.any_op, !transform.any_param - %distribute_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op - transform.iree.amdgpu_distribute_vectors %distribute_func : !transform.any_op - - %distribute_func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + %distribute_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %distribute_func_2 = transform.iree.amdgpu_distribute_vectors %distribute_func : (!transform.any_op) -> !transform.any_op transform.apply_patterns to %distribute_func_2 { transform.apply_patterns.canonicalization @@ -206,34 +202,32 @@ module attributes { transform.with_named_sequence } { // Distribute shared memory copies // ========================================== - %func_10 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op - transform.iree.gpu_distribute_shared_memory_copy %func_10 : (!transform.any_op) -> () - transform.apply_patterns to %func_10 { + transform.iree.gpu_distribute_shared_memory_copy %distribute_func_2 : (!transform.any_op) -> () + transform.apply_patterns to %distribute_func_2 { transform.apply_patterns.memref.fold_memref_alias_ops transform.apply_patterns.canonicalization transform.apply_patterns.linalg.tiling_canonicalization } : !transform.any_op - transform.apply_cse to %func_10 : !transform.any_op + transform.apply_cse to %distribute_func_2 : !transform.any_op - %forop = transform.structured.match ops{["scf.for"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + %forop = transform.structured.match ops{["scf.for"]} in %distribute_func_2 : (!transform.any_op) -> !transform.any_op %prefetched_forop = transform.iree.prefetch_shared_memory_copies %forop : (!transform.any_op) -> (!transform.any_op) - transform.apply_patterns to %func_10 { + transform.apply_patterns to %distribute_func_2 { transform.apply_patterns.memref.fold_memref_alias_ops transform.apply_patterns.canonicalization transform.apply_patterns.linalg.tiling_canonicalization } : !transform.any_op - transform.apply_cse to %func_10 : !transform.any_op + transform.apply_cse to %distribute_func_2 : !transform.any_op - %func_11 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op - transform.amdgpu.optimize_shared_memory_reads_and_writes %func_11 : (!transform.any_op) -> () + transform.iree.reduce_shared_memory_bank_conflicts %distribute_func_2 : (!transform.any_op) -> () transform.yield } // Script for FA2 transform pipeline for head_dim = 512. // For head_dim = 512, since the matmul is so big, and just try to do a single wave big load + big mfma. - transform.named_sequence @__attention_main_len_512(%variant_op: !transform.any_op {transform.consumed}) { + transform.named_sequence @__attention_main_len_512(%variant_op: !transform.any_op {transform.readonly}) { // Get attention op // ========================================== %attention = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op @@ -248,7 +242,7 @@ module attributes { transform.with_named_sequence } { // Tile batch dimensions of attention // ========================================== %attention2 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op - %batch_tiled_attn, %loop = transform.structured.tile_using_for %attention2 [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %batch_tiled_attn, %loop = transform.structured.tile_using_for %attention2 tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op transform.apply_patterns to %top_level_func { transform.apply_patterns.canonicalization @@ -356,32 +350,30 @@ module attributes { transform.with_named_sequence } { transform.apply_patterns.scf.for_loop_canonicalization } : !transform.any_op transform.apply_cse to %func_3 : !transform.any_op - transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> () + transform.iree.eliminate_empty_tensors %func_3 : (!transform.any_op) -> () transform.apply_patterns to %func_3 { transform.apply_patterns.linalg.erase_unnecessary_inputs } : !transform.any_op - %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op) + %memref_func = transform.iree.bufferize { target_gpu } %func_3 : (!transform.any_op) -> (!transform.any_op) // Step 5. Pre-process the contract and transfer ops to put it in the right form. // =========================================================================== - %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_2 { - transform.apply_patterns.iree.fold_arith_ext_into_contraction + transform.apply_patterns to %memref_func { + transform.apply_patterns.vector.fold_arith_extension } : !transform.any_op // Step 6. Post-bufferization vector distribution // =========================================================================== - %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op - transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> () - transform.iree.map_nested_forall_to_gpu_threads %func_7 workgroup_dims = [64, 4, 1] subgroup_size = 64 : (!transform.any_op) -> () + transform.iree.forall_to_workgroup %memref_func : (!transform.any_op) -> () + transform.iree.map_nested_forall_to_gpu_threads %memref_func workgroup_dims = [64, 4, 1] subgroup_size = 64 : (!transform.any_op) -> () - transform.apply_patterns to %func_7 { + transform.apply_patterns to %memref_func { transform.apply_patterns.memref.fold_memref_alias_ops } : !transform.any_op - transform.iree.apply_licm %func_7 : !transform.any_op - transform.apply_patterns to %func_7 { + transform.iree.apply_licm %memref_func : !transform.any_op + transform.apply_patterns to %memref_func { transform.apply_patterns.canonicalization } : !transform.any_op - transform.apply_cse to %func_7 : !transform.any_op - %func_8 = transform.structured.hoist_redundant_vector_transfers %func_7 + transform.apply_cse to %memref_func : !transform.any_op + %func_8 = transform.structured.hoist_redundant_vector_transfers %memref_func : (!transform.any_op) -> !transform.any_op transform.apply_patterns to %func_8 { transform.apply_patterns.canonicalization @@ -392,20 +384,18 @@ module attributes { transform.with_named_sequence } { // Apply chained matmul optimization. transform.apply_registered_pass "iree-amdgpu-prepare-chained-matmul" to %func_8 : (!transform.any_op) -> (!transform.any_op) - // transform.print %variant_op_3 : !transform.any_op + // transform.print %memref_func : !transform.any_op // Get the vector.contract ops. - %contracts = transform.structured.match ops{["vector.contract"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + %contracts = transform.structured.match ops{["vector.contract"]} in %variant_op : (!transform.any_op) -> !transform.any_op %contract1, %contract2 = transform.split_handle %contracts : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %layout16x16x16 = transform.param.constant #layout_16 -> !transform.any_param transform.iree.set_contraction_layout_attributes %contract1, %layout16x16x16 { read_layout_indices = array } : !transform.any_op, !transform.any_param transform.iree.set_contraction_layout_attributes %contract2, %layout16x16x16 : !transform.any_op, !transform.any_param - %distribute_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op - transform.iree.amdgpu_distribute_vectors %distribute_func : !transform.any_op - - %distribute_func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + %distribute_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %distribute_func_2 = transform.iree.amdgpu_distribute_vectors %distribute_func : (!transform.any_op) -> !transform.any_op transform.apply_patterns to %distribute_func_2 { transform.apply_patterns.canonicalization @@ -414,7 +404,7 @@ module attributes { transform.with_named_sequence } { // Distribute shared memory copies // ========================================== - %func_10 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + %func_10 = transform.structured.match ops{["func.func"]} in %distribute_func_2 : (!transform.any_op) -> !transform.any_op transform.iree.gpu_distribute_shared_memory_copy %func_10 : (!transform.any_op) -> () transform.apply_patterns to %func_10 { transform.apply_patterns.memref.fold_memref_alias_ops @@ -423,7 +413,7 @@ module attributes { transform.with_named_sequence } { } : !transform.any_op transform.apply_cse to %func_10 : !transform.any_op - %forop = transform.structured.match ops{["scf.for"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + %forop = transform.structured.match ops{["scf.for"]} in %distribute_func_2 : (!transform.any_op) -> !transform.any_op %prefetched_forop = transform.iree.prefetch_shared_memory_copies %forop : (!transform.any_op) -> (!transform.any_op) transform.apply_patterns to %func_10 { @@ -433,18 +423,17 @@ module attributes { transform.with_named_sequence } { } : !transform.any_op transform.apply_cse to %func_10 : !transform.any_op - %func_11 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op - transform.amdgpu.optimize_shared_memory_reads_and_writes %func_11 : (!transform.any_op) -> () + %func_11 = transform.structured.match ops{["func.func"]} in %distribute_func_2 : (!transform.any_op) -> !transform.any_op + transform.iree.reduce_shared_memory_bank_conflicts %func_11 : (!transform.any_op) -> () transform.yield } // Send it down a custom transform dialect pipeline. transform.named_sequence @custom_attention_len_512(%attention: !transform.any_op {transform.readonly}) { - %variant_op = transform.get_parent_op %attention {op_name = "hal.executable.variant"} : (!transform.any_op) -> !transform.any_op - %exports = transform.structured.match ops{["hal.executable.export"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %func = transform.get_parent_op %attention {op_name = "func.func"} : (!transform.any_op) -> !transform.any_op %attn = transform.param.constant #iree_codegen.translation_info -> !transform.any_param - transform.annotate %exports "translation_info" = %attn : !transform.any_op, !transform.any_param + transform.annotate %func "translation_info" = %attn : !transform.any_op, !transform.any_param transform.yield } @@ -457,10 +446,9 @@ module attributes { transform.with_named_sequence } { // Send it down a custom transform dialect pipeline. transform.named_sequence @custom_attention(%attention: !transform.any_op {transform.readonly}) { - %variant_op = transform.get_parent_op %attention {op_name = "hal.executable.variant"} : (!transform.any_op) -> !transform.any_op - %exports = transform.structured.match ops{["hal.executable.export"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %func = transform.get_parent_op %attention {op_name = "func.func"} : (!transform.any_op) -> !transform.any_op %attn = transform.param.constant #iree_codegen.translation_info -> !transform.any_param - transform.annotate %exports "translation_info" = %attn : !transform.any_op, !transform.any_param + transform.annotate %func "translation_info" = %attn : !transform.any_op, !transform.any_param transform.yield } @@ -529,36 +517,14 @@ module attributes { transform.with_named_sequence } { transform.iree.match.cast_compatible_type %lhs = tensor<2048x1280xf16> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<10240x1280xf16> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 2, subgroup_n_count = 2, - subgroup_m_tile_count = 4, - subgroup_n_tile_count = 4, - subgroup_k_tile_count = 2>, no_reorder_workgroups}>, - workgroup_size = [128, 2, 1], subgroup_size = 64 - > -> !transform.any_param - transform.yield %matmul, %config : !transform.any_op, !transform.any_param - } - - transform.named_sequence @match_mmt_2048x1280x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { - %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op - %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value - %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value - transform.iree.match.cast_compatible_type %lhs = tensor<2048x1280xf16> : !transform.any_value - transform.iree.match.cast_compatible_type %rhs = tensor<1280x1280xf16> : !transform.any_value - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, + translation_info = #iree_codegen.translation_info, - subgroup_m_count = 2, subgroup_n_count = 1, - subgroup_m_tile_count = 2, - subgroup_n_tile_count = 5, - subgroup_k_tile_count = 4>}>, - workgroup_size = [64, 2, 1], subgroup_size = 64 - > -> !transform.any_param + intrinsic = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 2> + , no_reorder_workgroups}> + > -> !transform.any_param transform.yield %matmul, %config : !transform.any_op, !transform.any_param } @@ -569,56 +535,33 @@ module attributes { transform.with_named_sequence } { transform.iree.match.cast_compatible_type %lhs = tensor<2048x5120xf16> : !transform.any_value transform.iree.match.cast_compatible_type %rhs = tensor<1280x5120xf16> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, + translation_info = #iree_codegen.translation_info, - subgroup_m_count = 2, subgroup_n_count = 2, - subgroup_m_tile_count = 2, - subgroup_n_tile_count = 5, - subgroup_k_tile_count = 4>}>, - workgroup_size = [128, 2, 1], subgroup_size = 64 - > -> !transform.any_param + intrinsic = #iree_gpu.mma_layout, + subgroup_m_count = 4, subgroup_n_count = 1> + , no_reorder_workgroups, llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param transform.yield %matmul, %config : !transform.any_op, !transform.any_param } - transform.named_sequence @match_mmt_128x1280x2048(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { - %mmt = transform.include @match_mmt_f16_f16_f16 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op - %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value - %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value - transform.iree.match.cast_compatible_type %lhs = tensor<128x2048xf16> : !transform.any_value - transform.iree.match.cast_compatible_type %rhs = tensor<1280x2048xf16> : !transform.any_value - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 2, subgroup_n_count = 2, - subgroup_m_tile_count = 1, - subgroup_n_tile_count = 1, - subgroup_k_tile_count = 16>}>, - workgroup_size = [128, 2, 1], subgroup_size = 64 - > -> !transform.any_param - transform.yield %matmul, %config : !transform.any_op, !transform.any_param - } - transform.named_sequence @match_mmt_8192x640x2560(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + transform.named_sequence @match_mmt_2048x1280x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value - transform.iree.match.cast_compatible_type %lhs = tensor<8192x2560xf16> : !transform.any_value - transform.iree.match.cast_compatible_type %rhs = tensor<640x2560xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<2048x1280xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<1280x1280xf16> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, + translation_info = #iree_codegen.translation_info, - subgroup_m_count = 2, subgroup_n_count = 2, - subgroup_m_tile_count = 2, - subgroup_n_tile_count = 5, - subgroup_k_tile_count = 4>}>, - workgroup_size = [128, 2, 1], subgroup_size = 64 - > -> !transform.any_param + intrinsic = #iree_gpu.mma_layout, + subgroup_m_count = 4, subgroup_n_count = 1> + }> + > -> !transform.any_param transform.yield %matmul, %config : !transform.any_op, !transform.any_param } @@ -630,374 +573,115 @@ module attributes { transform.with_named_sequence } { transform.iree.match.cast_compatible_type %rhs = tensor<5120x640xf16> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 2, subgroup_n_count = 2, - subgroup_m_tile_count = 4, - subgroup_n_tile_count = 4, - subgroup_k_tile_count = 2>}>, - workgroup_size = [128, 2, 1], subgroup_size = 64 - > -> !transform.any_param + intrinsic = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 2> + , no_reorder_workgroups}> + > -> !transform.any_param transform.yield %matmul, %config : !transform.any_op, !transform.any_param } - transform.named_sequence @match_mmt_128x640x2048(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + transform.named_sequence @match_mmt_8192x640x2560(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value - transform.iree.match.cast_compatible_type %lhs = tensor<128x2048xf16> : !transform.any_value - transform.iree.match.cast_compatible_type %rhs = tensor<640x2048xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<8192x2560xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<640x2560xf16> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, + translation_info = #iree_codegen.translation_info, - subgroup_m_count = 2, subgroup_n_count = 1, - subgroup_m_tile_count = 1, - subgroup_n_tile_count = 1, - subgroup_k_tile_count = 32>}>, - workgroup_size = [64, 2, 1], subgroup_size = 64 - > -> !transform.any_param + intrinsic = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 2> + , llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param transform.yield %matmul, %config : !transform.any_op, !transform.any_param } -//===----------------------------------------------------------------------===// -// Convolution tuning -//===----------------------------------------------------------------------===// - - transform.named_sequence @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280(%conv: !transform.any_op {transform.readonly}) - -> (!transform.any_op, !transform.any_param) { - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { - ^bb0(%lhs: tensor<2x?x?x1280xf16>, %rhs: tensor<3x3x1280x1280xf16>, %out: tensor<2x32x32x1280xf32>): - %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } - ins(%lhs, %rhs : tensor<2x?x?x1280xf16>, tensor<3x3x1280x1280xf16>) - outs(%out : tensor<2x32x32x1280xf32>) -> tensor<2x32x32x1280xf32> - } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 2, subgroup_n_count = 5, - subgroup_m_tile_count = 1, - subgroup_n_tile_count = 1, - subgroup_k_tile_count = 8>}>, - workgroup_size = [640, 1, 1], subgroup_size = 64> -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param - } - - transform.named_sequence @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1920(%conv: !transform.any_op {transform.readonly}) - -> (!transform.any_op, !transform.any_param) { - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { - ^bb0(%lhs: tensor<2x?x?x1920xf16>, %rhs: tensor<3x3x1920x1280xf16>, %out: tensor<2x32x32x1280xf32>): - %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } - ins(%lhs, %rhs : tensor<2x?x?x1920xf16>, tensor<3x3x1920x1280xf16>) - outs(%out : tensor<2x32x32x1280xf32>) -> tensor<2x32x32x1280xf32> - } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 2, subgroup_n_count = 5, - subgroup_m_tile_count = 1, - subgroup_n_tile_count = 1, - subgroup_k_tile_count = 8>}>, - workgroup_size = [640, 1, 1], subgroup_size = 64> -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param - } - - transform.named_sequence @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x2560(%conv: !transform.any_op {transform.readonly}) - -> (!transform.any_op, !transform.any_param) { - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { - ^bb0(%lhs: tensor<2x?x?x2560xf16>, %rhs: tensor<3x3x2560x1280xf16>, %out: tensor<2x32x32x1280xf32>): - %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } - ins(%lhs, %rhs : tensor<2x?x?x2560xf16>, tensor<3x3x2560x1280xf16>) - outs(%out : tensor<2x32x32x1280xf32>) -> tensor<2x32x32x1280xf32> - } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 2, subgroup_n_count = 5, - subgroup_m_tile_count = 1, - subgroup_n_tile_count = 1, - subgroup_k_tile_count = 8>}>, - workgroup_size = [640, 1, 1], subgroup_size = 64> -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param - } - - transform.named_sequence @match_conv_2d_nhwc_hwcf_2x64x64x640x3x3x640(%conv: !transform.any_op {transform.readonly}) - -> (!transform.any_op, !transform.any_param) { - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { - ^bb0(%lhs: tensor<2x?x?x640xf16>, %rhs: tensor<3x3x640x640xf16>, %out: tensor<2x64x64x640xf32>): - %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } - ins(%lhs, %rhs : tensor<2x?x?x640xf16>, tensor<3x3x640x640xf16>) - outs(%out : tensor<2x64x64x640xf32>) -> tensor<2x64x64x640xf32> - } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 2, subgroup_n_count = 5, - subgroup_m_tile_count = 2, - subgroup_n_tile_count = 1, - subgroup_k_tile_count = 4>}>, - workgroup_size = [640, 1, 1], subgroup_size = 64> -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param - } - - transform.named_sequence @match_conv_2d_nhwc_hwcf_2x64x64x640x3x3x1280(%conv: !transform.any_op {transform.readonly}) - -> (!transform.any_op, !transform.any_param) { - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { - ^bb0(%lhs: tensor<2x?x?x1280xf16>, %rhs: tensor<3x3x1280x640xf16>, %out: tensor<2x64x64x640xf32>): - %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } - ins(%lhs, %rhs : tensor<2x?x?x1280xf16>, tensor<3x3x1280x640xf16>) - outs(%out : tensor<2x64x64x640xf32>) -> tensor<2x64x64x640xf32> - } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 1, subgroup_n_count = 10, - subgroup_m_tile_count = 2, - subgroup_n_tile_count = 1, - subgroup_k_tile_count = 10>}>, - workgroup_size = [640, 1, 1], subgroup_size = 64> -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param - } - - transform.named_sequence @match_conv_2d_nhwc_hwcf_2x64x64x640x3x3x1920(%conv: !transform.any_op {transform.readonly}) - -> (!transform.any_op, !transform.any_param) { - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { - ^bb0(%lhs: tensor<2x?x?x1920xf16>, %rhs: tensor<3x3x1920x640xf16>, %out: tensor<2x64x64x640xf32>): - %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } - ins(%lhs, %rhs : tensor<2x?x?x1920xf16>, tensor<3x3x1920x640xf16>) - outs(%out : tensor<2x64x64x640xf32>) -> tensor<2x64x64x640xf32> - } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 1, subgroup_n_count = 10, - subgroup_m_tile_count = 2, - subgroup_n_tile_count = 1, - subgroup_k_tile_count = 10>}>, - workgroup_size = [640, 1, 1], subgroup_size = 64> -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param - } - - transform.named_sequence @match_conv_2d_nhwc_hwcf_2x64x64x1280x3x3x1280(%conv: !transform.any_op {transform.readonly}) - -> (!transform.any_op, !transform.any_param) { - transform.match.operation_name %conv ["linalg.conv_2d_nhwc_hwcf"] : !transform.any_op - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { - ^bb0(%lhs: tensor<2x66x66x1280xf16>, %rhs: tensor<3x3x1280x1280xf16>, %out: tensor<2x64x64x1280xf32>): - %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } - ins(%lhs, %rhs : tensor<2x66x66x1280xf16>, tensor<3x3x1280x1280xf16>) - outs(%out : tensor<2x64x64x1280xf32>) -> tensor<2x64x64x1280xf32> - } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 1, subgroup_n_count = 10, - subgroup_m_tile_count = 2, - subgroup_n_tile_count = 1, - subgroup_k_tile_count = 10>}>, - workgroup_size = [640, 1, 1], subgroup_size = 64> -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param - } - - transform.named_sequence @match_conv_2d_nhwc_hwcf_2x128x128x320x3x3x320(%conv: !transform.any_op {transform.readonly}) - -> (!transform.any_op, !transform.any_param) { - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { - ^bb0(%lhs: tensor<2x?x?x320xf16>, %rhs: tensor<3x3x320x320xf16>, %out: tensor<2x128x128x320xf32>): - %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } - ins(%lhs, %rhs : tensor<2x?x?x320xf16>, tensor<3x3x320x320xf16>) - outs(%out : tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32> - } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 1, subgroup_n_count = 5, - subgroup_m_tile_count = 4, - subgroup_n_tile_count = 2, - subgroup_k_tile_count = 5>}>, - workgroup_size = [320, 1, 1], subgroup_size = 64> -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param - } - - transform.named_sequence @match_conv_2d_nhwc_hwcf_2x128x128x320x3x3x640(%conv: !transform.any_op {transform.readonly}) - -> (!transform.any_op, !transform.any_param) { - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { - ^bb0(%lhs: tensor<2x?x?x640xf16>, %rhs: tensor<3x3x640x320xf16>, %out: tensor<2x128x128x320xf32>): - %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } - ins(%lhs, %rhs : tensor<2x?x?x640xf16>, tensor<3x3x640x320xf16>) - outs(%out : tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32> - } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 1, subgroup_n_count = 5, - subgroup_m_tile_count = 4, - subgroup_n_tile_count = 2, - subgroup_k_tile_count = 5>}>, - workgroup_size = [320, 1, 1], subgroup_size = 64> -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param - } - - transform.named_sequence @match_conv_2d_nhwc_hwcf_2x128x128x320x3x3x960(%conv: !transform.any_op {transform.readonly}) - -> (!transform.any_op, !transform.any_param) { - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { - ^bb0(%lhs: tensor<2x?x?x960xf16>, %rhs: tensor<3x3x960x320xf16>, %out: tensor<2x128x128x320xf32>): - %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } - ins(%lhs, %rhs : tensor<2x?x?x960xf16>, tensor<3x3x960x320xf16>) - outs(%out : tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32> - } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 1, subgroup_n_count = 5, - subgroup_m_tile_count = 4, - subgroup_n_tile_count = 2, - subgroup_k_tile_count = 5>}>, - workgroup_size = [320, 1, 1], subgroup_size = 64> -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param - } - - transform.named_sequence @match_conv_2d_nhwc_hwcf_2x128x128x640x3x3x640(%conv: !transform.any_op {transform.readonly}) - -> (!transform.any_op, !transform.any_param) { - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { - ^bb0(%lhs: tensor<2x?x?x640xf16>, %rhs: tensor<3x3x640x640xf16>, %out: tensor<2x128x128x640xf32>): - %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } - ins(%lhs, %rhs : tensor<2x?x?x640xf16>, tensor<3x3x640x640xf16>) - outs(%out : tensor<2x128x128x640xf32>) -> tensor<2x128x128x640xf32> - } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + transform.named_sequence @match_mmt_8192x640x640(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<8192x640xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<640x640xf16> : !transform.any_value %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 1, subgroup_n_count = 4, - subgroup_m_tile_count = 4, - subgroup_n_tile_count = 1, - subgroup_k_tile_count = 4>}>, - workgroup_size = [256, 1, 1], subgroup_size = 64> -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 4, subgroup_n_count = 1> + , no_reorder_workgroups, llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param } //===----------------------------------------------------------------------===// // Contraction tuning //===----------------------------------------------------------------------===// - transform.named_sequence @match_contract_2x1024x1280x20x64(%contract: !transform.any_op {transform.readonly}) + transform.named_sequence @match_contract_3x2x20x1024x64x1280(%contract: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %contract { - ^bb0(%lhs: tensor<2x20x1024x64xf16>, %rhs: tensor<1280x20x64xf16>, %out: tensor<2x1024x1280xf32>): + ^bb0(%lhs: tensor<2x1024x1280xf16>, %rhs: tensor<3x20x64x1280xf16>, %out: tensor<3x2x20x1024x64xf32>): %20 = linalg.generic { - indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d1, d4)>, - affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, - affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], - iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] - } ins(%lhs, %rhs : tensor<2x20x1024x64xf16>, tensor<1280x20x64xf16>) - outs(%out : tensor<2x1024x1280xf32>) { - ^bb0(%in: f16, %in_0: f16, %acc: f32): - %22 = arith.extf %in : f16 to f32 - %23 = arith.extf %in_0 : f16 to f32 - %24 = arith.mulf %22, %23 : f32 - %25 = arith.addf %acc, %24 : f32 - linalg.yield %25 : f32 - } -> tensor<2x1024x1280xf32> - } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 2, subgroup_n_count = 2, - subgroup_m_tile_count = 2, - subgroup_n_tile_count = 5, - subgroup_k_tile_count = 4>}>, - workgroup_size = [128, 2, 1], subgroup_size = 64> -> !transform.any_param - // transform.print %contract {name = "Contract"} : !transform.any_op - transform.yield %contract, %config : !transform.any_op, !transform.any_param - } - - transform.named_sequence @match_contract_2x2x20x64x64x2048(%contract: !transform.any_op {transform.readonly}) - -> (!transform.any_op, !transform.any_param) { - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %contract { - ^bb0(%lhs: tensor<2x64x2048xf16>, %rhs: tensor<2x20x64x2048xf16>, %out: tensor<2x2x20x64x64xf32>): - %10 = linalg.generic { - indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d5)>, - affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>, - affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)>], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] - } ins(%lhs, %rhs : tensor<2x64x2048xf16>, tensor<2x20x64x2048xf16>) - outs(%out : tensor<2x2x20x64x64xf32>) { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] + } ins(%lhs, %rhs : tensor<2x1024x1280xf16>, tensor<3x20x64x1280xf16>) + outs(%out : tensor<3x2x20x1024x64xf32>) { ^bb0(%in: f16, %in_0: f16, %acc: f32): - %12 = arith.extf %in : f16 to f32 - %13 = arith.extf %in_0 : f16 to f32 - %14 = arith.mulf %12, %13 : f32 - %15 = arith.addf %acc, %14 : f32 - linalg.yield %15 : f32 - } -> tensor<2x2x20x64x64xf32> + %22 = arith.extf %in : f16 to f32 + %23 = arith.extf %in_0 : f16 to f32 + %24 = arith.mulf %22, %23 : f32 + %25 = arith.addf %acc, %24 : f32 + linalg.yield %25 : f32 + } -> tensor<3x2x20x1024x64xf32> } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 2, subgroup_n_count = 1, - subgroup_m_tile_count = 1, - subgroup_n_tile_count = 1, - subgroup_k_tile_count = 8>}>, - workgroup_size = [64, 2, 1], subgroup_size = 64> -> !transform.any_param - // transform.print %contract {name = "Contract"} : !transform.any_op + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 4, subgroup_n_count = 1> + , llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param transform.yield %contract, %config : !transform.any_op, !transform.any_param } - transform.named_sequence @match_contract_3x2x20x64x64x1280(%contract: !transform.any_op {transform.readonly}) + transform.named_sequence @match_contract_2x20x64x64x2048(%contract: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %contract { - ^bb0(%lhs: tensor<2x1024x1280xf16>, %rhs: tensor<3x20x64x1280xf16>, %out: tensor<3x2x20x1024x64xf32>): - %14 = linalg.generic { - indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d5)>, - affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>, - affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)>], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] - } ins(%lhs, %rhs : tensor<2x1024x1280xf16>, tensor<3x20x64x1280xf16>) - outs(%out : tensor<3x2x20x1024x64xf32>) { - ^bb0(%in: f16, %in_0: f16, %acc: f32): - %16 = arith.extf %in : f16 to f32 - %17 = arith.extf %in_0 : f16 to f32 - %18 = arith.mulf %16, %17 : f32 - %19 = arith.addf %acc, %18 : f32 - linalg.yield %19 : f32 - } -> tensor<3x2x20x1024x64xf32> + ^bb0(%lhs: tensor<2x64x2048xf16>, %rhs: tensor<20x64x2048xf16>, %out: tensor<2x20x64x64xf32>): + %14 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"] + } ins(%lhs, %rhs : tensor<2x64x2048xf16>, tensor<20x64x2048xf16>) + outs(%out : tensor<2x20x64x64xf32>) { + ^bb0(%in: f16, %in_0: f16, %acc: f32): + %16 = arith.extf %in : f16 to f32 + %17 = arith.extf %in_0 : f16 to f32 + %18 = arith.mulf %16, %17 : f32 + %19 = arith.addf %acc, %18 : f32 + linalg.yield %19 : f32 + } -> tensor<2x20x64x64xf32> } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = 2, subgroup_n_count = 2, - subgroup_m_tile_count = 2, - subgroup_n_tile_count = 1, - subgroup_k_tile_count = 8>}>, - workgroup_size = [128, 2, 1], subgroup_size = 64> -> !transform.any_param - // transform.print %contract {name = "Contract"} : !transform.any_op + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2> + }> + > -> !transform.any_param transform.yield %contract, %config : !transform.any_op, !transform.any_param } @@ -1009,31 +693,19 @@ module attributes { transform.with_named_sequence } { transform.foreach_match in %variant_op // Attention. @match_attention_len_512 -> @custom_attention_len_512, - @match_attention -> @custom_attention, - // Matmul tuning. - @match_mmt_2048x10240x1280 -> @apply_op_config, - @match_mmt_2048x1280x1280 -> @apply_op_config, - @match_mmt_2048x1280x5120 -> @apply_op_config, - @match_mmt_128x1280x2048 -> @apply_op_config, - @match_mmt_128x640x2048 -> @apply_op_config, - @match_mmt_8192x640x2560 -> @apply_op_config, - @match_mmt_8192x5120x640 -> @apply_op_config, - // Convolution tuning. - @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280 -> @apply_op_config, - @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1920 -> @apply_op_config, - @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x2560 -> @apply_op_config, - @match_conv_2d_nhwc_hwcf_2x64x64x640x3x3x640 -> @apply_op_config, - @match_conv_2d_nhwc_hwcf_2x64x64x640x3x3x1280 -> @apply_op_config, - @match_conv_2d_nhwc_hwcf_2x64x64x640x3x3x1920 -> @apply_op_config, - @match_conv_2d_nhwc_hwcf_2x64x64x1280x3x3x1280 -> @apply_op_config, - @match_conv_2d_nhwc_hwcf_2x128x128x320x3x3x320 -> @apply_op_config, - @match_conv_2d_nhwc_hwcf_2x128x128x320x3x3x640 -> @apply_op_config, - @match_conv_2d_nhwc_hwcf_2x128x128x320x3x3x960 -> @apply_op_config, - @match_conv_2d_nhwc_hwcf_2x128x128x640x3x3x640 -> @apply_op_config, - // Contract tuning. - @match_contract_2x1024x1280x20x64 -> @apply_op_config, - @match_contract_2x2x20x64x64x2048 -> @apply_op_config, - @match_contract_3x2x20x64x64x1280 -> @apply_op_config + @match_attention -> @custom_attention + + // Matmul. + , @match_mmt_2048x10240x1280 -> @apply_op_config + , @match_mmt_2048x1280x5120 -> @apply_op_config + , @match_mmt_2048x1280x1280 -> @apply_op_config + , @match_mmt_8192x5120x640 -> @apply_op_config + , @match_mmt_8192x640x2560 -> @apply_op_config + , @match_mmt_8192x640x640 -> @apply_op_config + + // Contration. + , @match_contract_3x2x20x1024x64x1280 -> @apply_op_config + , @match_contract_2x20x64x64x2048 -> @apply_op_config : (!transform.any_op) -> (!transform.any_op) transform.yield } diff --git a/models/turbine_models/custom_models/sd_inference/schedulers.py b/models/turbine_models/custom_models/sd_inference/schedulers.py index c7af11bc5..0a6e36cc1 100644 --- a/models/turbine_models/custom_models/sd_inference/schedulers.py +++ b/models/turbine_models/custom_models/sd_inference/schedulers.py @@ -5,188 +5,446 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import os -import sys +from typing import List import torch -from torch.fx.experimental.proxy_tensor import make_fx from shark_turbine.aot import * -from iree import runtime as ireert -import iree.compiler as ireec +import shark_turbine.ops.iree as ops +from shark_turbine.transforms.general.add_metadata import AddMetadataPass from iree.compiler.ir import Context +import iree.runtime as ireert import numpy as np -from turbine_models.custom_models.sd_inference import utils from diffusers import ( - UNet2DConditionModel, + LCMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDPMScheduler, + DPMSolverSDEScheduler, + DDIMScheduler, + DPMSolverMultistepScheduler, + KDPM2DiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DEISMultistepScheduler, + DPMSolverSinglestepScheduler, + KDPM2AncestralDiscreteScheduler, + HeunDiscreteScheduler, ) -import safetensors -import argparse - from turbine_models.turbine_tank import turbine_tank +from turbine_models.custom_models.sd_inference import utils +from turbine_models.model_runner import vmfbRunner -parser = argparse.ArgumentParser() -parser.add_argument( - "--hf_auth_token", type=str, help="The Hugging Face auth token, required" -) -parser.add_argument( - "--hf_model_name", - type=str, - help="HF model name", - default="CompVis/stable-diffusion-v1-4", -) -parser.add_argument( - "--scheduler_id", - type=str, - help="Scheduler ID", - default="PNDM", -) -parser.add_argument( - "--num_inference_steps", type=int, default=50, help="Number of inference steps" -) -parser.add_argument( - "--batch_size", type=int, default=1, help="Batch size for inference" -) -parser.add_argument( - "--height", type=int, default=512, help="Height of Stable Diffusion" -) -parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") -parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") -parser.add_argument("--external_weight_path", type=str, default="") -parser.add_argument( - "--external_weights", - type=str, - default=None, - help="saves ir/vmfb without global weights for size and readability, options [safetensors]", -) -parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") -# TODO: Bring in detection for target triple -parser.add_argument( - "--iree_target_triple", - type=str, - default="", - help="Specify vulkan target triple or rocm/cuda target device.", -) -parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") + +class SharkSchedulerWrapper: + def __init__(self, rt_device, vmfb): + self.runner = vmfbRunner(rt_device, vmfb, None) + + def initialize(self, sample): + sample, time_ids, steps, timesteps = self.runner.ctx.modules.compiled_scheduler[ + "run_initialize" + ](sample) + return sample, time_ids, steps.to_host(), timesteps + + def scale_model_input(self, sample, t, timesteps): + return self.runner.ctx.modules.compiled_scheduler["run_scale"]( + sample, t, timesteps + ) + + def step(self, noise_pred, t, sample, step_index): + return self.runner.ctx.modules.compiled_scheduler["run_step"]( + noise_pred, t, sample, guidance_scale, step_index + ) -class Scheduler(torch.nn.Module): - def __init__(self, hf_model_name, num_inference_steps, scheduler): +class SchedulingModel(torch.nn.Module): + def __init__( + self, + hf_model_name, + scheduler, + height, + width, + batch_size, + num_inference_steps, + dtype, + ): super().__init__() - self.scheduler = scheduler - self.scheduler.set_timesteps(num_inference_steps) - self.unet = UNet2DConditionModel.from_pretrained( - hf_model_name, - subfolder="unet", + # For now, assumes SDXL implementation. May not need parametrization for other models, + # but keeping hf_model_name in case. + self.model = scheduler + self.height = height + self.width = width + self.is_sd3 = False + if "stable-diffusion-3" in hf_model_name: + self.is_sd3 = True + self.batch_size = batch_size + # Whether this will be used with CFG-enabled pipeline. + self.do_classifier_free_guidance = True + + self.model.set_timesteps(num_inference_steps) + self.timesteps = self.model.timesteps + self.model.is_scale_input_called = True + self.dtype = dtype + + # TODO: Make steps dynamic here + def initialize(self, sample): + height = self.height + width = self.width + original_size = (height, width) + target_size = (height, width) + crops_coords_top_left = (0, 0) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=self.dtype) + if self.do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + add_time_ids = add_time_ids.repeat(self.batch_size, 1).type(self.dtype) + step_count = torch.tensor(len(self.timesteps)) + timesteps = self.model.timesteps + # ops.trace_tensor("timesteps", self.timesteps) + sample = sample * self.model.init_noise_sigma + return ( + sample.type(self.dtype), + add_time_ids, + step_count, + timesteps.type(torch.float32), + ) + + def prepare_model_input(self, sample, i, timesteps): + t = timesteps[i] + + latent_model_input = sample + return self.model.scale_model_input(latent_model_input, t).type( + self.dtype + ), t.type(self.dtype) + + def step(self, noise_pred, t, sample): + self.model._step_index = self.model.index_for_timestep(t) + + sample = self.model.step(noise_pred, t, sample, return_dict=False)[0] + return sample.type(self.dtype) + + +class SharkSchedulerCPUWrapper: + @torch.no_grad() + def __init__( + self, + scheduler, + batch_size, + dest_device, + latents_dtype, + conditional_timesteps=False, + ): + self.module = scheduler + self.dest = dest_device + self.batch_size = batch_size + self.timesteps = None + self.do_classifier_free_guidance = True + self.do_guidance = True + self.repeat_sample = True + + # Enable this on init for models that use a pair of timestep values per unet step. + # this includes sd3 and some others we don't support yet. + # It allows passage of 'uncond_t' to the scale_model_input function and repeats the + # default timestep value if no 'uncond_t' is passed. + self.conditional_timesteps = conditional_timesteps + + self.dtype = latents_dtype + self.torch_dtype = ( + torch.float32 if latents_dtype == "float32" else torch.float16 ) - self.guidance_scale = 7.5 - - def forward(self, latents, encoder_hidden_states) -> torch.FloatTensor: - latents = latents * self.scheduler.init_noise_sigma - for t in self.scheduler.timesteps: - latent_model_input = torch.cat([latents] * 2) - t = t.unsqueeze(0) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, timestep=t + + def initialize_sdxl(self, sample, num_inference_steps): + if isinstance(sample, ireert.DeviceArray): + sample = torch.tensor(sample.to_host(), dtype=torch.float32) + + self.module.set_timesteps(num_inference_steps) + self.timesteps = self.module.timesteps + height = sample.shape[2] * 8 + width = sample.shape[3] * 8 + original_size = (height, width) + target_size = (height, width) + crops_coords_top_left = (0, 0) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=self.torch_dtype) + if self.do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + add_time_ids = add_time_ids.repeat(self.batch_size, 1).type( + self.torch_dtype ) - unet_out = self.unet.forward( - latent_model_input, t, encoder_hidden_states, return_dict=False - )[0] - noise_pred_uncond, noise_pred_text = unet_out.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( + step_indexes = torch.tensor(len(self.timesteps)) + timesteps = self.timesteps + sample = sample * self.module.init_noise_sigma + return sample, add_time_ids, step_indexes, timesteps + + def initialize_sd(self, sample, num_inference_steps): + if isinstance(sample, ireert.DeviceArray): + sample = torch.tensor(sample.to_host(), dtype=torch.float32) + self.module.set_timesteps(num_inference_steps) + timesteps = self.module.timesteps + sample = sample * self.module.init_noise_sigma + return sample, timesteps + + def scale_model_input(self, sample, t, t_uncond=None): + if self.repeat_sample: + sample = torch.cat([sample] * 2) + if self.conditional_timesteps: + if t_uncond: + t = torch.tensor([t, t_uncond]) + else: + t = torch.tensor([t, t]) + else: + t = torch.tensor([t]) + scaled = self.module.scale_model_input(sample, t) + return scaled, t + + def step(self, noise_pred, t, latents, guidance_scale=None): + if isinstance(t, ireert.DeviceArray): + t = torch.tensor(t.to_host()) + if isinstance(noise_pred, ireert.DeviceArray): + noise_pred = torch.tensor(noise_pred.to_host()) + elif isinstance(noise_pred, np.ndarray): + noise_pred = torch.tensor(noise_pred) + if isinstance(guidance_scale, ireert.DeviceArray): + guidance_scale = torch.tensor(guidance_scale.to_host()) + if self.do_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - return latents - - -def export_scheduler( - scheduler, - hf_model_name, - batch_size, - height, - width, - hf_auth_token=None, - compile_to="torch", - external_weights=None, - external_weight_path=None, - device=None, - target_triple=None, - max_alloc=None, + return self.module.step( + noise_pred, + t, + latents, + ).prev_sample + + +@torch.no_grad() +def export_scheduler_model( + hf_model_name: str, + scheduler_id: str, + batch_size: int = 1, + height: int = 512, + width: int = 512, + num_inference_steps: int = 30, + precision: str = "fp16", + compile_to: str = "torch", + device: str = None, + target: str = None, + ireec_flags: str = None, + exit_on_vmfb: bool = False, + pipeline_dir: str = None, + input_mlir: str = None, + attn_spec: str = None, + external_weights: str = None, + external_weight_path: str = None, upload_ir=False, ): - mapper = {} - utils.save_external_weights( - mapper, scheduler, external_weights, external_weight_path + dtype = torch.float16 if precision == "fp16" else torch.float32 + iree_dtype = "float16" if precision == "fp16" else "float32" + scheduler = get_scheduler(hf_model_name, scheduler_id) + scheduler_module = SchedulingModel( + hf_model_name, scheduler, height, width, batch_size, num_inference_steps, dtype ) - encoder_hidden_states_sizes = (2, 77, 768) - if hf_model_name == "stabilityai/stable-diffusion-2-1-base": - encoder_hidden_states_sizes = (2, 77, 1024) + vmfb_names = [ + scheduler_id + "Scheduler", + f"bs{batch_size}", + f"{height}x{width}", + precision, + str(num_inference_steps), + ] + vmfb_name = "_".join(vmfb_names) + safe_name = utils.create_safe_name(hf_model_name, "_" + vmfb_name) + if pipeline_dir: + safe_name = os.path.join(pipeline_dir, safe_name) - sample = (batch_size, 4, height // 8, width // 8) + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + ) + return vmfb_path - class CompiledScheduler(CompiledModule): - if external_weights: - params = export_parameters( - scheduler, external=True, external_scope="", name_mapper=mapper.get - ) - else: - params = export_parameters(scheduler) + sample = ( + batch_size, + 4, + height // 8, + width // 8, + ) + noise_pred_shape = ( + batch_size, + 4, + height // 8, + width // 8, + ) + example_init_args = [torch.empty(sample, dtype=dtype)] + example_prep_args = ( + torch.empty(sample, dtype=dtype), + torch.empty(1, dtype=torch.int64), + torch.empty([19], dtype=torch.float32), + ) + timesteps = torch.export.Dim("timesteps") + prep_dynamic_args = { + "sample": {}, + "t": {}, + "timesteps": {0: timesteps}, + } + example_step_args = [ + torch.empty(noise_pred_shape, dtype=dtype), + torch.empty(1, dtype=dtype), + torch.empty(sample, dtype=dtype), + ] + + fxb = FxProgramsBuilder(scheduler_module) + + @fxb.export_program( + args=(example_init_args,), + ) + def _initialize(module, sample): + return module.initialize(*sample) + + @fxb.export_program( + args=example_prep_args, + dynamic_shapes=prep_dynamic_args, + ) + def _scale(module, sample, t, timesteps): + return module.prepare_model_input(sample, t, timesteps) + + @fxb.export_program( + args=(example_step_args,), + ) + def _step(module, inputs): + return module.step(*inputs) + + decomp_list = [] + # if decomp_attn == True: + # decomp_list.extend( + # [ + # torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + # torch.ops.aten._scaled_dot_product_flash_attention.default, + # ] + # ) + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): - def main( - self, - sample=AbstractTensor(*sample, dtype=torch.float32), - encoder_hidden_states=AbstractTensor( - *encoder_hidden_states_sizes, dtype=torch.float32 - ), - ): - return jittable(scheduler.forward)(sample, encoder_hidden_states) + class CompiledScheduler(CompiledModule): + run_initialize = _initialize + run_scale = _scale + run_step = _step import_to = "INPUT" if compile_to == "linalg" else "IMPORT" inst = CompiledScheduler(context=Context(), import_to=import_to) - module_str = str(CompiledModule.get_mlir_module(inst)) - safe_name = utils.create_safe_name(hf_model_name, "-scheduler") - if upload_ir: - with open(f"{safe_name}.mlir", "w+") as f: - f.write(module_str) - model_name_upload = hf_model_name.replace("/", "-") - model_name_upload = model_name_upload + "_scheduler" - blob_name = turbine_tank.uploadToBlobStorage( - str(os.path.abspath(f"{safe_name}.mlir")), - f"{model_name_upload}/{model_name_upload}.mlir", - ) + module = CompiledModule.get_mlir_module(inst) + metadata_modelname = "_".join( + [hf_model_name, scheduler_id, "scheduler", str(num_inference_steps)] + ) + model_metadata_init = { + "model_name": metadata_modelname, + "input_shapes": [sample], + "input_dtypes": [iree_dtype], + } + model_metadata_prep = { + "model_name": metadata_modelname, + "input_shapes": [sample, (1,), ("?",)], + "input_dtypes": [iree_dtype, "int64", "float32"], + } + model_metadata_step = { + "model_name": metadata_modelname, + "input_shapes": [noise_pred_shape, (1,), sample], + "input_dtypes": [iree_dtype, iree_dtype, iree_dtype], + } + module = AddMetadataPass(module, model_metadata_init, "run_initialize").run() + module = AddMetadataPass(module, model_metadata_prep, "run_scale").run() + module = AddMetadataPass(module, model_metadata_step, "run_step").run() + module_str = str(module) if compile_to != "vmfb": return module_str + elif compile_to == "vmfb": + vmfb = utils.compile_to_vmfb( + module_str, + device, + target, + ireec_flags, + safe_name, + return_path=True, + ) + return vmfb + + +def get_scheduler(model_id, scheduler_id): + # TODO: switch over to turbine and run all on GPU + print(f"\n[LOG] Initializing schedulers from model id: {model_id}") + if scheduler_id in SCHEDULER_MAP.keys(): + scheduler = SCHEDULER_MAP[scheduler_id].from_pretrained( + model_id, subfolder="scheduler" + ) + elif all(x in scheduler_id for x in ["DPMSolverMultistep", "++"]): + scheduler = DPMSolverMultistepScheduler.from_pretrained( + model_id, subfolder="scheduler", algorithm_type="dpmsolver++" + ) else: - utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) - if upload_ir: - return blob_name + raise ValueError(f"Scheduler {scheduler_id} not found.") + if "Karras" in scheduler_id: + scheduler.config.use_karras_sigmas = True + + return scheduler + +SCHEDULER_MAP = { + "PNDM": PNDMScheduler, + "DDPM": DDPMScheduler, + "KDPM2Discrete": KDPM2DiscreteScheduler, + "LMSDiscrete": LMSDiscreteScheduler, + "DDIM": DDIMScheduler, + "LCMScheduler": LCMScheduler, + "EulerDiscrete": EulerDiscreteScheduler, + "EulerAncestralDiscrete": EulerAncestralDiscreteScheduler, + "DEISMultistep": DEISMultistepScheduler, + "DPMSolverSinglestep": DPMSolverSinglestepScheduler, + "KDPM2AncestralDiscrete": KDPM2AncestralDiscreteScheduler, + "HeunDiscrete": HeunDiscreteScheduler, + "DPMSolverMultistepKarras": DPMSolverMultistepScheduler, + "DPMSolverMultistep": DPMSolverMultistepScheduler, + "DPMSolverSDE": DPMSolverSDEScheduler, + "DPMSolverSDEKarras": DPMSolverSDEScheduler, +} if __name__ == "__main__": - args = parser.parse_args() - schedulers = utils.get_schedulers(args.hf_model_name) - scheduler = schedulers[args.scheduler_id] - scheduler_module = Scheduler( - args.hf_model_name, args.num_inference_steps, scheduler - ) - mod_str = export_scheduler( - scheduler_module, + from turbine_models.custom_models.sd_inference.sd_cmd_opts import args + + mod_str = export_scheduler_model( args.hf_model_name, + args.scheduler_id, args.batch_size, args.height, args.width, - args.hf_auth_token, + args.num_inference_steps, + args.precision, args.compile_to, - args.external_weights, - args.external_weight_path, args.device, args.iree_target_triple, - args.vulkan_max_allocation, + args.ireec_flags, + exit_on_vmfb=False, + input_mlir=args.input_mlir, ) - safe_name = utils.create_safe_name(args.hf_model_name, "-scheduler") - with open(f"{safe_name}.mlir", "w+") as f: - f.write(mod_str) - print("Saved to", safe_name + ".mlir") + vmfb_names = [ + args.scheduler_id + "Scheduler", + f"_bs{args.batch_size}_{args.height}x{args.width}", + args.precision, + str(args.num_inference_steps), + args.iree_target_triple, + ] + safe_name = "_".join(vmfb_names) + if args.compile_to != "vmfb": + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sd_inference/schedulers_runner.py b/models/turbine_models/custom_models/sd_inference/schedulers_runner.py index 45663c0a6..54b9c47f1 100644 --- a/models/turbine_models/custom_models/sd_inference/schedulers_runner.py +++ b/models/turbine_models/custom_models/sd_inference/schedulers_runner.py @@ -4,66 +4,13 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import argparse from turbine_models.model_runner import vmfbRunner from iree import runtime as ireert import torch from diffusers import ( - PNDMScheduler, UNet2DConditionModel, ) -parser = argparse.ArgumentParser() - -# TODO move common runner flags to generic flag file -parser.add_argument( - "--scheduler_id", - type=str, - help="Scheduler ID", - default="PNDM", -) -parser.add_argument( - "--num_inference_steps", type=int, default=50, help="Number of inference steps" -) -parser.add_argument( - "--vmfb_path", type=str, default="", help="path to vmfb containing compiled module" -) -parser.add_argument( - "--external_weight_path", - type=str, - default="", - help="path to external weight parameters if model compiled without them", -) -parser.add_argument( - "--compare_vs_torch", - action="store_true", - help="Runs both turbine vmfb and a torch model to compare results", -) -parser.add_argument( - "--hf_model_name", - type=str, - help="HF model name", - default="stabilityai/stable-diffusion-xl-base-1.0", -) -parser.add_argument( - "--hf_auth_token", - type=str, - help="The Hugging face auth token, required for some models", -) -parser.add_argument( - "--device", - type=str, - default="local-task", - help="local-sync, local-task, cuda, vulkan, rocm", -) -parser.add_argument( - "--batch_size", type=int, default=1, help="Batch size for inference" -) -parser.add_argument( - "--height", type=int, default=1024, help="Height of Stable Diffusion" -) -parser.add_argument("--width", type=int, default=1024, help="Width of Stable Diffusion") - def run_scheduler( device, @@ -197,7 +144,8 @@ def forward(self, sample, prompt_embeds, text_embeds, time_ids): if __name__ == "__main__": - args = parser.parse_args() + from turbine_models.custom_models.sd_inference.sd_cmd_opts import args + sample = torch.rand( args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32 ) diff --git a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py new file mode 100644 index 000000000..aa5fa4a15 --- /dev/null +++ b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py @@ -0,0 +1,307 @@ +import argparse +import os +from pathlib import Path + + +def path_expand(s): + return Path(s).expanduser().resolve() + + +def is_valid_file(arg): + if not os.path.exists(arg): + return None + else: + return arg + + +# Note: this is where command-line options for the scripts in this directory +# are defined along with their defaults. Thus, they should not be referenced +# within modelling or inference code, only at the entry point to the script. + +# We should consider separating out the options that are "model configs" from +# the options that control the compiler, runtime, and script behavior, +# when applicable, as the former would best be kept in a separate +# config or imported from huggingface. + +p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter +) + +############################################################################## +# SDXL Huggingface Options +############################################################################## + +p.add_argument( + "--hf_auth_token", + type=str, + help="The Hugging Face auth token, if required", + default=None, +) +p.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="stabilityai/stable-diffusion-2-1", +) +p.add_argument( + "--scheduler_id", + type=str, + help="Scheduler ID", + default="EulerDiscrete", +) + +############################################################################## +# SDXL Inference Options +# These options are used to control runtime parameters for SDXL inference. +############################################################################## + +p.add_argument( + "--prompt", + type=str, + default=" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + help="Prompt input to stable diffusion.", +) + +p.add_argument( + "--negative_prompt", + type=str, + default="Watermark, blurry, oversaturated, low resolution, pollution", + help="Negative prompt input to stable diffusion.", +) + +p.add_argument( + "--num_inference_steps", type=int, default=30, help="Number of UNet inference steps" +) + +p.add_argument( + "--batch_count", + type=int, + default=1, + help="Number of batches to run for a single prompt", +) + +p.add_argument( + "--guidance_scale", + type=float, + default=7.5, + help="Scale by which to adjust prompt guidance to the unconditional noise prediction output of UNet after each iteration.", +) + +p.add_argument( + "--seed", type=float, default=0, help="Seed for random number/latents generation." +) + +p.add_argument( + "--external_weight_path", + type=str, + default="", + help="Path to external weights file, for jobs with one weights filepath. When importing, this is used to specify where to save the model weights, and at runtime, this is used to specify where to load the model weights from.", +) + +p.add_argument( + "--external_weights_dir", + type=str, + default="./weights", + help="Directory containing external weights for a job that requires more than one weights file. When importing, this is used to specify where to save the model weights, and at runtime, this is used to specify where to load the model weights from. Files will then be saved according to the parameters that make them unique, i.e. ___.", +) + +p.add_argument( + "--vmfb_path", type=str, default="", help="path to vmfb containing compiled module" +) + +p.add_argument( + "--pipeline_vmfb_path", + type=str, + default="", + help="path to vmfb containing compiled meta-module", +) + +p.add_argument( + "--external_weight_file", + type=str, + default=None, + help="Path to external weights, used in benchmark scripts.", +) + +p.add_argument( + "--pipeline_dir", + type=str, + default="./vmfbs", + help="Directory to save pipeline artifacts", +) + +p.add_argument( + "--compiled_pipeline", + default=False, + action="store_true", + help="Do one-shot inference from tokens to image in a shrink-wrapped pipeline binary.", +) + +p.add_argument( + "--cpu_scheduling", + default=False, + action="store_true", + help="Run scheduling on native pytorch CPU backend.", +) + +############################################################################## +# SDXL Modelling Options +# These options are used to control model defining parameters for SDXL. +# These are MLIR - changing variables! If you change them, you will need +# to import/download and recompile the model. +############################################################################## + +p.add_argument("--batch_size", type=int, default=1, help="Batch size for inference") +p.add_argument( + "--height", type=int, default=512, help="Height of Stable Diffusion output image." +) +p.add_argument( + "--width", type=int, default=512, help="Width of Stable Diffusion output image" +) +p.add_argument( + "--precision", + type=str, + default="fp16", + help="Precision of Stable Diffusion weights and graph.", +) +p.add_argument( + "--max_length", type=int, default=64, help="Sequence Length of Stable Diffusion" +) +p.add_argument("--vae_variant", type=str, default="decode", help="encode, decode") +p.add_argument( + "--return_index", + action="store_true", + help="Make scheduled unet compiled module return the step index.", +) + +p.add_argument( + "--vae_decomp_attn", + action="store_true", + help="Decompose attention for VAE decode only at fx graph level", +) + +p.add_argument( + "--unet_decomp_attn", + action="store_true", + help="Decompose attention for unet only at fx graph level", +) + +p.add_argument( + "--use_i8_punet", + action="store_true", + help="Use i8 quantized Partitioned UNet for inference", +) + +############################################################################## +# SDXL script general options. +############################################################################## + +p.add_argument("--compile_to", type=str, default="mlir", help="torch, linalg, vmfb") + +p.add_argument( + "--external_weights", + type=str, + default=None, + choices=["safetensors", "irpa", "gguf", None], + help="Externalizes model weights from the torch dialect IR and its successors", +) + +# See --external_weight_path and --external_weight_dir to specify where to save the model weights. + +p.add_argument( + "--compare_vs_torch", + action="store_true", + help="Runs both turbine vmfb and a torch model to compare results", +) +p.add_argument( + "--decomp_attn", + default=False, + action="store_true", + help="Decompose attention at fx graph level", +) +p.add_argument( + "--exit_on_vmfb", + default=True, + action="store_false", + help="Exit program on vmfb compilation completion. Most scripts will also save .mlir if this is disabled.", +) +p.add_argument( + "--input_mlir", + type=str, + default=None, + help="Path to input mlir file to compile. Comma-separate paths to provide more than one input to pipelines.", +) +p.add_argument( + "--download_mlir", + default=False, + action="store_true", + help="Download missing mlir files from Azure storage.", +) +p.add_argument( + "--container_name", + type=str, + default=None, + help="Azure storage container name to download mlir files from.", +) + + +############################################################################## +# IREE Compiler Options +############################################################################## + +p.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") + +p.add_argument( + "--rt_device", + type=str, + default="local-task", + help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.", +) + +# TODO: Bring in detection for target triple +p.add_argument( + "--iree_target_triple", + type=str, + default="x86_64-linux-gnu", + help="Specify vulkan target triple or rocm/cuda target device.", +) + +p.add_argument("--ireec_flags", type=str, default="", help="extra iree-compile options") + +p.add_argument( + "--attn_flags", + type=str, + default="", + help="extra iree-compile options for models with iree_linalg_ext.attention ops.", +) + +p.add_argument( + "--attn_spec", + type=str, + default=None, + help="extra iree-compile options for models with iree_linalg_ext.attention ops. Set this to 'default' if you are using mfma-capable hardware with ROCM.", +) + +p.add_argument( + "--clip_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling CLIP/prompt_encoder. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + +p.add_argument( + "--vae_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling VAE. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + +p.add_argument( + "--unet_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling unet. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + + +args, unknown = p.parse_known_args() diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py new file mode 100644 index 000000000..e1d3ae940 --- /dev/null +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -0,0 +1,748 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import copy +import torch +import iree.runtime as ireert +from random import randint +from tqdm.auto import tqdm +from turbine_models.custom_models.sd_inference import ( + clip, + unet, + vae, + schedulers, + utils, +) +from turbine_models.custom_models.sdxl_inference import ( + sdxl_prompt_encoder as sdxl_clip, + unet as sdxl_unet, +) +from turbine_models.custom_models.sd3_inference import ( + sd3_text_encoders, + sd3_mmdit, +) +from turbine_models.custom_models.pipeline_base import ( + TurbinePipelineBase, + merge_arg_into_map, +) +from turbine_models.custom_models.sd_inference.tokenization import encode_prompt +from turbine_models.model_runner import vmfbRunner +from transformers import CLIPTokenizer +from pathlib import Path + +from PIL import Image +import os +import numpy as np +import time +from datetime import datetime as dt + +# These are arguments common among submodel exports. +# They are expected to be populated in two steps: +# First, by the child class, +# and second by the base class for inference task-agnostic args. + +sd1_sd2_model_map = { + "text_encoder": { + "module_name": "compiled_text_encoder", + "keywords": ["clip"], + "dest_type": "torch", + "export_fn": clip.export_clip_model, + "export_args": { + "batch_size": 1, + "max_length": 64, + }, + }, + "unet": { + "module_name": "compiled_unet", + "keywords": ["unet"], + "export_fn": unet.export_unet_model, + "export_args": { + "batch_size": 1, + "height": 512, + "width": 512, + "max_length": 64, + "decomp_attn": None, + }, + }, + "vae": { + "module_name": "compiled_vae", + "keywords": ["vae"], + "dest_type": "numpy", + "export_fn": vae.export_vae_model, + "export_args": { + "batch_size": 1, + "height": 512, + "width": 512, + "num_channels": 4, + "decomp_attn": None, + }, + }, +} +sdxl_model_map = { + "text_encoder": { + "module_name": "compiled_clip", + "keywords": ["prompt_encoder"], + "dest_type": "torch", + "export_fn": sdxl_clip.export_prompt_encoder, + "export_args": { + "batch_size": 1, + "max_length": 64, + }, + }, + "unet": { + "module_name": "compiled_unet", + "keywords": ["unet", "!loop"], + "export_fn": sdxl_unet.export_unet_model, + "export_args": { + "batch_size": 1, + "height": 1024, + "width": 1024, + "max_length": 64, + "decomp_attn": None, + }, + }, + "vae": { + "module_name": "compiled_vae", + "keywords": ["vae"], + "dest_type": "numpy", + "export_fn": vae.export_vae_model, + "export_args": { + "batch_size": 1, + "height": 1024, + "width": 1024, + "num_channels": 4, + "decomp_attn": None, + }, + }, + "unetloop": { + "module_name": "sdxl_compiled_pipeline", + "load": False, + "keywords": ["unetloop"], + "wraps": ["unet", "scheduler"], + "export_args": { + "batch_size": 1, + "height": 1024, + "width": 1024, + "max_length": 64, + }, + }, + "fullpipeline": { + "module_name": "sdxl_compiled_pipeline", + "load": False, + "keywords": ["fullpipeline"], + "wraps": ["text_encoder", "unet", "scheduler", "vae"], + "export_args": { + "batch_size": 1, + "height": 1024, + "width": 1024, + "max_length": 64, + }, + }, +} +sd3_model_map = { + "text_encoder": { + "module_name": "compiled_text_encoder", + "keywords": ["text_encoder"], + "export_fn": sd3_text_encoders.export_text_encoders, + "export_args": { + "batch_size": 1, + "max_length": 64, + }, + }, + "mmdit": { + "module_name": "compiled_mmdit", + "keywords": ["mmdit"], + "export_fn": sd3_mmdit.export_mmdit_model, + "export_args": { + "batch_size": 1, + "height": 1024, + "width": 1024, + "max_length": 64, + "decomp_attn": None, + }, + }, + "vae": { + "module_name": "compiled_vae", + "keywords": ["vae"], + "dest_type": "numpy", + "export_fn": vae.export_vae_model, + "export_args": { + "batch_size": 1, + "height": 1024, + "width": 1024, + "num_channels": 16, + "decomp_attn": None, + }, + }, +} + + +def get_sd_model_map(hf_model_name): + if isinstance(hf_model_name, dict): + name = hf_model_name["text_encoder"] + else: + name = hf_model_name + if name in [ + "stabilityai/sdxl-turbo", + "stabilityai/stable-diffusion-xl-base-1.0", + "/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16//checkpoint_pipe", + ]: + return sdxl_model_map + elif "stabilityai/stable-diffusion-3" in name: + return sd3_model_map + else: + return sd1_sd2_model_map + + +torch_dtypes = { + "fp32": torch.float32, + "fp16": torch.float16, + "float32": torch.float32, + "float16": torch.float16, + "int8": torch.int8, + "i8": torch.int8, +} + + +class SharkSDPipeline(TurbinePipelineBase): + def __init__( + self, + hf_model_name: str | dict[str], + height: int, + width: int, + batch_size: int, + max_length: int | dict[int], + precision: str | dict[str], + device: str | dict[str], + target: str | dict[str], + ireec_flags: str | dict[str] = None, + attn_spec: str | dict[str] = None, + decomp_attn: bool | dict[bool] = False, + pipeline_dir: str = "./shark_vmfbs", + external_weights_dir: str = "./shark_weights", + external_weights: str | dict[str] = "safetensors", + num_inference_steps: int = 30, + cpu_scheduling: bool = True, + scheduler_id: str = None, # compatibility only + shift: float = 1.0, # compatibility only + use_i8_punet: bool = False, + ): + common_export_args = { + "hf_model_name": None, + "precision": None, + "compile_to": "vmfb", + "device": None, + "target": None, + "exit_on_vmfb": False, + "pipeline_dir": pipeline_dir, + "input_mlir": None, + "attn_spec": None, + "external_weights": None, + "external_weight_path": None, + } + sd_model_map = get_sd_model_map(hf_model_name) + for submodel in sd_model_map: + if "load" not in sd_model_map[submodel]: + sd_model_map[submodel]["load"] = True + sd_model_map[submodel]["export_args"]["batch_size"] = batch_size + if "max_length" in sd_model_map[submodel]["export_args"]: + max_length_sub = ( + max_length if isinstance(max_length, int) else max_length[submodel] + ) + sd_model_map[submodel]["export_args"]["max_length"] = max_length_sub + if "height" in sd_model_map[submodel]["export_args"]: + sd_model_map[submodel]["export_args"]["height"] = height + sd_model_map[submodel]["export_args"]["width"] = width + if "decomp_attn" in sd_model_map[submodel]["export_args"]: + if isinstance(decomp_attn, bool): + sd_model_map[submodel]["export_args"]["decomp_attn"] = decomp_attn + else: + sd_model_map[submodel]["export_args"]["decomp_attn"] = ( + decomp_attn.get(submodel, False) + ) + super().__init__( + sd_model_map, + device, + target, + ireec_flags, + precision, + attn_spec, + decomp_attn, + external_weights, + pipeline_dir, + external_weights_dir, + hf_model_name, + common_export_args, + ) + for submodel in sd_model_map: + if self.map[submodel].get("external_weights"): + weights_filename = utils.create_safe_name( + self.map[submodel]["export_args"]["hf_model_name"], + f"_{submodel}_{self.map[submodel]['precision']}", + ) + weights_filename += ( + "." + self.map[submodel]["export_args"]["external_weights"] + ) + self.map[submodel]["export_args"][ + "external_weight_path" + ] = weights_filename + + self.batch_size = batch_size + self.model_max_length = max_length + self.height = height + self.width = width + self.cpu_scheduling = cpu_scheduling + self.scheduler_id = scheduler_id + self.num_inference_steps = num_inference_steps + + self.text_encoder = None + self.unet = None + self.mmdit = None + self.vae = None + self.scheduler = None + + self.split_scheduler = True + + self.base_model_name = ( + hf_model_name + if isinstance(hf_model_name, str) + else hf_model_name.get("unet", hf_model_name.get("mmdit")) + ) + self.is_img2img = False + self.is_sdxl = "xl" in self.base_model_name.lower() + self.is_sd3 = "stable-diffusion-3" in self.base_model_name + if self.is_sdxl: + if self.split_scheduler: + if self.map.get("unetloop"): + self.map.pop("unetloop") + if self.map.get("fullpipeline"): + self.map.pop("fullpipeline") + self.tokenizers = [ + CLIPTokenizer.from_pretrained( + self.base_model_name, subfolder="tokenizer" + ), + CLIPTokenizer.from_pretrained( + self.base_model_name, subfolder="tokenizer_2" + ), + ] + self.latents_precision = self.map["unet"]["precision"] + self.scheduler_device = self.map["unet"]["device"] + self.scheduler_driver = self.map["unet"]["driver"] + self.scheduler_target = self.map["unet"]["target"] + elif not self.is_sd3: + self.tokenizer = CLIPTokenizer.from_pretrained( + self.base_model_name, subfolder="tokenizer" + ) + self.latents_precision = self.map["unet"]["precision"] + self.scheduler_device = self.map["unet"]["device"] + self.scheduler_driver = self.map["unet"]["driver"] + self.scheduler_target = self.map["unet"]["target"] + # TODO: Add SD3 init + + self.latents_dtype = torch_dtypes[self.latents_precision] + self.use_i8_punet = self.use_punet = use_i8_punet + if self.use_i8_punet: + self.map["unet"]["export_args"]["precision"] = "i8" + self.map["unet"]["export_args"]["use_punet"] = True + self.map["unet"]["use_weights_for_export"] = True + self.map["unet"]["keywords"].append("punet") + self.map["unet"]["module_name"] = "compiled_punet" + self.map["unet"]["function_name"] = "main" + self.map["unet"]["export_args"]["external_weight_path"] = ( + utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa" + ) + for idx, word in enumerate(self.map["unet"]["keywords"]): + if word in ["fp32", "fp16"]: + self.map["unet"]["keywords"][idx] = "i8" + break + else: + self.map["unet"]["keywords"].append("!punet") + self.map["unet"]["function_name"] = "run_forward" + + # LOAD + + def load_scheduler( + self, + scheduler_id: str, + steps: int = 30, + ): + if self.is_sd3: + scheduler_device = self.mmdit.device + else: + scheduler_device = self.unet.device + if not self.cpu_scheduling: + self.map["scheduler"] = { + "module_name": "compiled_scheduler", + "export_fn": schedulers.export_scheduler_model, + "driver": self.scheduler_driver, + "export_args": { + "hf_model_name": self.base_model_name, + "scheduler_id": scheduler_id, + "batch_size": self.batch_size, + "height": self.height, + "width": self.width, + "num_inference_steps": steps, + "precision": self.latents_precision, + "compile_to": "vmfb", + "device": self.scheduler_device, + "target": self.scheduler_target, + "pipeline_dir": self.pipeline_dir, + }, + } + self.scheduler = None + self.num_inference_steps = steps + self.scheduler_id = scheduler_id + scheduler_uid = "_".join( + [ + f"{scheduler_id}Scheduler", + f"bs{self.batch_size}", + "x".join([str(self.width), str(self.height)]), + self.latents_precision, + str(self.num_inference_steps), + self.scheduler_target, + ] + ) + scheduler_path = os.path.join( + self.pipeline_dir, + utils.create_safe_name(self.base_model_name, scheduler_uid), + ) + if not os.path.exists(scheduler_path): + self.export_submodel("scheduler") + else: + self.map["scheduler"]["vmfb"] = scheduler_path + try: + self.load_submodel("scheduler") + except: + print("JIT export of scheduler failed. Loading CPU scheduler.") + self.cpu_scheduling = True + if self.cpu_scheduling: + scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id) + self.scheduler = schedulers.SharkSchedulerCPUWrapper( + scheduler, + self.batch_size, + scheduler_device, + latents_dtype=self.latents_dtype, + ) + if self.use_punet: + self.scheduler.use_punet = True + + # RUN + + def encode_prompts_sdxl(self, prompt, negative_prompt): + # Tokenize prompt and negative prompt. + text_input_ids_list = [] + uncond_input_ids_list = [] + + for tokenizer in self.tokenizers: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=self.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids_list += text_inputs.input_ids.unsqueeze(0) + uncond_input_ids_list += uncond_input.input_ids.unsqueeze(0) + + if self.compiled_pipeline: + return text_input_ids_list, uncond_input_ids_list + else: + prompt_embeds, add_text_embeds = self.text_encoder( + "encode_prompts", [*text_input_ids_list, *uncond_input_ids_list] + ) + return prompt_embeds, add_text_embeds + + def prepare_latents( + self, + noise, + num_inference_steps, + image=None, + strength=None, + ): + if self.is_img2img: + raise NotImplementedError("Image-to-image not supported yet.") + elif self.is_sdxl and self.cpu_scheduling: + self.scheduler.do_guidance = False + self.scheduler.repeat_sample = False + sample, add_time_ids, step_indexes, timesteps = ( + self.scheduler.initialize_sdxl(noise, num_inference_steps) + ) + return sample, add_time_ids, step_indexes, timesteps + elif self.is_sdxl: + return self.scheduler("run_initialize", noise) + elif self.is_sd3: + raise NotImplementedError("Stable Diffusion 3 not supported yet.") + else: + sample, timesteps = self.scheduler.initialize_sd(noise, num_inference_steps) + return sample, timesteps + + def get_rand_latents(self, seed, batch_count): + samples = [] + uint32_info = np.iinfo(np.uint32) + uint32_min, uint32_max = uint32_info.min, uint32_info.max + if seed < uint32_min or seed >= uint32_max: + seed = randint(uint32_min, uint32_max) + for i in range(batch_count): + generator = torch.manual_seed(seed + i) + rand_sample = torch.randn( + ( + self.batch_size, + 4, + self.height // 8, + self.width // 8, + ), + generator=generator, + dtype=self.latents_dtype, + ) + samples.append(rand_sample) + return samples + + def _produce_latents_sd( + self, + sample, + prompt_embeds, + negative_prompt_embeds, + steps, + guidance_scale, + ): + image = None + strength = 0 + sample, timesteps = self.prepare_latents( + sample, self.num_inference_steps, image, strength + ) + text_embeddings = torch.cat((negative_prompt_embeds, prompt_embeds), dim=0) + self.scheduler.do_guidance = False + for i, t in tqdm(enumerate(timesteps)): + latent_model_input, _ = self.scheduler.scale_model_input(sample, t) + timestep = torch.tensor([t]) + unet_inputs = [ + latent_model_input, + timestep, + ] + unet_inputs.extend([text_embeddings, [guidance_scale]]) + latents = self.unet(self.map["unet"]["function_name"], unet_inputs) + sample = self.scheduler.step( + torch.tensor( + latents, dtype=torch_dtypes[self.map["unet"]["precision"]] + ), + t, + sample, + ) + return sample + + def _produce_latents_sdxl( + self, + sample, + prompt_embeds, + add_text_embeds, + steps, + guidance_scale, + ): + image = None + strength = 0 + latents, add_time_ids, step_indexes, timesteps = self.prepare_latents( + sample, self.num_inference_steps, image, strength + ) + guidance_scale = ireert.asdevicearray( + self.unet.device, + [guidance_scale], + dtype=self.map["unet"]["np_dtype"], + ) + for i, t in tqdm(enumerate(timesteps)): + if self.cpu_scheduling: + latent_model_input, t = self.scheduler.scale_model_input( + latents, + t, + ) + t = t.type(self.map["unet"]["torch_dtype"]) + else: + step = torch.tensor([i], dtype=torch.float32) + latent_model_input, t = self.scheduler( + "run_scale", [latents, step, timesteps] + ) + + unet_inputs = [ + latent_model_input, + t, + prompt_embeds, + add_text_embeds, + add_time_ids, + guidance_scale, + ] + if self.use_punet: + for inp_idx, inp in enumerate(unet_inputs): + if not isinstance(inp, ireert.DeviceArray): + unet_inputs[inp_idx] = ireert.asdevicearray( + self.unet.device, inp, dtype=self.map["unet"]["np_dtype"] + ) + noise_pred = self.unet( + self.map["unet"]["function_name"], + unet_inputs, + ) + if self.cpu_scheduling: + latents = self.scheduler.step( + noise_pred, + t, + latents, + ) + else: + latents = self.scheduler("run_step", [noise_pred, t, latents]) + return latents + + def generate_images( + self, + prompt: str, + negative_prompt: str = "", + steps: int = 30, + batch_count: int = 1, + guidance_scale: float = 7.5, + seed: float = -1, + cpu_scheduling: bool = True, + scheduler_id: str = "EulerDiscrete", + return_imgs: bool = False, + ): + needs_new_scheduler = ( + (steps and steps != self.num_inference_steps) + or (cpu_scheduling != self.cpu_scheduling) + and self.split_scheduler + ) + if not self.scheduler and not self.compiled_pipeline: + needs_new_scheduler = True + + if guidance_scale == 0: + negative_prompt = prompt + prompt = "" + + self.cpu_scheduling = cpu_scheduling + if steps and needs_new_scheduler: + self.num_inference_steps = steps + self.load_scheduler(scheduler_id, steps) + + pipe_start = time.time() + numpy_images = [] + + samples = self.get_rand_latents(seed, batch_count) + + # Tokenize prompt and negative prompt. + if self.is_sdxl: + prompt_embeds, negative_embeds = self.encode_prompts_sdxl( + prompt, negative_prompt + ) + else: + prompt_embeds, negative_embeds = encode_prompt( + self, prompt, negative_prompt + ) + + for i in range(batch_count): + produce_latents_input = [ + samples[i], + prompt_embeds, + negative_embeds, + steps, + guidance_scale, + ] + if self.is_sdxl: + latents = self._produce_latents_sdxl(*produce_latents_input) + else: + latents = self._produce_latents_sd(*produce_latents_input) + image = self.vae("decode", [latents]) + numpy_images.append(image) + pipe_end = time.time() + + logging.info(f"Total inference time: {pipe_end - pipe_start:.2f}s") + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + images = [] + for idx, image in enumerate(numpy_images): + image = torch.from_numpy(image).cpu().permute(0, 2, 3, 1).float().numpy() + image = numpy_to_pil_image(image) + images.append(image[0]) + if return_imgs: + return images + for idx, image in enumerate(images): + img_path = "sdxl_output_" + timestamp + "_" + str(idx) + ".png" + image.save(img_path) + print(img_path, "saved") + return + + +def numpy_to_pil_image(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +if __name__ == "__main__": + from turbine_models.custom_models.sd_inference.sd_cmd_opts import args + + ireec_flags = { + "clip": args.ireec_flags + args.clip_flags, + "scheduler": args.ireec_flags, + "unet": args.ireec_flags + args.unet_flags, + "vae_decode": args.ireec_flags + args.vae_flags, + } + if not args.pipeline_dir: + args.pipeline_dir = utils.create_safe_name(args.hf_model_name, "") + if any(x for x in [args.vae_decomp_attn, args.unet_decomp_attn]): + args.decomp_attn = { + "text_encoder": args.decomp_attn, + "unet": ( + args.unet_decomp_attn if args.unet_decomp_attn else args.decomp_attn + ), + "vae": args.vae_decomp_attn if args.vae_decomp_attn else args.decomp_attn, + } + sd_pipe = SharkSDPipeline( + args.hf_model_name, + args.height, + args.width, + args.batch_size, + args.max_length, + args.precision, + args.device, + args.iree_target_triple, + ireec_flags, + args.attn_spec, + args.decomp_attn, + args.pipeline_dir, + args.external_weights_dir, + args.external_weights, + args.num_inference_steps, + args.cpu_scheduling, + args.scheduler_id, + None, + args.use_i8_punet, + ) + sd_pipe.prepare_all() + sd_pipe.load_map() + sd_pipe.generate_images( + args.prompt, + args.negative_prompt, + args.num_inference_steps, + args.batch_count, + args.guidance_scale, + args.seed, + args.cpu_scheduling, + args.scheduler_id, + False, + ) + print("Image generation complete.") diff --git a/models/turbine_models/custom_models/sd_inference/tokenization.py b/models/turbine_models/custom_models/sd_inference/tokenization.py new file mode 100644 index 000000000..e35d37e06 --- /dev/null +++ b/models/turbine_models/custom_models/sd_inference/tokenization.py @@ -0,0 +1,177 @@ +from typing import List, Optional, Union +from iree import runtime as ireert +import re +import torch +import numpy as np +import warnings + + +# The following is copied from Diffusers' "encode_prompt" function in the StableDiffusion pipeline. +# It has been lightly augmented to work with the SHARK-Turbine pipeline. +def encode_prompt( + pipe, + prompt, + negative_prompt=None, + num_images_per_prompt=1, + do_classifier_free_guidance=True, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, +): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + # if lora_scale is not None and pipe.use_lora: + # pipe._lora_scale = lora_scale + + # # dynamically adjust the LoRA scale + # if not USE_PEFT_BACKEND: + # adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale) + # else: + # scale_lora_layers(pipe.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + # if pipe.use_textual_inversion: + # prompt = pipe.maybe_convert_prompt(prompt, pipe.tokenizer) + + text_inputs = pipe.tokenizer( + prompt, + padding="max_length", + max_length=pipe.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = pipe.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = pipe.tokenizer.batch_decode( + untruncated_ids[:, pipe.model_max_length - 1 : -1] + ) + warnings.warn( + "The following text was removed due to truncation: " + removed_text + ) + if pipe.text_encoder.metadata.get("use_attention_mask"): + attention_mask = text_inputs.attention_mask + prompt_embeds = pipe.text_encoder( + "encode_tokens_attn_mask", [text_input_ids, attention_mask] + ) + else: + attention_mask = None + prompt_embeds = pipe.text_encoder("encode_tokens", [text_input_ids]) + prompt_embeds = prompt_embeds[0] + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + # if pipe.use_textual_inversion: + # uncond_tokens = pipe.maybe_convert_prompt(uncond_tokens, pipe.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = pipe.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if pipe.text_encoder.metadata.get("use_attention_mask"): + attention_mask = uncond_input.attention_mask + negative_prompt_embeds = pipe.text_encoder( + "encode_tokens_attn_mask", + [ + uncond_input.input_ids, + attention_mask, + ], + ) + else: + attention_mask = None + negative_prompt_embeds = pipe.text_encoder( + "encode_tokens", + [ + uncond_input.input_ids, + ], + ) + + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # if pipe.use_lora: + # Retrieve the original scale by scaling back the LoRA layers + # unimplemented + # unscale_lora_layers(pipe.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds diff --git a/models/turbine_models/custom_models/sd_inference/unet.py b/models/turbine_models/custom_models/sd_inference/unet.py index 18657ae86..dac967b8a 100644 --- a/models/turbine_models/custom_models/sd_inference/unet.py +++ b/models/turbine_models/custom_models/sd_inference/unet.py @@ -6,6 +6,7 @@ import os import sys +import copy from iree import runtime as ireert from iree.compiler.ir import Context @@ -14,6 +15,7 @@ from shark_turbine.dynamo.passes import ( DEFAULT_DECOMPOSITIONS, ) +from shark_turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch import torch._dynamo as dynamo @@ -23,174 +25,200 @@ import argparse from turbine_models.turbine_tank import turbine_tank -parser = argparse.ArgumentParser() -parser.add_argument( - "--hf_auth_token", type=str, help="The Hugging Face auth token, required" -) -parser.add_argument( - "--hf_model_name", - type=str, - help="HF model name", - default="CompVis/stable-diffusion-v1-4", -) -parser.add_argument( - "--batch_size", type=int, default=1, help="Batch size for inference" -) -parser.add_argument( - "--height", type=int, default=512, help="Height of Stable Diffusion" -) -parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") -parser.add_argument( - "--precision", type=str, default="fp16", help="Precision of Stable Diffusion" -) -parser.add_argument( - "--max_length", type=int, default=77, help="Sequence Length of Stable Diffusion" -) -parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") -parser.add_argument("--external_weight_path", type=str, default="") -parser.add_argument( - "--external_weights", - type=str, - default=None, - help="saves ir/vmfb without global weights for size and readability, options [safetensors]", -) -parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") -# TODO: Bring in detection for target triple -parser.add_argument( - "--iree_target_triple", - type=str, - default="", - help="Specify vulkan target triple or rocm/cuda target device.", -) -parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") - class UnetModel(torch.nn.Module): - def __init__(self, hf_model_name, hf_auth_token=None): + def __init__(self, hf_model_name): super().__init__() + self.do_classifier_free_guidance = True self.unet = UNet2DConditionModel.from_pretrained( hf_model_name, subfolder="unet", ) - def forward(self, sample, timestep, encoder_hidden_states, guidance_scale): - samples = torch.cat([sample] * 2) - unet_out = self.unet.forward( - samples, timestep, encoder_hidden_states, return_dict=False + def forward( + self, latent_model_input, timestep, encoder_hidden_states, guidance_scale + ): + noise_pred = self.unet.forward( + latent_model_input, timestep, encoder_hidden_states, return_dict=False )[0] - noise_pred_uncond, noise_pred_text = unet_out.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + if self.do_classifier_free_guidance: + noise_preds = noise_pred.chunk(2) + noise_pred = noise_preds[0] + guidance_scale * ( + noise_preds[1] - noise_preds[0] + ) return noise_pred def export_unet_model( - unet_model, hf_model_name, batch_size, height, width, precision="fp32", max_length=77, - hf_auth_token=None, compile_to="torch", external_weights=None, external_weight_path=None, device=None, - target_triple=None, - max_alloc=None, + target=None, + ireec_flags=None, + decomp_attn=False, + exit_on_vmfb=False, + pipeline_dir=None, + attn_spec=None, + input_mlir=None, + weights_only=False, upload_ir=False, - decomp_attn=True, ): - mapper = {} - decomp_list = DEFAULT_DECOMPOSITIONS - if decomp_attn: - decomp_list.extend( - [ - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, - torch.ops.aten._scaled_dot_product_flash_attention.default, - ] + if input_mlir: + unet_model = None + else: + unet_model = UnetModel( + hf_model_name, ) dtype = torch.float16 if precision == "fp16" else torch.float32 - unet_model = unet_model.to(dtype) + np_dtype = "float16" if precision == "fp16" else "float32" + safe_name = utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_unet", + ) + if decomp_attn: + safe_name += "_decomp_attn" + if pipeline_dir: + safe_name = os.path.join(pipeline_dir, safe_name) + + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + attn_spec=attn_spec, + ) + return vmfb_path + + mapper = {} + + if precision == "fp16": + unet_model = unet_model.half() + utils.save_external_weights( mapper, unet_model, external_weights, external_weight_path ) + + if weights_only: + return external_weight_path + + sample = ( + batch_size * 2, + unet_model.unet.config.in_channels, + height // 8, + width // 8, + ) encoder_hidden_states_sizes = ( unet_model.unet.config.layers_per_block, max_length, unet_model.unet.config.cross_attention_dim, ) + example_forward_args = [ + torch.empty(sample, dtype=dtype), + torch.empty(1, dtype=dtype), + torch.empty(encoder_hidden_states_sizes, dtype=dtype), + torch.empty(1, dtype=dtype), + ] + decomp_list = [] + if decomp_attn: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + fxb = FxProgramsBuilder(unet_model) + + @fxb.export_program( + args=(example_forward_args,), + ) + def _forward( + module, + inputs, + ): + return module.forward(*inputs) - sample = (batch_size, unet_model.unet.config.in_channels, height // 8, width // 8) + class CompiledUnet(CompiledModule): + run_forward = _forward - class CompiledUnet(CompiledModule): if external_weights: - params = export_parameters( - unet_model, external=True, external_scope="", name_mapper=mapper.get - ) - else: - params = export_parameters(unet_model) - - def main( - self, - sample=AbstractTensor(*sample, dtype=dtype), - timestep=AbstractTensor(1, dtype=dtype), - encoder_hidden_states=AbstractTensor( - *encoder_hidden_states_sizes, dtype=dtype - ), - guidance_scale=AbstractTensor(1, dtype=dtype), - ): - return jittable(unet_model.forward, decompose_ops=decomp_list)( - sample, timestep, encoder_hidden_states, guidance_scale - ) - - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledUnet(context=Context(), import_to=import_to) - - module_str = str(CompiledModule.get_mlir_module(inst)) - safe_name = utils.create_safe_name(hf_model_name, "-unet") - if upload_ir: - with open(f"{safe_name}.mlir", "w+") as f: - f.write(module_str) - model_name_upload = hf_model_name.replace("/", "-") - model_name_upload += "_unet" - blob_name = turbine_tank.uploadToBlobStorage( - str(os.path.abspath(f"{safe_name}.mlir")), - f"{model_name_upload}/{model_name_upload}.mlir", - ) + externalize_module_parameters(unet_model) + + inst = CompiledUnet(context=Context(), import_to="IMPORT") + + module = CompiledModule.get_mlir_module(inst) + + model_metadata_run_forward = { + "model_name": "sd_unet", + "input_shapes": [ + sample, + (1,), + encoder_hidden_states_sizes, + (1,), + ], + "input_dtypes": [np_dtype for x in range(4)], + "output_shapes": [sample], + "output_dtypes": [np_dtype], + } + + module = AddMetadataPass(module, model_metadata_run_forward, "run_forward").run() + module_str = str(module) if compile_to != "vmfb": return module_str else: - utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) - if upload_ir: - return blob_name + vmfb_path = utils.compile_to_vmfb( + module_str, + device, + target, + ireec_flags, + safe_name, + return_path=True, + attn_spec=attn_spec, + ) + if exit_on_vmfb: + exit() + return vmfb_path if __name__ == "__main__": - args = parser.parse_args() - unet_model = UnetModel( - args.hf_model_name, - args.hf_auth_token, - ) + from turbine_models.custom_models.sd_inference.sd_cmd_opts import args + mod_str = export_unet_model( - unet_model, args.hf_model_name, args.batch_size, args.height, args.width, args.precision, args.max_length, - args.hf_auth_token, args.compile_to, args.external_weights, args.external_weight_path, args.device, args.iree_target_triple, - args.vulkan_max_allocation, + args.ireec_flags + args.attn_flags + args.unet_flags, + args.decomp_attn, + attn_spec=args.attn_spec, + input_mlir=args.input_mlir, + ) + if args.input_mlir: + exit() + safe_name = utils.create_safe_name( + args.hf_model_name, + f"_bs{args.batch_size}_{args.max_length}_{args.height}x{args.width}_{args.precision}_unet", ) - safe_name = utils.create_safe_name(args.hf_model_name, "-unet") with open(f"{safe_name}.mlir", "w+") as f: f.write(mod_str) print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sd_inference/unet_runner.py b/models/turbine_models/custom_models/sd_inference/unet_runner.py index 1b8c5d101..12e420960 100644 --- a/models/turbine_models/custom_models/sd_inference/unet_runner.py +++ b/models/turbine_models/custom_models/sd_inference/unet_runner.py @@ -4,48 +4,6 @@ from iree import runtime as ireert import torch -parser = argparse.ArgumentParser() - -# TODO move common runner flags to generic flag file -parser.add_argument( - "--vmfb_path", type=str, default="", help="path to vmfb containing compiled module" -) -parser.add_argument( - "--external_weight_path", - type=str, - default="", - help="path to external weight parameters if model compiled without them", -) -parser.add_argument( - "--compare_vs_torch", - action="store_true", - help="Runs both turbine vmfb and a torch model to compare results", -) -parser.add_argument( - "--hf_model_name", - type=str, - help="HF model name", - default="CompVis/stable-diffusion-v1-4", -) -parser.add_argument( - "--hf_auth_token", - type=str, - help="The Hugging face auth token, required for some models", -) -parser.add_argument( - "--device", - type=str, - default="local-task", - help="local-sync, local-task, cuda, vulkan, rocm", -) -parser.add_argument( - "--batch_size", type=int, default=1, help="Batch size for inference" -) -parser.add_argument( - "--height", type=int, default=512, help="Height of Stable Diffusion" -) -parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") - def run_unet( device, @@ -57,16 +15,18 @@ def run_unet( hf_model_name, hf_auth_token, external_weight_path, + iree_dtype, ): runner = vmfbRunner(device, vmfb_path, external_weight_path) - inputs = [ - ireert.asdevicearray(runner.config.device, sample), - ireert.asdevicearray(runner.config.device, timestep), - ireert.asdevicearray(runner.config.device, encoder_hidden_states), - ireert.asdevicearray(runner.config.device, guidance_scale), + ireert.asdevicearray(runner.config.device, sample, dtype=iree_dtype), + ireert.asdevicearray(runner.config.device, timestep, dtype=iree_dtype), + ireert.asdevicearray( + runner.config.device, encoder_hidden_states, dtype=iree_dtype + ), + ireert.asdevicearray(runner.config.device, guidance_scale, dtype=iree_dtype), ] - results = runner.ctx.modules.compiled_unet["main"](*inputs) + results = runner.ctx.modules.compiled_unet["run_forward"](*inputs) return results @@ -78,32 +38,10 @@ def run_torch_unet( encoder_hidden_states, guidance_scale, ): - from diffusers import UNet2DConditionModel - - class UnetModel(torch.nn.Module): - def __init__(self, hf_model_name, hf_auth_token): - super().__init__() - self.unet = UNet2DConditionModel.from_pretrained( - hf_model_name, - subfolder="unet", - token=hf_auth_token, - ) - self.guidance_scale = 7.5 - - def forward(self, sample, timestep, encoder_hidden_states, guidance_scale): - samples = torch.cat([sample] * 2) - unet_out = self.unet.forward( - samples, timestep, encoder_hidden_states, return_dict=False - )[0] - noise_pred_uncond, noise_pred_text = unet_out.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - return noise_pred + from turbine_models.custom_models.sd_inference.unet import UnetModel unet_model = UnetModel( hf_model_name, - hf_auth_token, ) results = unet_model.forward( sample, timestep, encoder_hidden_states, guidance_scale @@ -114,15 +52,21 @@ def forward(self, sample, timestep, encoder_hidden_states, guidance_scale): if __name__ == "__main__": args = parser.parse_args() + iree_dtypes = { + "fp16": "float16", + "fp32": "float32", + } sample = torch.rand( - args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32 + args.batch_size * 2, 4, args.height // 8, args.width // 8, dtype=torch.float32 ) timestep = torch.zeros(1, dtype=torch.float32) guidance_scale = torch.Tensor([7.5], dtype=torch.float32) if args.hf_model_name == "CompVis/stable-diffusion-v1-4": - encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) + encoder_hidden_states = torch.rand(2, args.max_length, 768, dtype=torch.float32) elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base": - encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32) + encoder_hidden_states = torch.rand( + 2, args.max_length, 1024, dtype=torch.float32 + ) turbine_output = run_unet( args.device, @@ -134,6 +78,7 @@ def forward(self, sample, timestep, encoder_hidden_states, guidance_scale): args.hf_model_name, args.hf_auth_token, args.external_weight_path, + iree_dtypes[args.precision], ) print( "TURBINE OUTPUT:", @@ -145,6 +90,7 @@ def forward(self, sample, timestep, encoder_hidden_states, guidance_scale): if args.compare_vs_torch: print("generating torch output: ") from turbine_models.custom_models.sd_inference import utils + from turbine_models.custom_models.sd_inference.sd_cmd_opts import args torch_output = run_torch_unet( args.hf_model_name, diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index a90824dae..b9afc6de8 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -1,49 +1,144 @@ +from urllib.request import urlopen import iree.compiler as ireec import numpy as np import os import safetensors +import safetensors.numpy as safe_numpy import re from diffusers import ( PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, + # DPMSolverSDEScheduler, ) # If flags are verified to work on a specific model and improve performance without regressing numerics, add them to this dictionary. If you are working with bleeding edge flags, please add them manually with the --ireec_flags argument. -gfx94X_flags = { +MI_flags = { "all": [ "--iree-global-opt-propagate-transposes=true", + "--iree-opt-const-eval=false", "--iree-opt-outer-dim-concat=true", - "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode", "--iree-vm-target-truncate-unsupported-floats", "--iree-llvmgpu-enable-prefetch=true", - "--verify=false", - "--iree-rocm-waves-per-eu=2", "--iree-opt-data-tiling=false", - "--iree-codegen-log-swizzle-tile=4", - "--iree-llvmgpu-promote-filter=true", - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)", + "--iree-codegen-gpu-native-math-precision=true", + "--iree-rocm-waves-per-eu=2", + "--iree-flow-inline-constants-max-byte-length=1", + ], + "pad_attention": [ + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,128,0,32,0}))", + ], + "preprocess_default": [ + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))", ], "unet": [ - "--iree-codegen-llvmgpu-use-conv-vector-distribute-pipeline", - "--iree-codegen-llvmgpu-reduce-skinny-matmuls", - "--iree-codegen-gpu-native-math-precision=true", - "--iree-codegen-llvmgpu-use-vector-distribution", - "--iree-codegen-winograd-use-forall", + "--iree-flow-enable-aggressive-fusion", + "--iree-global-opt-enable-fuse-horizontal-contractions=true", + "--iree-opt-aggressively-propagate-transposes=true", + "--iree-codegen-llvmgpu-use-vector-distribution=true", ], "clip": [ - "--iree-codegen-llvmgpu-use-vector-distribution", - "--iree-codegen-llvmgpu-reduce-skinny-matmuls", - "--iree-global-opt-only-sink-transposes=true", + "--iree-flow-enable-aggressive-fusion", + "--iree-global-opt-enable-fuse-horizontal-contractions=true", + "--iree-opt-aggressively-propagate-transposes=true", ], "vae": [ - "--iree-codegen-llvmgpu-use-conv-vector-distribute-pipeline", - "--iree-codegen-llvmgpu-use-vector-distribution", - "--iree-global-opt-only-sink-transposes=true", - "--iree-codegen-winograd-use-forall", + "--iree-flow-enable-aggressive-fusion", + "--iree-codegen-llvmgpu-use-vector-distribution=true", + ], + "winograd": [""], +} +GFX11_flags = { + "all": [ + "--iree-global-opt-propagate-transposes=true", + "--iree-opt-outer-dim-concat=true", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-llvmgpu-enable-prefetch=true", + "--iree-opt-data-tiling=false", + "--iree-opt-const-eval=false", + "--iree-opt-aggressively-propagate-transposes=true", + "--iree-flow-enable-aggressive-fusion", + "--iree-global-opt-enable-fuse-horizontal-contractions=true", + "--iree-codegen-gpu-native-math-precision=true", + "--iree-codegen-llvmgpu-use-vector-distribution=true", + "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", + ], + "pad_attention": [ + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", + ], + "preprocess_default": [ + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))", + ], + "unet": [""], + "clip": [""], + "vae": [""], + "winograd": [""], +} +znver4_flags = { + "all": [ + "--iree-llvmcpu-target-cpu=znver4", + "--iree-opt-const-eval=false", + "--iree-llvmcpu-enable-ukernels=mmt4d,pack,unpack", + "--iree-flow-collapse-reduction-dims", + "--iree-opt-const-expr-max-size-increase-threshold=1000000000000000", + "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops", + ], + "bf16": [ + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-demote-contraction-inputs-to-bf16))", + ], + "winograd": [ + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-linalg-ext-convert-conv2d-to-winograd{replace-all-convs=true},iree-global-opt-demote-contraction-inputs-to-bf16))" ], } +_IREE_DRIVER_MAP = { + "cpu": "local-task", + "cpu-task": "local-task", + "cpu-sync": "local-sync", + "cuda": "cuda", + "vulkan": "vulkan", + "metal": "metal", + "rocm": "hip", + "rocm-legacy": "rocm", + "hip": "hip", + "intel-gpu": "level_zero", +} + +_IREE_BACKEND_MAP = { + "cpu": "llvm-cpu", + "local-task": "llvm-cpu", + "local-sync": "llvm-cpu", + "rocm": "rocm", + "rocm-legacy": "rocm", + "hip": "rocm", + "cuda": "cuda", + "vulkan": "vulkan-spirv", + "metal": "metal", +} + + +def iree_device_map(device): + uri_parts = device.split("://", 2) + iree_driver = ( + _IREE_DRIVER_MAP[uri_parts[0]] + if uri_parts[0] in _IREE_DRIVER_MAP + else uri_parts[0] + ) + if len(uri_parts) == 1: + return iree_driver + else: + return f"{iree_driver}://{uri_parts[1]}" + + +def iree_backend_map(device): + uri_parts = device.split("://", 2) + iree_device = ( + _IREE_BACKEND_MAP[uri_parts[0]] + if uri_parts[0] in _IREE_BACKEND_MAP + else uri_parts[0] + ) + return iree_device + def compile_to_vmfb( module_str, @@ -55,10 +150,15 @@ def compile_to_vmfb( const_expr_hoisting=True, mlir_source="str", max_alloc="4294967296", - save_mlir=False, + save_mlir=True, attn_spec=None, + winograd=False, + masked_attention=False, + debug=False, ): flags = [] + if mlir_source == "file" and not isinstance(module_str, str): + module_str = str(module_str) if target_triple in ["", None]: if device == "cpu": target_triple = "x86_64-linux-gnu" @@ -66,17 +166,25 @@ def compile_to_vmfb( raise ValueError( "target_triple must be set. Usually this can be fixed by setting --iree_target_triple in the CLI." ) - if device == "cpu": - flags.extend( - [ - "--iree-llvmcpu-target-triple=" + target_triple, - "--iree-llvmcpu-target-cpu-features=host", - "--iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false", - "--iree-llvmcpu-distribution-size=32", - ] - ) + if device in ["cpu", "llvm-cpu"]: + if target_triple == "znver4": + flags.extend(znver4_flags["all"]) + if winograd: + flags.extend(znver4_flags["winograd"]) + else: + flags.extend( + [ + "--iree-llvmcpu-target-triple=" + target_triple, + "--iree-llvmcpu-target-cpu-features=host", + "--iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false", + "--iree-llvmcpu-distribution-size=32", + "--iree-opt-const-eval=false", + "--iree-llvmcpu-enable-ukernels=all", + "--iree-global-opt-enable-quantized-matmul-reassociation", + ] + ) device = "llvm-cpu" - elif device == "vulkan": + elif device in ["vulkan", "vulkan-spirv"]: flags.extend( [ "--iree-hal-target-backends=vulkan-spirv", @@ -88,15 +196,16 @@ def compile_to_vmfb( ] ) device = "vulkan-spirv" - elif device == "rocm": + elif device in ["rocm", "hip"]: flags.extend( [ "--iree-hal-target-backends=rocm", "--iree-rocm-target-chip=" + target_triple, - "--verify=false", - "--iree-opt-const-eval=false", + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", ] ) + if target_triple == "gfx942": + flags.extend(["--iree-rocm-waves-per-eu=2"]) elif device == "cuda": flags.extend( [ @@ -113,29 +222,85 @@ def compile_to_vmfb( elif ireec_flags == None: ireec_flags = [] - for i, flag in enumerate(ireec_flags): - k = flag.strip().split("=")[0] - for idx, default in enumerate(flags): - if k == default.split("=")[0]: - flags[idx] = flag - ireec_flags[i] = "" - if flag not in [None, "", " "]: - flags.append(flag) + if debug: + flags.extend( + ["--iree-hal-dump-executable-files-to=" + safe_name + "_dispatches"] + ) - if target_triple in ["gfx940", "gfx941", "gfx942"]: + if target_triple in ["gfx940", "gfx941", "gfx942", "gfx90a"]: if "unet" in safe_name: - flags.extend(gfx94X_flags["unet"]) + flags.extend(MI_flags["unet"]) elif any(x in safe_name for x in ["clip", "prompt_encoder"]): - flags.extend(gfx94X_flags["clip"]) + flags.extend(MI_flags["clip"]) elif "vae" in safe_name: - flags.extend(gfx94X_flags["vae"]) - flags.extend(gfx94X_flags["all"]) + flags.extend(MI_flags["vae"]) + flags.extend(MI_flags["all"]) + if masked_attention: + flags.extend(GFX11_flags["pad_attention"]) + else: + flags.extend(GFX11_flags["preprocess_default"]) + + if "gfx11" in target_triple: + flags.extend(GFX11_flags["all"]) + if masked_attention: + flags.extend(GFX11_flags["pad_attention"]) + else: + flags.extend(GFX11_flags["preprocess_default"]) + + # Currently, we need a transform dialect script to be applied to the compilation through IREE in certain cases. + # This 'attn_spec' handles a linalg_ext.attention op lowering to mfma instructions for capable targets. + # This is a temporary solution, and should be removed or largely disabled once the functionality of + # the TD spec is implemented in C++. - if attn_spec not in [None, "", " "]: + if attn_spec in ["default", "mfma", "punet"]: + use_punet = True if attn_spec in ["punet", "i8"] else False + attn_spec = get_mfma_spec_path( + target_triple, + os.path.dirname(safe_name), + masked_attention, + use_punet=use_punet, + ) + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + + elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): + attn_spec = get_wmma_spec_path( + target_triple, os.path.dirname(safe_name), masked_attention + ) + if attn_spec: + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + elif attn_spec and attn_spec != "None": flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + for i, flag in enumerate(ireec_flags): + k = flag.strip().split("=")[0] + for idx, default in enumerate(flags): + if default == None: + flags.pop(idx) + continue + elif k == default.split("=")[0]: + flags[idx] = flag if flag.split("=")[-1] not in ["None", ""] else None + flag = None + if flags[idx] == None: + flags.pop(idx) + continue + if flag not in [None, "", " "] and flag.split("=")[-1] not in ["None", ""]: + flags.append(flag) + + for idx, flag in enumerate(flags): + if flag is None: + flags.pop(idx) print("Compiling to", device, "with flags:", flags) + # Forces a standard for naming files: + # If safe_name has target triple in it, get rid of target triple in mlir name + # + if target_triple not in safe_name: + safe_vmfb_name = safe_name + "_" + target_triple + safe_mlir_name = safe_name + else: + safe_vmfb_name = safe_name + safe_mlir_name = "".join(safe_name.split(target_triple)) + if mlir_source == "file": flatbuffer_blob = ireec.compile_file( module_str, @@ -145,9 +310,9 @@ def compile_to_vmfb( ) elif mlir_source == "str": if save_mlir: - with open(f"{safe_name}.mlir", "w+") as f: + with open(f"{safe_mlir_name}.mlir", "w+") as f: f.write(module_str) - print("Saved to", safe_name + ".mlir") + print("Saved to", safe_mlir_name + ".mlir") flatbuffer_blob = ireec.compile_str( module_str, target_backends=[device], @@ -156,33 +321,80 @@ def compile_to_vmfb( ) else: raise ValueError("mlir_source must be either 'file' or 'str'") - with open(f"{safe_name}.vmfb", "wb+") as f: + with open(f"{safe_vmfb_name}.vmfb", "wb+") as f: f.write(flatbuffer_blob) - print("Saved to", safe_name + ".vmfb") + print(f"Saved to {safe_vmfb_name}.vmfb") if return_path == True: - return safe_name + ".vmfb" + return safe_vmfb_name + ".vmfb" -def create_safe_name(hf_model_name, model_name_str): +def create_safe_name(hf_model_name, model_name_str=""): + if not model_name_str: + model_name_str = "" + if model_name_str != "" and (not model_name_str.startswith("_")): + model_name_str = "_" + model_name_str + safe_name = hf_model_name.split("/")[-1].strip() + model_name_str safe_name = re.sub("-", "_", safe_name) safe_name = re.sub("\.", "_", safe_name) return safe_name +def get_mfma_spec_path(target_chip, save_dir, masked_attention=False, use_punet=False): + if use_punet: + suffix = "_punet" + url = "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/specs/attention_and_matmul_spec.mlir" + elif not masked_attention: + suffix = "" + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/no_pad/attention_and_matmul_spec_mfma.mlir" + else: + suffix = "_pad" + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir" + attn_spec = urlopen(url).read().decode("utf-8") + spec_path = os.path.join(save_dir, f"attention_and_matmul_spec_mfma{suffix}.mlir") + with open(spec_path, "w") as f: + f.write(attn_spec) + return spec_path + + +def get_wmma_spec_path(target_chip, save_dir, masked_attention=False): + if not masked_attention: + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/no_pad/attention_and_matmul_spec_wmma.mlir" + elif target_chip == "gfx1100": + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1100.mlir" + elif target_chip in ["gfx1103", "gfx1150"]: + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1150.mlir" + else: + return None + attn_spec = urlopen(url).read().decode("utf-8") + suffix = "masked" if masked_attention else "" + spec_path = os.path.join(save_dir, f"attention_and_matmul_spec_wmma{suffix}.mlir") + with open(spec_path, "w") as f: + f.write(attn_spec) + return spec_path + + def save_external_weights( mapper, model, external_weights=None, external_weight_file=None, + force_format=False, ): if external_weights is not None: if external_weights in ["safetensors", "irpa"]: mod_params = dict(model.named_parameters()) + mod_buffers = dict(model.named_buffers()) + mod_params.update(mod_buffers) for name in mod_params: mapper["params." + name] = name if external_weight_file and not os.path.isfile(external_weight_file): - safetensors.torch.save_file(mod_params, external_weight_file) + if not force_format: + safetensors.torch.save_file(mod_params, external_weight_file) + else: + for x in mod_params.keys(): + mod_params[x] = mod_params[x].numpy() + safe_numpy.save_file(mod_params, external_weight_file) print("Saved params to", external_weight_file) @@ -208,12 +420,18 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) - schedulers["Euler"] = EulerDiscreteScheduler.from_pretrained( + schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained( model_id, subfolder="scheduler", ) - schedulers["EulerA"] = EulerAncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", + schedulers["EulerAncestralDiscrete"] = ( + EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) ) + # schedulers["DPMSolverSDE"] = DPMSolverSDEScheduler.from_pretrained( + # model_id, + # subfolder="scheduler", + # ) return schedulers diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 0916acda0..7ccd12c48 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -7,56 +7,22 @@ import os import sys -from iree import runtime as ireert from iree.compiler.ir import Context import numpy as np from shark_turbine.aot import * from shark_turbine.dynamo.passes import ( DEFAULT_DECOMPOSITIONS, ) +from shark_turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch import torch._dynamo as dynamo +from huggingface_hub import hf_hub_download +from safetensors import safe_open from diffusers import AutoencoderKL import argparse from turbine_models.turbine_tank import turbine_tank -parser = argparse.ArgumentParser() -parser.add_argument( - "--hf_model_name", - type=str, - help="HF model name", - default="CompVis/stable-diffusion-v1-4", -) -parser.add_argument( - "--batch_size", type=int, default=1, help="Batch size for inference" -) -parser.add_argument( - "--height", type=int, default=512, help="Height of Stable Diffusion" -) -parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") -parser.add_argument( - "--precision", type=str, default="fp32", help="Precision of Stable Diffusion" -) -parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") -parser.add_argument("--external_weight_path", type=str, default="") -parser.add_argument( - "--external_weights", - type=str, - default=None, - help="saves ir/vmfb without global weights for size and readability, options [safetensors]", -) -parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") -# TODO: Bring in detection for target triple -parser.add_argument( - "--iree_target_triple", - type=str, - default="", - help="Specify vulkan target triple or rocm/cuda target device.", -) -parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") -parser.add_argument("--variant", type=str, default="decode") - class VaeModel(torch.nn.Module): def __init__( @@ -72,17 +38,19 @@ def __init__( subfolder="vae", ) elif not isinstance(custom_vae, dict): - try: - # custom HF repo with no vae subfolder - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - ) - except: - # some larger repo with vae subfolder - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - subfolder="vae", - ) + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + fp16_weights = hf_hub_download( + repo_id=custom_vae, + filename="vae/vae.safetensors", + ) + with safe_open(fp16_weights, framework="pt", device="cpu") as f: + state_dict = {} + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + self.vae.load_state_dict(state_dict) else: # custom vae as a HF state dict self.vae = AutoencoderKL.from_pretrained( @@ -91,104 +59,225 @@ def __init__( ) self.vae.load_state_dict(custom_vae) - def decode_inp(self, inp): - inp = 1 / 0.18215 * inp + def decode(self, inp): + inp = 1 / self.vae.config.scaling_factor * inp x = self.vae.decode(inp, return_dict=False)[0] return (x / 2 + 0.5).clamp(0, 1) - def encode_inp(self, inp): + def encode(self, inp): latents = self.vae.encode(inp).latent_dist.sample() - return 0.18215 * latents + return self.vae.config.scaling_factor * latents + + +class SD3VaeModel(torch.nn.Module): + def __init__( + self, + hf_model_name, + ): + super().__init__() + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + + def decode(self, inp): + inp = (inp / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(inp, return_dict=False)[0] + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + return image + + def encode(self, inp): + image_np = inp / 255.0 + image_np = np.moveaxis(image_np, 2, 0) + batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0) + image_torch = torch.from_numpy(batch_images) + image_torch = 2.0 * image_torch - 1.0 + image_torch = image_torch + latent = self.vae.encode(image_torch) + return latent def export_vae_model( - vae_model, hf_model_name, batch_size, height, width, precision, compile_to="torch", + num_channels=4, external_weights=None, external_weight_path=None, device=None, - target_triple=None, - max_alloc=None, - variant="decode", + target=None, + ireec_flags=None, + decomp_attn=False, + exit_on_vmfb=False, + pipeline_dir=None, + attn_spec=None, + input_mlir=None, + weights_only=False, upload_ir=False, - decomp_attn=True, ): - mapper = {} - decomp_list = DEFAULT_DECOMPOSITIONS + dtype = torch.float16 if precision == "fp16" else torch.float32 + np_dtype = "float16" if precision == "fp16" else "float32" + safe_name = utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{height}x{width}_{precision}_vae", + ) if decomp_attn: - decomp_list.extend( - [ - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, - torch.ops.aten._scaled_dot_product_flash_attention.default, - ] + safe_name += "_decomp_attn" + elif not attn_spec: + if "gfx9" in target: + attn_spec = "mfma" + elif "gfx11" in target: + attn_spec = "wmma" + + if pipeline_dir: + safe_name = os.path.join(pipeline_dir, safe_name) + + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + attn_spec=attn_spec, ) - dtype = torch.float16 if precision == "fp16" else torch.float32 - vae_model = vae_model.to(dtype) + return vmfb_path + + if "stable-diffusion-3" in hf_model_name: + vae_model = SD3VaeModel(hf_model_name) + else: + if "xl" in hf_model_name.lower() and precision == "fp16": + custom_vae = "amd-shark/sdxl-quant-models" + else: + custom_vae = None + vae_model = VaeModel(hf_model_name, custom_vae=custom_vae) + + if dtype == torch.float16: + vae_model = vae_model.half() + mapper = {} utils.save_external_weights( mapper, vae_model, external_weights, external_weight_path ) + if weights_only: + return external_weight_path - sample = (batch_size, 4, height // 8, width // 8) - if variant == "encode": - sample = (batch_size, 3, height, width) - - class CompiledVae(CompiledModule): - params = export_parameters(vae_model) - - def main(self, inp=AbstractTensor(*sample, dtype=dtype)): - if variant == "decode": - return jittable(vae_model.decode_inp, decompose_ops=decomp_list)(inp) - elif variant == "encode": - return jittable(vae_model.encode_inp, decompose_ops=decomp_list)(inp) - - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledVae(context=Context(), import_to=import_to) - - module_str = str(CompiledModule.get_mlir_module(inst)) - safe_name = utils.create_safe_name(hf_model_name, "-vae") - if upload_ir: - with open(f"{safe_name}.mlir", "w+") as f: - f.write(module_str) - model_name_upload = hf_model_name.replace("/", "_") - model_name_upload = model_name_upload + "-vae-" + variant - blob_name = turbine_tank.uploadToBlobStorage( - str(os.path.abspath(f"{safe_name}.mlir")), - f"{model_name_upload}/{model_name_upload}.mlir", + input_image_shape = (height, width, 3) + input_latents_shape = (batch_size, num_channels, height // 8, width // 8) + encode_args = [ + torch.empty( + input_image_shape, + dtype=torch.float32, + ) + ] + decode_args = [ + torch.empty( + input_latents_shape, + dtype=dtype, ) + ] + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + fxb = FxProgramsBuilder(vae_model) + + # TODO: fix issues with exporting the encode function. + # @fxb.export_program(args=(encode_args,)) + # def _encode(module, inputs,): + # return module.encode(*inputs) + + @fxb.export_program(args=(decode_args,)) + def _decode(module, inputs): + return module.decode(*inputs) + + class CompiledVae(CompiledModule): + decode = _decode + + if external_weights: + externalize_module_parameters(vae_model) + + inst = CompiledVae(context=Context(), import_to="IMPORT") + + module = CompiledModule.get_mlir_module(inst) + + model_metadata_decode = { + "model_name": "vae_decode", + "input_shapes": [input_latents_shape], + "input_dtypes": [np_dtype], + "output_shapes": [(3, width, height) * batch_size], + "output_dtypes": ["float32"], + } + model_metadata_encode = { + "model_name": "vae_encode", + "input_shapes": [input_image_shape], + "input_dtypes": [np_dtype], + "output_shapes": [input_latents_shape], + "output_dtypes": [np_dtype], + } + module = AddMetadataPass(module, model_metadata_decode, "decode").run() + if compile_to != "vmfb": - return module_str + return str(module) else: - utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) - if upload_ir: - return blob_name + vmfb_path = utils.compile_to_vmfb( + str(module), + device, + target, + ireec_flags, + safe_name, + return_path=not exit_on_vmfb, + attn_spec=attn_spec, + ) + return vmfb_path if __name__ == "__main__": - args = parser.parse_args() - vae_model = VaeModel( - args.hf_model_name, - ) + from turbine_models.custom_models.sd_inference.sd_cmd_opts import args + + if args.input_mlir: + vae_model = None + else: + vae_model = VaeModel( + args.hf_model_name, + custom_vae=None, + ) mod_str = export_vae_model( vae_model, args.hf_model_name, args.batch_size, - args.height, - args.width, - args.precision, - args.compile_to, - args.external_weights, - args.external_weight_path, - args.device, - args.iree_target_triple, - args.vulkan_max_allocation, - args.variant, + height=args.height, + width=args.width, + precision=args.precision, + compile_to=args.compile_to, + external_weights=args.external_weights, + external_weight_path=args.external_weight_path, + device=args.device, + target=args.iree_target_triple, + ireec_flags=args.ireec_flags + args.attn_flags + args.vae_flags, + variant=args.vae_variant, + decomp_attn=args.decomp_attn, + attn_spec=args.attn_spec, + input_mlir=args.input_mlir, + ) + if args.input_mlir or (args.compile_to == "vmfb"): + exit() + safe_name = utils.create_safe_name( + args.hf_model_name, + f"_bs{str(args.batch_size)}_{args.height}x{args.width}_{args.precision}_vae_{args.vae_variant}", ) - safe_name = utils.create_safe_name(args.hf_model_name, "-vae") with open(f"{safe_name}.mlir", "w+") as f: f.write(mod_str) print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sd_inference/vae_runner.py b/models/turbine_models/custom_models/sd_inference/vae_runner.py index cce53c118..166021631 100644 --- a/models/turbine_models/custom_models/sd_inference/vae_runner.py +++ b/models/turbine_models/custom_models/sd_inference/vae_runner.py @@ -4,56 +4,20 @@ from iree import runtime as ireert import torch -parser = argparse.ArgumentParser() - -# TODO move common runner flags to generic flag file -parser.add_argument( - "--vmfb_path", type=str, default="", help="path to vmfb containing compiled module" -) -parser.add_argument( - "--external_weight_path", - type=str, - default="", - help="path to external weight parameters if model compiled without them", -) -parser.add_argument( - "--compare_vs_torch", - action="store_true", - help="Runs both turbine vmfb and a torch model to compare results", -) -parser.add_argument( - "--hf_model_name", - type=str, - help="HF model name", - default="CompVis/stable-diffusion-v1-4", -) -parser.add_argument( - "--device", - type=str, - default="local-task", - help="local-sync, local-task, cuda, vulkan, rocm", -) -parser.add_argument( - "--batch_size", type=int, default=1, help="Batch size for inference" -) -parser.add_argument( - "--height", type=int, default=512, help="Height of Stable Diffusion" -) -parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") -parser.add_argument("--variant", type=str, default="decode") - - -def run_vae(device, example_input, vmfb_path, hf_model_name, external_weight_path): + +def run_vae_decode( + device, example_input, vmfb_path, hf_model_name, external_weight_path +): runner = vmfbRunner(device, vmfb_path, external_weight_path) inputs = [ireert.asdevicearray(runner.config.device, example_input)] - results = runner.ctx.modules.compiled_vae["main"](*inputs).to_host() + results = runner.ctx.modules.compiled_vae["decode"](*inputs).to_host() return results -def run_torch_vae(hf_model_name, variant, example_input): +def run_torch_vae_decode(hf_model_name, variant, example_input): from diffusers import AutoencoderKL class VaeModel(torch.nn.Module): @@ -114,7 +78,8 @@ def encode_inp(self, inp): if __name__ == "__main__": - args = parser.parse_args() + from turbine_models.custom_models.sd_inference.sd_cmd_opts import args + if args.variant == "decode": example_input = torch.rand( args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32 @@ -124,7 +89,7 @@ def encode_inp(self, inp): args.batch_size, 3, args.height, args.width, dtype=torch.float32 ) print("generating turbine output:") - turbine_results = run_vae( + turbine_results = run_vae_decode( args.device, example_input, args.vmfb_path, @@ -141,7 +106,7 @@ def encode_inp(self, inp): print("generating torch output: ") from turbine_models.custom_models.sd_inference import utils - torch_output = run_torch_vae( + torch_output = run_torch_vae_decode( args.hf_model_name, args.hf_auth_token, args.variant, example_input ) print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) diff --git a/models/turbine_models/custom_models/sdxl_inference/README.md b/models/turbine_models/custom_models/sdxl_inference/README.md index 19783c146..f6cea1b21 100644 --- a/models/turbine_models/custom_models/sdxl_inference/README.md +++ b/models/turbine_models/custom_models/sdxl_inference/README.md @@ -1,29 +1,86 @@ -# Stable Diffusion Commands +# Stable Diffusion XL with SHARK-Turbine -## Run and benchmark the entire SDXL pipeline on MI300 - - note: the command below is specifically for use on the ppac-pla-s22-35 instance. you may need to tweak paths accordingly. - - follow "setup repository" in the next section - - optional: set HF_HOME to save dl time/ disk usage +## Support + +Following is a table that shows current status of turbine SDXL inference support for a few AMDGPU targets. This is not an exhaustive list of supported targets. + +| Target Chip | Attention Decomposed? | CLIP | UNet | VAE Decode | Txt2Img | +|-------------|-----------------------|---------------|--------------------------------|--------------------------------|----------------| +| gfx1100 | Yes | 💚 | 💛 (numerics with vector distribution)| 💚 | 💚 | +| | No | | 💔 (Attn lowering) | 💔 (Attn lowering) | 💔 | +| gfx90a | Yes | 💚 | 💚 | 💚 | 💚 | +| | No | | 💛 (Numerics with mfma) | 💚 | 💛 | +| gfx942 | Yes | 💚 | 💚 | 💚 | 💚 | +| | No | | 💚 | 💚 | 💚 | + +## Setup SHARK-Turbine for importing and running the SDXL pipeline or submodels. + +Linux: +```shell +python -m venv turbine_venv +source turbine_venv/bin/activate +python -m pip install --upgrade pip +cd .. +git clone https://iree-org/iree-turbine +cd iree-turbine +pip install -r pytorch-cpu-requirements.txt +pip install -e . +cd ../SHARK-Turbine +pip install --pre --upgrade -e models -r models/requirements.txt +``` + +Windows: +```shell +python -m venv turbine_venv +turbine_venv/Scripts/activate +python -m pip install --upgrade pip +cd .. +git clone https://iree-org/iree-turbine +cd iree-turbine +pip install -r pytorch-cpu-requirements.txt +pip install -e . +cd ../SHARK-Turbine +pip install --pre --upgrade -e models -r models/requirements.txt ``` -export HF_HOME=/mnt/dcgpuval/huggingface/ #ppac -export HF_HOME=/data/huggingface-cache #banff + +## Run tests +ROCM: +``` +pytest models/turbine_models/tests/sdxl_test.py --device=rocm --rt_device= --iree_target_triple=gfx --external_weights=safetensors ``` - - make sure you have ROCM working with IREE, check `iree-run-module --dump_devices` - - make a file called "mfma_spec.mlir" and drop in the contents of the TD script https://github.com/nod-ai/2024-q1-sdxl-sprint/tree/main/specs. -### Newest pipeline command, weights (as of [SHARK-Turbine@ean-sd-fp16:6251fbef9233c406093dab056a08cd42cfc54a0b](https://github.com/nod-ai/SHARK-Turbine/commit/6251fbef9233c406093dab056a08cd42cfc54a0b)): +CPU: +``` +pytest models/turbine_models/tests/sdxl_test.py --device=cpu --rt_device=local-task --iree_target_triple=x86-64_linux_gnu --external_weights=safetensors --precision=fp32 +``` +## Run image generation pipeline -gfx940: +ROCM: ``` -python SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py --precision=fp16 --external_weights=safetensors --device=rocm --rt_device=rocm --iree_target_triple=gfx942 --scheduler_id=PNDM --num_inference_steps=30 --pipeline_dir=./sdxl_fp16_1024x1024_gfx940/ --external_weights_dir=./weights_fp16/ --attn_spec=default +python models\turbine_models\custom_models\sdxl_inference\sdxl_compiled_pipeline.py --iree_target_triple=gfx1100 --device=rocm --rt_device=hip --external_weights=safetensors ``` +For mfma-capable hardware, use `--attn_spec=default` to lower attention ops to MFMA instructions. -gfx942: +CPU: ``` -python SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py --precision=fp16 --external_weights=safetensors --device=rocm --rt_device=rocm --iree_target_triple=gfx940 --scheduler_id=PNDM --num_inference_steps=30 --pipeline_dir=./sdxl_fp16_1024x1024_gfx940/ --external_weights_dir=./weights_fp16/ --attn_spec=default +pytest models/turbine_models/tests/sdxl_test.py --device=cpu --rt_device=local-task --iree_target_triple=x86-64_linux_gnu --external_weights=safetensors --precision=fp32 ``` +## Shared CLI options + - `--iree_target_triple`: use gfx1100 for 7900xt, gfx90a for MI210/MI250, gfx940 for MI300A, gfx942 for MI300X. For CPU, use x86_64-linux-gnu if you aren't sure. On Vulkan, this is something like `rdna3-7900-windows`. + - `--rt_device`: if using pip install, `hip` will work correctly, but `rocm` will not. Source builds of IREE can support rocm with the `-DIREE_HAL_DRIVER_ROCM=ON -DIREE_EXTERNAL_HAL_DRIVERS="rocm"`, but that option is soon to be deprecated in favor of the HIP driver. + - `--compiled_pipeline`: run one-shot SDXL in a MLIR wrapper, removing model glue from python execution layer + - `--pipeline_dir`: directory in which to save or look for .vmfb files. + - `--external_weights_dir`: directory in which to save or look for weights. + - `--ireec_flags`: extra ireec flags to use for _all_ submodels. + - `--unet_flags / --vae_flags / --clip_flags`: extra ireec flags for individual submodels. + - `--precision`: fp16 or fp32. Default is fp16 and you should only use fp32 for cpu. + - `--num_inference_steps`: (default 30) number of unet iterations to run. + - `--batch_count`: Not compatible with `--compiled_pipeline`. Uses the same clip output to generate a set of images in a batch, with different image latents. + - `--prompt / --negative_prompt`: prompts for stable diffusion inference + + Note: the following "prompt_encoder_f16.irpa" contains weights for both clip1 and clip2. The pipeline script will look for these filenames in the specified "external_weights_dir" under "prompt_encoder.irpa", "vae_decode.irpa", "scheduled_unet.irpa". It's not ideal in current state, but will be smoothed out now that general pipeline structure and file management needs are stable. diff --git a/models/turbine_models/custom_models/sdxl_inference/clip.py b/models/turbine_models/custom_models/sdxl_inference/clip.py index 20b0aa7ae..2740745ed 100644 --- a/models/turbine_models/custom_models/sdxl_inference/clip.py +++ b/models/turbine_models/custom_models/sdxl_inference/clip.py @@ -67,7 +67,7 @@ def export_clip_model( safe_name = os.path.join(pipeline_dir, "clip_" + str(index)) else: safe_name = utils.create_safe_name( - hf_model_name, f"-{str(max_length)}-{precision}-clip-{index}-{device}" + hf_model_name, f"_{str(max_length)}-{precision}-clip-{index}-{device}" ) if input_mlir: vmfb_path = utils.compile_to_vmfb( diff --git a/models/turbine_models/custom_models/sdxl_inference/pipeline_ir.py b/models/turbine_models/custom_models/sdxl_inference/pipeline_ir.py new file mode 100644 index 000000000..cb2b62bea --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/pipeline_ir.py @@ -0,0 +1,82 @@ +tokens_to_image = r""" +module @sdxl_compiled_pipeline {{ + func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x6x{precision}>, tensor) attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}]"}} + func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>, %arg1: tensor<{bd}x{max_length}x2048x{precision}>, %arg2: tensor<{bd}x1280x{precision}>, %arg3: tensor<{bd}x6x{precision}>, %arg4: tensor<1x{precision}>, %arg5: tensor<1xi64>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]"}} + func.func private @compiled_clip.encode_prompts(%arg0: tensor<{batch_size}x{max_length}xi64>, %arg1: tensor<{batch_size}x{max_length}xi64>, %arg2: tensor<{batch_size}x{max_length}xi64>, %arg3: tensor<{batch_size}x{max_length}xi64>) -> (tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>) attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}]"}} + func.func private @{vae_fn_name}.main(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> tensor<{batch_size}x3x{width}x{height}x{precision}> attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]"}} + + func.func @tokens_to_image(%sample: tensor<{batch_size}x4x{lw}x{lh}x{precision}>, %guidance_scale: tensor<1x{precision}>, %t_ids_1: tensor<{batch_size}x{max_length}xi64>, %t_ids_2: tensor<{batch_size}x{max_length}xi64>, %u_ids_1: tensor<{batch_size}x{max_length}xi64>, %u_ids_2: tensor<{batch_size}x{max_length}xi64>) -> tensor<{batch_size}x3x{width}x{height}x{precision}> {{ + %p_embeds, %t_embeds = func.call @compiled_clip.encode_prompts(%t_ids_1, %t_ids_2, %u_ids_1, %u_ids_2) : (tensor<{batch_size}x{max_length}xi64>, tensor<{batch_size}x{max_length}xi64>, tensor<{batch_size}x{max_length}xi64>, tensor<{batch_size}x{max_length}xi64>) -> (tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>) + %noisy_sample, %time_ids, %steps = func.call @compiled_scheduled_unet.run_initialize(%sample) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x6x{precision}>, tensor) + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %steps_int = tensor.extract %steps[] : tensor + %n_steps = arith.index_cast %steps_int: i64 to index + %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) {{ + %step_64 = arith.index_cast %arg0 : index to i64 + %this_step = tensor.from_elements %step_64 : tensor<1xi64> + %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<1xi64>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> + scf.yield %inner : tensor<{batch_size}x4x{lw}x{lh}x{precision}> + }} + %image = func.call @{vae_fn_name}.main(%res): (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> tensor<{batch_size}x3x{width}x{height}x{precision}> + return %image : tensor<{batch_size}x3x{width}x{height}x{precision}> + }} +}} +""" + +unet_loop = r""" +module @sdxl_compiled_pipeline {{ + func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x6x{precision}>, tensor) attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}]"}} + func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<{batch_size}x4x{lw}x{lh}x{precision}>, %arg1: tensor<{bd}x{max_length}x2048x{precision}>, %arg2: tensor<{bd}x1280x{precision}>, %arg3: tensor<{bd}x6x{precision}>, %arg4: tensor<1x{precision}>, %arg5: tensor<1xi64>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> attributes {{torch.args_schema = "[1, {{\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]}}, {{\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}}]}}]", torch.return_schema = "[1, {{\22type\22: null, \22context\22: null, \22children_spec\22: []}}]"}} + + func.func @produce_image_latents(%sample: tensor<{batch_size}x4x{lw}x{lh}x{precision}>, %p_embeds: tensor<{bd}x{max_length}x2048x{precision}>, %t_embeds: tensor<{bd}x1280x{precision}>, %guidance_scale: tensor<1x{precision}>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> {{ + %noisy_sample, %time_ids, %steps = func.call @compiled_scheduled_unet.run_initialize(%sample) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x6x{precision}>, tensor) + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %steps_int = tensor.extract %steps[] : tensor + %n_steps = arith.index_cast %steps_int: i64 to index + %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<{batch_size}x4x{lw}x{lh}x{precision}>) {{ + %step_64 = arith.index_cast %arg0 : index to i64 + %this_step = tensor.from_elements %step_64 : tensor<1xi64> + %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<{batch_size}x4x{lw}x{lh}x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<1xi64>) -> tensor<{batch_size}x4x{lw}x{lh}x{precision}> + scf.yield %inner : tensor<{batch_size}x4x{lw}x{lh}x{precision}> + }} + return %res : tensor<{batch_size}x4x{lw}x{lh}x{precision}> + }} +}} +""" + + +def get_pipeline_ir( + width: int, + height: int, + precision: str, + batch_size: int, + max_length: int, + type: str, + vae_fn_name: str = "compiled_vae", +): + precision = "f32" if precision == "fp32" else "f16" + if type == "tokens_to_image": + return tokens_to_image.format( + width=width, + height=height, + lw=int(width / 8), + lh=int(height / 8), + bd=int(batch_size * 2), + precision=precision, + batch_size=batch_size, + max_length=max_length, + vae_fn_name=vae_fn_name, + ) + elif type == "unet_loop": + return unet_loop.format( + width=width, + height=height, + lw=int(width / 8), + lh=int(height / 8), + bd=int(batch_size * 2), + precision=precision, + batch_size=batch_size, + max_length=max_length, + ) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py index f2faa0323..368fb0d74 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py @@ -116,6 +116,27 @@ def is_valid_file(arg): help="path to vmfb containing compiled meta-module", ) +p.add_argument( + "--scheduler_vmfb_path", + type=str, + default="", + help="path to vmfb containing compiled scheduler", +) + +p.add_argument( + "--split_scheduler", + default=False, + action="store_true", + help="Use a decoupled unet and scheduler for better QOL.", +) + +p.add_argument( + "--cpu_scheduling", + default=False, + action="store_true", + help="Run scheduling on torch cpu (will be slower due to data movement costs).", +) + p.add_argument( "--external_weight_file", type=str, @@ -137,6 +158,62 @@ def is_valid_file(arg): help="Do one-shot inference from tokens to image in a shrink-wrapped pipeline binary.", ) +p.add_argument( + "--vae_precision", + type=str, + default="fp16", + help="Precision of VAE weights and graph.", +) + +p.add_argument( + "--npu_delegate_path", + type=str, + default=None, + help="Path to npu executable plugin .dll for running VAE on NPU.", +) + +p.add_argument( + "--clip_device", + default=None, + type=str, + help="Device to run CLIP on. If None, defaults to the device specified in args.device.", +) + +p.add_argument( + "--unet_device", + default=None, + type=str, + help="Device to run unet on. If None, defaults to the device specified in args.device.", +) + +p.add_argument( + "--vae_device", + default=None, + type=str, + help="Device to run VAE on. If None, defaults to the device specified in args.device.", +) + +p.add_argument( + "--clip_target", + default=None, + type=str, + help="IREE target for CLIP compilation. If None, defaults to the target specified by --iree_target_triple.", +) + +p.add_argument( + "--unet_target", + default=None, + type=str, + help="IREE target for unet compilation. If None, defaults to the target specified by --iree_target_triple.", +) + +p.add_argument( + "--vae_target", + default=None, + type=str, + help="IREE target for vae compilation. If None, defaults to the target specified by --iree_target_triple.", +) + ############################################################################## # SDXL Modelling Options # These options are used to control model defining parameters for SDXL. @@ -145,6 +222,13 @@ def is_valid_file(arg): ############################################################################## p.add_argument("--batch_size", type=int, default=1, help="Batch size for inference") +p.add_argument( + "--batch_prompt_input", + type=bool, + default=False, + help="If batch size > 1 this enables batching the prompt encoder input rather than concating prompt encoders output", +) + p.add_argument( "--height", type=int, default=1024, help="Height of Stable Diffusion output image." ) @@ -170,7 +254,7 @@ def is_valid_file(arg): p.add_argument( "--vae_decomp_attn", type=bool, - default=True, + default=False, help="Decompose attention for VAE decode only at fx graph level", ) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py index f17a17f60..ec88c525d 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py @@ -1,4 +1,4 @@ -# Copyright 2023 Nod Labs, Inc +# Copyright 2024 Advanced Micro Devices, inc. # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. @@ -10,9 +10,13 @@ sdxl_prompt_encoder, sdxl_scheduled_unet, vae, + unet, ) import iree.runtime as ireert -from turbine_models.custom_models.sd_inference import utils +from turbine_models.custom_models.sd_inference import utils, schedulers +from turbine_models.custom_models.sdxl_inference.pipeline_ir import ( + get_pipeline_ir, +) from turbine_models.utils.sdxl_benchmark import run_benchmark from turbine_models.model_runner import vmfbRunner from transformers import CLIPTokenizer @@ -21,29 +25,23 @@ import os import numpy as np import time +import copy from datetime import datetime as dt -device_list = [ - "cpu", - "vulkan", - "cuda", - "rocm", -] - -rt_device_list = [ - "local-task", - "local-sync", - "vulkan", - "cuda", - "rocm", -] - empty_pipe_dict = { "vae_decode": None, "prompt_encoder": None, "scheduled_unet": None, - "pipeline": None, - "full_pipeline": None, + "unetloop": None, + "fullpipeline": None, +} + +EMPTY_FLAGS = { + "clip": None, + "unet": None, + "vae": None, + "unetloop": None, + "fullpipeline": None, } @@ -51,22 +49,26 @@ class SharkSDXLPipeline: def __init__( self, hf_model_name: str, - scheduler_id: str, height: int, width: int, precision: str, max_length: int, batch_size: int, num_inference_steps: int, - device: str, - iree_target_triple: str, - ireec_flags: dict, + device: str | dict[str], + iree_target_triple: str | dict[str], + scheduler_id: str = "EulerDiscrete", + ireec_flags: dict = EMPTY_FLAGS, attn_spec: str = None, decomp_attn: bool = False, pipeline_dir: str = "./shark_vmfbs", external_weights_dir: str = "./shark_weights", external_weights: str = "safetensors", - vae_decomp_attn: bool = True, + vae_decomp_attn: bool = False, + custom_vae: str = "", + cpu_scheduling: bool = False, + vae_precision: str = "fp32", + batch_prompt_input: bool = False, ): self.hf_model_name = hf_model_name self.scheduler_id = scheduler_id @@ -75,16 +77,61 @@ def __init__( self.precision = precision self.max_length = max_length self.batch_size = batch_size + self.batch_prompt_input = batch_prompt_input self.num_inference_steps = num_inference_steps - self.device = device - self.iree_target_triple = iree_target_triple - self.ireec_flags = ireec_flags + self.devices = {} + if isinstance(device, dict): + assert isinstance( + iree_target_triple, dict + ), "Device and target triple must be both dicts or both strings." + self.devices["clip"] = { + "device": device["clip"], + "driver": utils.iree_device_map(device["clip"]), + "target": iree_target_triple["clip"], + } + self.devices["unet"] = { + "device": device["unet"], + "driver": utils.iree_device_map(device["unet"]), + "target": iree_target_triple["unet"], + } + self.devices["vae"] = { + "device": device["vae"], + "driver": utils.iree_device_map(device["vae"]), + "target": iree_target_triple["vae"], + } + else: + assert isinstance( + iree_target_triple, str + ), "Device and target triple must be both dicts or both strings." + self.devices["unet"] = { + "device": device, + "driver": utils.iree_device_map(device), + "target": iree_target_triple, + } + self.devices["clip"] = self.devices["unet"] + self.devices["vae"] = self.devices["unet"] + self.ireec_flags = ireec_flags if ireec_flags else EMPTY_FLAGS self.attn_spec = attn_spec self.decomp_attn = decomp_attn self.pipeline_dir = pipeline_dir self.external_weights_dir = external_weights_dir self.external_weights = external_weights self.vae_decomp_attn = vae_decomp_attn + self.vae_precision = vae_precision + self.vae_dtype = "float32" if vae_precision == "fp32" else "float16" + self.custom_vae = custom_vae + if self.custom_vae: + self.vae_dir = os.path.join( + self.pipeline_dir, utils.create_safe_name(custom_vae, "") + ) + if not os.path.exists(self.vae_dir): + os.makedirs(self.vae_dir) + self.cpu_scheduling = cpu_scheduling + self.compiled_pipeline = False + self.split_scheduler = False + # TODO: set this based on user-inputted guidance scale and negative prompt. + self.do_classifier_free_guidance = True # False if any(x in hf_model_name for x in ["turbo", "lightning"]) else True + self._interrupt = False # FILE MANAGEMENT AND PIPELINE SETUP @@ -108,53 +155,95 @@ def check_prepared( if do_continue.lower() == "y": for submodel in vmfbs.keys(): if vmfbs[submodel] == None: + print("Fetching: ", submodel) vmfb, weight = self.export_submodel(submodel, input_mlir=mlirs) + if self._interrupt: + self._interrupt = False + return None, None vmfbs[submodel] = vmfb if weights[submodel] is None: weights[submodel] = weight - elif weights[submodel] is None and "pipeline" not in submodel: + elif weights[submodel] is None and not any( + x in submodel for x in ["unetloop", "scheduler"] + ): _, weight = self.export_submodel(submodel, weights_only=True) weights[submodel] = weight ready, vmfbs, weights = self.is_prepared(vmfbs, weights) if ready: - print("All necessary files found. Generating images.") + print("All necessary files found.") return vmfbs, weights else: print("There was an error generating the necessary files.") exit() else: - print("All necessary files found. Generating images.") + print("All necessary files found.") return vmfbs, weights def is_prepared(self, vmfbs, weights): missing = [] + dims = f"{str(self.width)}x{str(self.height)}" + pipeline_dir = self.pipeline_dir for key in vmfbs: if key == "scheduled_unet": - val = f"{self.scheduler_id}_unet_{self.num_inference_steps}" - default_filepath = os.path.join(self.pipeline_dir, val + ".vmfb") - else: - val = vmfbs[key] - default_filepath = os.path.join(self.pipeline_dir, key + ".vmfb") - if vmfbs[key] is not None and os.path.exists(vmfbs[key]): + keywords = [ + "DiffusionModule", + self.scheduler_id, + str(self.num_inference_steps), + self.precision, + self.max_length, + dims, + ] + device_key = "unet" + elif key == "scheduler": continue - elif vmfbs[key] == None and os.path.exists(default_filepath): - vmfbs[key] = default_filepath - elif val is None: - missing.append(key + ".vmfb") + elif key == "vae_decode": + keywords = ["vae", self.vae_precision, dims] + device_key = "vae" + if self.custom_vae: + pipeline_dir = self.vae_dir + elif key == "prompt_encoder": + keywords = ["prompt_encoder", self.precision, self.max_length] + device_key = "clip" else: - missing.append(val + ".vmfb") + keywords = [key, self.precision, self.max_length, dims] + device_key = "unet" + keywords.extend( + [ + utils.create_safe_name(self.hf_model_name.split("/")[-1], ""), + "vmfb", + "bs" + str(self.batch_size), + self.devices[device_key]["target"], + ] + ) + avail_files = os.listdir(pipeline_dir) + for filename in avail_files: + if all(str(x) in filename for x in keywords): + vmfbs[key] = os.path.join(pipeline_dir, filename) + if not vmfbs[key]: + missing.append(key + " vmfb") + for w_key in weights: - if "pipeline" in w_key: - continue - if weights[w_key] is not None and os.path.exists(weights[w_key]): + if any(x in w_key for x in ["fullpipeline", "unetloop", "scheduler"]) or ( + self.external_weights is None + ): continue - default_name = os.path.join( - self.external_weights_dir, w_key + "." + self.external_weights - ) - if weights[w_key] is None and os.path.exists(default_name): - weights[w_key] = os.path.join(default_name) - else: - missing.append(w_key + "." + self.external_weights) + elif weights[w_key] is not None: + print("Weights already found for ", w_key, "at: ", weights[w_key]) + elif w_key == "vae_decode": + keywords = ["vae", self.vae_precision] + elif w_key in ["prompt_encoder", "clip"]: + keywords = ["prompt_encoder", self.precision] + elif w_key in ["scheduled_unet", "unet"]: + keywords = ["unet", self.precision] + avail_weights = os.listdir(self.external_weights_dir) + for filename in avail_weights: + if all(str(x) in filename for x in keywords): + weights[w_key] = os.path.join(self.external_weights_dir, filename) + if not weights[w_key]: + missing.append( + " ".join([keywords[0], keywords[1], self.external_weights]) + ) + if len(missing) > 0: print(f"Missing files: " + ", ".join(missing)) return False, vmfbs, weights @@ -197,11 +286,18 @@ def get_torch_models(self, submodel): self.hf_model_name, custom_vae=( "madebyollin/sdxl-vae-fp16-fix" - if self.precision == "fp16" - else None + if self.vae_precision == "fp16" + else self.custom_vae ), ) return vae_torch + case "unet": + unet_torch = unet.UnetModel( + self.hf_model_name, + None, + self.precision, + ) + return unet_torch def export_submodel( self, @@ -209,19 +305,24 @@ def export_submodel( input_mlir: str = None, weights_only: bool = False, ): + if self._interrupt: + return None, None if not os.path.exists(self.pipeline_dir): os.makedirs(self.pipeline_dir) - if self.external_weights_dir: + if self.external_weights and self.external_weights_dir: if not os.path.exists(self.external_weights_dir): - os.makedirs(external_weights_dir, exist_ok=True) + os.makedirs(self.external_weights_dir, exist_ok=True) vae_external_weight_path = os.path.join( - self.external_weights_dir, "vae_decode." + self.external_weights + self.external_weights_dir, + f"vae_decode_{self.vae_precision}." + self.external_weights, ) unet_external_weight_path = os.path.join( - self.external_weights_dir, "scheduled_unet." + self.external_weights + self.external_weights_dir, + f"unet_{self.precision}." + self.external_weights, ) prompt_encoder_external_weight_path = os.path.join( - self.external_weights_dir, "prompt_encoder." + self.external_weights + self.external_weights_dir, + f"prompt_encoder_{self.precision}." + self.external_weights, ) elif self.external_weights is None: print( @@ -234,25 +335,27 @@ def export_submodel( print( f"No external weights directory specified using --external_weights_dir, we assume you have your own weights in {self.pipeline_dir}." ) - external_weights_dir = self.pipeline_dir if not os.path.exists(self.pipeline_dir): os.makedirs(self.pipeline_dir, exist_ok=True) vae_external_weight_path = os.path.join( - self.pipeline_dir, "vae_decode." + self.external_weights + self.pipeline_dir, + f"vae_decode_{self.vae_precision}." + self.external_weights, ) unet_external_weight_path = os.path.join( - self.pipeline_dir, "scheduled_unet." + self.external_weights + self.pipeline_dir, f"unet_{self.precision}." + self.external_weights ) prompt_encoder_external_weight_path = os.path.join( - self.pipeline_dir, "prompt_encoder." + self.external_weights + self.pipeline_dir, + f"prompt_encoder_{self.precision}." + self.external_weights, ) if weights_only: input_mlir = { "vae_decode": None, "prompt_encoder": None, "scheduled_unet": None, - "pipeline": None, - "full_pipeline": None, + "unet": None, + "unetloop": None, + "fullpipeline": None, } match submodel: case "scheduled_unet": @@ -274,8 +377,8 @@ def export_submodel( "vmfb", self.external_weights, unet_external_weight_path, - self.device, - self.iree_target_triple, + self.devices["unet"]["driver"], + self.devices["unet"]["target"], self.ireec_flags["unet"], self.decomp_attn, exit_on_vmfb=False, @@ -284,33 +387,93 @@ def export_submodel( input_mlir=input_mlir["scheduled_unet"], weights_only=weights_only, ) + del scheduled_unet_torch + return unet_vmfb, unet_external_weight_path + case "unet": + if not input_mlir[submodel]: + unet_torch = self.get_torch_models("unet") + else: + unet_torch = None + unet_vmfb = unet.export_unet_model( + unet_torch, + self.hf_model_name, + self.batch_size, + self.height, + self.width, + self.precision, + self.max_length, + None, + "vmfb", + self.external_weights, + unet_external_weight_path, + self.devices["unet"]["driver"], + self.devices["unet"]["target"], + self.ireec_flags["unet"], + self.decomp_attn, + exit_on_vmfb=False, + pipeline_dir=self.pipeline_dir, + attn_spec=self.attn_spec, + input_mlir=input_mlir["unet"], + weights_only=weights_only, + ) + del unet_torch return unet_vmfb, unet_external_weight_path + case "scheduler": + if self.cpu_scheduling: + return None, None + else: + scheduler_vmfb = schedulers.export_scheduler_model( + self.hf_model_name, + self.scheduler_id, + self.batch_size, + self.height, + self.width, + self.num_inference_steps, + self.precision, + "vmfb", + self.devices["unet"]["driver"], + self.devices["unet"]["target"], + self.ireec_flags["scheduler"], + exit_on_vmfb=False, + pipeline_dir=self.pipeline_dir, + input_mlir=None, + ) + return scheduler_vmfb, None case "vae_decode": if not input_mlir[submodel]: vae_torch = self.get_torch_models("vae_decode") else: vae_torch = None + if self.custom_vae: + vae_external_weight_path = os.path.join( + self.vae_dir, + f"vae_decode_{self.vae_precision}." + self.external_weights, + ) + vae_dir = self.vae_dir + else: + vae_dir = self.pipeline_dir vae_decode_vmfb = vae.export_vae_model( vae_torch, self.hf_model_name, self.batch_size, self.height, self.width, - self.precision, + self.vae_precision, "vmfb", self.external_weights, vae_external_weight_path, - self.device, - self.iree_target_triple, + self.devices["vae"]["driver"], + self.devices["vae"]["target"], self.ireec_flags["vae"], "decode", self.vae_decomp_attn, exit_on_vmfb=False, - pipeline_dir=self.pipeline_dir, + pipeline_dir=vae_dir, attn_spec=self.attn_spec, input_mlir=input_mlir["vae_decode"], weights_only=weights_only, ) + del vae_torch return vae_decode_vmfb, vae_external_weight_path case "prompt_encoder": _, prompt_encoder_vmfb = sdxl_prompt_encoder.export_prompt_encoder( @@ -321,52 +484,72 @@ def export_submodel( "vmfb", self.external_weights, prompt_encoder_external_weight_path, - self.device, - self.iree_target_triple, + self.devices["clip"]["driver"], + self.devices["clip"]["target"], self.ireec_flags["clip"], exit_on_vmfb=False, - pipeline_dir=self.pipeline_dir, + pipeline_dir=( + self.pipeline_dir if not self.custom_vae else self.vae_dir + ), input_mlir=input_mlir["prompt_encoder"], attn_spec=self.attn_spec, weights_only=weights_only, + batchsize=self.batch_size, + batch_input=self.batch_prompt_input, ) return prompt_encoder_vmfb, prompt_encoder_external_weight_path - case "pipeline": - pipeline_file = ( - "sdxl_sched_unet_bench_" + "f32" - if self.precision == "fp32" - else "sdxl_sched_unet_bench_" + "f16" + case "unetloop": + pipeline_file = get_pipeline_ir( + self.width, + self.height, + self.precision, + self.batch_size, + self.max_length, + "unet_loop", ) + pipeline_keys = [ + utils.create_safe_name(self.hf_model_name.split("/")[-1], ""), + "bs" + str(self.batch_size), + f"{str(self.width)}x{str(self.height)}", + self.precision, + str(self.max_length), + "unetloop", + ] pipeline_vmfb = utils.compile_to_vmfb( - os.path.join( - os.path.realpath(os.path.dirname(__file__)), - pipeline_file + ".mlir", - ), - self.device, - self.iree_target_triple, - self.ireec_flags["pipeline"], - os.path.join(self.pipeline_dir, "pipeline"), + pipeline_file, + self.devices["unet"]["driver"], + self.devices["unet"]["target"], + self.ireec_flags["unetloop"], + os.path.join(self.pipeline_dir, "_".join(pipeline_keys)), return_path=True, - mlir_source="file", + mlir_source="str", ) return pipeline_vmfb, None - case "full_pipeline": - pipeline_file = ( - "sdxl_pipeline_bench_" + "f32" - if self.precision == "fp32" - else "sdxl_pipeline_bench_" + "f16" + case "fullpipeline": + pipeline_file = get_pipeline_ir( + self.width, + self.height, + self.precision, + self.batch_size, + self.max_length, + "tokens_to_image", ) + pipeline_keys = [ + utils.create_safe_name(self.hf_model_name.split("/")[-1], ""), + "bs" + str(self.batch_size), + f"{str(self.width)}x{str(self.height)}", + self.precision, + str(self.max_length), + "fullpipeline", + ] pipeline_vmfb = utils.compile_to_vmfb( - os.path.join( - os.path.realpath(os.path.dirname(__file__)), - pipeline_file + ".mlir", - ), - self.device, - self.iree_target_triple, - self.ireec_flags["pipeline"], - os.path.join(self.pipeline_dir, "full_pipeline"), + pipeline_file, + self.devices["unet"]["driver"], + self.devices["unet"]["target"], + self.ireec_flags["unetloop"], + os.path.join(self.pipeline_dir, "_".join(pipeline_keys)), return_path=True, - mlir_source="file", + mlir_source="str", ) return pipeline_vmfb, None @@ -376,19 +559,59 @@ def load_pipeline( self, vmfbs: dict, weights: dict, - rt_device: str = "local-task", - compiled_pipeline: bool = True, + compiled_pipeline: bool = False, + split_scheduler: bool = True, + extra_device_args: dict = {}, ): + if "npu_delegate_path" in extra_device_args.keys(): + delegate = extra_device_args["npu_delegate_path"] + else: + delegate = None self.runners = {} runners = {} - if compiled_pipeline: + load_start = time.time() + if split_scheduler: + # We get scheduler at generate time and set steps then. + self.num_inference_steps = None + self.split_scheduler = True runners["pipe"] = vmfbRunner( - rt_device, + self.devices["unet"]["driver"], + vmfbs["unet"], + weights["unet"], + ) + unet_loaded = time.time() + print("\n[LOG] Unet loaded in ", unet_loaded - load_start, "sec") + runners["vae_decode"] = vmfbRunner( + self.devices["vae"]["driver"], + vmfbs["vae_decode"], + weights["vae_decode"], + extra_plugin=delegate, + ) + vae_loaded = time.time() + print("\n[LOG] VAE Decode loaded in ", vae_loaded - unet_loaded, "sec") + runners["prompt_encoder"] = vmfbRunner( + self.devices["clip"]["driver"], + vmfbs["prompt_encoder"], + weights["prompt_encoder"], + ) + clip_loaded = time.time() + print("\n[LOG] CLIP loaded in ", clip_loaded - vae_loaded, "sec") + elif compiled_pipeline: + assert ( + self.devices["unet"]["device"] + == self.devices["clip"]["device"] + == self.devices["vae"]["device"] + ), "Compiled pipeline requires all submodels to be on the same device." + assert ( + self.precision == self.vae_precision + ), "Compiled pipeline requires all submodels to have the same precision for now." + runners["pipe"] = vmfbRunner( + self.devices["unet"]["driver"], [ vmfbs["scheduled_unet"], vmfbs["prompt_encoder"], vmfbs["vae_decode"], - vmfbs["full_pipeline"], + vmfbs["fullpipeline"], ], [ weights["scheduled_unet"], @@ -397,18 +620,34 @@ def load_pipeline( None, ], ) + pipe_loaded = time.time() + print( + "\n[LOG] Compiled Pipeline loaded in ", pipe_loaded - load_start, "sec" + ) + else: runners["pipe"] = vmfbRunner( - rt_device, - [vmfbs["scheduled_unet"], vmfbs["pipeline"]], - [weights["scheduled_unet"], None], - ) - runners["vae_decode"] = vmfbRunner( - rt_device, vmfbs["vae_decode"], weights["vae_decode"] + self.devices["unet"]["driver"], + [ + vmfbs["scheduled_unet"], + vmfbs["unetloop"], + vmfbs["vae_decode"], + vmfbs["prompt_encoder"], + ], + [ + weights["scheduled_unet"], + None, + weights["vae_decode"], + weights["prompt_encoder"], + ], ) - runners["prompt_encoder"] = vmfbRunner( - rt_device, vmfbs["prompt_encoder"], weights["prompt_encoder"] + runners["vae_decode"] = runners["pipe"] + runners["prompt_encoder"] = runners["pipe"] + pipe_loaded = time.time() + print( + "\n[LOG] Compiled Pipeline loaded in ", pipe_loaded - load_start, "sec" ) + tok_start = time.time() runners["tokenizer_1"] = CLIPTokenizer.from_pretrained( self.hf_model_name, subfolder="tokenizer", @@ -417,6 +656,8 @@ def load_pipeline( self.hf_model_name, subfolder="tokenizer_2", ) + tok_loaded = time.time() + print("\n[LOG] Tokenizers loaded in ", tok_loaded - tok_start, "sec") self.runners = runners self.compiled_pipeline = compiled_pipeline print("Successfully loaded pipeline.") @@ -430,9 +671,59 @@ def generate_images( batch_count: int = 1, guidance_scale: float = 7.5, seed: float = -1, + return_imgs: bool = False, + steps: int = None, + cpu_scheduling: bool = False, + scheduler_id: str = "EulerDiscrete", + progress=None, ): + needs_new_scheduler = ( + (steps and steps != self.num_inference_steps) + or (cpu_scheduling != self.cpu_scheduling) + and self.split_scheduler + ) + + self.cpu_scheduling = cpu_scheduling + + if steps and not self.compiled_pipeline and needs_new_scheduler: + self.num_inference_steps = steps + if ( + steps + and not self.cpu_scheduling + and not self.compiled_pipeline + and needs_new_scheduler + ): + self.runners["scheduler"] = None + self.num_inference_steps = steps + self.scheduler_id = scheduler_id + scheduler_path = f"{scheduler_id}Scheduler_{self.num_inference_steps}" + if not os.path.exists(scheduler_path): + scheduler_path, _ = self.export_submodel("scheduler") + try: + self.runners["scheduler"] = schedulers.SharkSchedulerWrapper( + self.devices["unet"]["driver"], + scheduler_path, + ) + except: + print("JIT export of scheduler failed. Loading CPU scheduler.") + self.cpu_scheduling = True + if self.cpu_scheduling and needs_new_scheduler: + scheduler = schedulers.get_scheduler(self.hf_model_name, scheduler_id) + self.runners["scheduler"] = schedulers.SharkSchedulerCPUWrapper( + scheduler, + self.batch_size, + self.num_inference_steps, + self.runners["pipe"].config.device, + latents_dtype="float32" if self.precision == "fp32" else "float16", + ) + # TODO: implement case where this is false e.g. in SDXL Turbo - # do_classifier_free_guidance = True + do_classifier_free_guidance = True + + # Workaround for turbo support (guidance_scale 0) + if guidance_scale == 0: + negative_prompt = prompt + prompt = "" iree_dtype = "float32" if self.precision == "fp32" else "float16" torch_dtype = torch.float32 if self.precision == "fp32" else torch.float16 @@ -529,30 +820,111 @@ def generate_images( else: encode_prompts_start = time.time() - prompt_embeds, add_text_embeds = self.runners[ - "prompt_encoder" - ].ctx.modules.compiled_clip["encode_prompts"]( - *text_input_ids_list, *uncond_input_ids_list - ) + if self.do_classifier_free_guidance == False: + prompt_embeds, add_text_embeds = self.runners[ + "prompt_encoder" + ].ctx.modules.compiled_clip["encode_prompts_turbo"]( + *text_input_ids_list + ) + else: + prompt_embeds, add_text_embeds = self.runners[ + "prompt_encoder" + ].ctx.modules.compiled_clip["encode_prompts"]( + *text_input_ids_list, *uncond_input_ids_list + ) encode_prompts_end = time.time() for i in range(batch_count): + if self._interrupt: + self._interrupt = False + return None unet_start = time.time() - - latents = self.runners["pipe"].ctx.modules.sdxl_compiled_pipeline[ - "produce_image_latents" - ](samples[i], prompt_embeds, add_text_embeds, guidance_scale) - + if self.split_scheduler: + if self.cpu_scheduling: + sample, time_ids, steps, timesteps = self.runners[ + "scheduler" + ].initialize(samples[i], self.num_inference_steps) + else: + sample, time_ids, steps, timesteps = self.runners[ + "scheduler" + ].initialize(samples[i]) + iree_inputs = [ + sample, + ireert.asdevicearray( + self.runners["pipe"].config.device, prompt_embeds + ), + ireert.asdevicearray( + self.runners["pipe"].config.device, add_text_embeds + ), + time_ids, + None, + ] + for s in range(steps): + if self._interrupt: + self._interrupt = False + return None + # print(f"step {s}") + if self.cpu_scheduling: + step_index = s + else: + step_index = ireert.asdevicearray( + self.runners["scheduler"].runner.config.device, + torch.tensor([s]), + "int64", + ) + latents, t = self.runners["scheduler"].scale_model_input( + sample, + step_index, + timesteps, + ) + noise_pred = self.runners["pipe"].ctx.modules.compiled_unet[ + "run_forward" + ]( + latents, + t, + iree_inputs[1], + iree_inputs[2], + iree_inputs[3], + ) + sample = self.runners["scheduler"].step( + noise_pred, + t, + sample, + guidance_scale, + step_index, + ) + if isinstance(sample, torch.Tensor): + latents = ireert.asdevicearray( + self.runners["vae_decode"].config.device, + sample, + dtype=self.vae_dtype, + ) + else: + latents = sample + else: + latents = self.runners["pipe"].ctx.modules.sdxl_compiled_pipeline[ + "produce_image_latents" + ](samples[i], prompt_embeds, add_text_embeds, guidance_scale) + if ( + self.devices["unet"]["driver"] != self.devices["vae"]["driver"] + or self.precision != self.vae_precision + ): + latents = ireert.asdevicearray( + self.runners["vae_decode"].config.device, + latents.to_host(), + dtype=self.vae_dtype, + ) vae_start = time.time() + # print(latents.to_host()[0,0,:]) vae_out = self.runners["vae_decode"].ctx.modules.compiled_vae["main"]( latents ) + # print(vae_out.to_host()[0,0,:]) pipe_end = time.time() image = vae_out.to_host() - numpy_images.append(image) print("Batch #", i + 1, "\n") print( @@ -577,7 +949,6 @@ def generate_images( end = time.time() print("Total CLIP time:", encode_prompts_end - encode_prompts_start, "sec") print("Total tokenize time:", encode_prompts_start - tokenize_start, "sec") - print("Loading time: ", encode_prompts_start - pipe_start, "sec") if batch_count > 1: print( f"Total inference time ({batch_count} batch(es)):", @@ -585,12 +956,27 @@ def generate_images( "sec", ) timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + images = [] for idx, image in enumerate(numpy_images): image = torch.from_numpy(image).cpu().permute(0, 2, 3, 1).float().numpy() image = numpy_to_pil_image(image) - img_path = "sdxl_output_" + timestamp + "_" + str(idx) + ".png" - image[0].save(img_path) - print(img_path, "saved") + images.append(image) + if return_imgs: + return images + for idx_batch, image_batch in enumerate(images): + for idx, image in enumerate(image_batch): + img_path = ( + "sdxl_output_" + + timestamp + + "_" + + str(idx_batch) + + "_" + + str(idx) + + ".png" + ) + image.save(img_path) + print(img_path, "saved") + return def numpy_to_pil_image(images): @@ -602,56 +988,74 @@ def numpy_to_pil_image(images): images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images - pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + pil_images = [] + for batched_image in images: + for image in range(0, batched_image.size(dim=0)): + pil_images.append(Image.fromarray(image.squeeze(), mode="L")) else: - pil_images = [Image.fromarray(image) for image in images] - + pil_images = [] + for image in images: + pil_images.append(Image.fromarray(image)) return pil_images if __name__ == "__main__": from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args - mlirs = { - "vae_decode": None, - "prompt_encoder": None, - "scheduled_unet": None, - "pipeline": None, - "full_pipeline": None, - } - vmfbs = { - "vae_decode": None, - "prompt_encoder": None, - "scheduled_unet": None, - "pipeline": None, - "full_pipeline": None, + map = empty_pipe_dict + if args.split_scheduler: + map["unet"] = None + map.pop("scheduled_unet") + map.pop("unetloop") + map.pop("fullpipeline") + mlirs = copy.deepcopy(map) + vmfbs = copy.deepcopy(map) + weights = copy.deepcopy(map) + + if any(x for x in [args.clip_device, args.unet_device, args.vae_device]): + assert all( + x for x in [args.clip_device, args.unet_device, args.vae_device] + ), "Please specify device for all submodels or pass --device for all submodels." + assert all( + x for x in [args.clip_target, args.unet_target, args.vae_target] + ), "Please specify target triple for all submodels or pass --iree_target_triple for all submodels." + args.device = "hybrid" + args.iree_target_triple = "_".join( + [args.clip_target, args.unet_target, args.vae_target] + ) + else: + args.clip_device = args.device + args.unet_device = args.device + args.vae_device = args.device + args.clip_target = args.iree_target_triple + args.unet_target = args.iree_target_triple + args.vae_target = args.iree_target_triple + + devices = { + "clip": args.clip_device, + "unet": args.unet_device, + "vae": args.vae_device, } - weights = { - "vae_decode": None, - "prompt_encoder": None, - "scheduled_unet": None, - "pipeline": None, - "full_pipeline": None, + targets = { + "clip": args.clip_target, + "unet": args.unet_target, + "vae": args.vae_target, } + ireec_flags = { + "clip": args.ireec_flags + args.clip_flags, "unet": args.ireec_flags + args.unet_flags, "vae": args.ireec_flags + args.vae_flags, - "clip": args.ireec_flags + args.clip_flags, - "pipeline": args.ireec_flags, + "unetloop": args.ireec_flags, + "scheduler": args.ireec_flags, } if not args.pipeline_dir: - pipe_id_list = [ - "sdxl_1_0", - str(args.height), - str(args.width), - str(args.max_length), - args.precision, - args.device, - ] args.pipeline_dir = os.path.join( ".", - "_".join(pipe_id_list), + utils.create_safe_name(args.hf_model_name, ""), ) + if not os.path.exists(args.pipeline_dir): + os.makedirs(args.pipeline_dir, exist_ok=True) if args.input_mlir: user_mlir_list = args.input_mlir.split(",") else: @@ -661,18 +1065,17 @@ def numpy_to_pil_image(images): mlirs[submodel_id] = mlir_path if not args.external_weights_dir and args.external_weights: args.external_weights_dir = args.pipeline_dir - sdxl_pipe = SharkSDXLPipeline( args.hf_model_name, - args.scheduler_id, args.height, args.width, args.precision, args.max_length, args.batch_size, args.num_inference_steps, - args.device, - args.iree_target_triple, + devices, + targets, + args.scheduler_id, ireec_flags, args.attn_spec, args.decomp_attn, @@ -680,14 +1083,33 @@ def numpy_to_pil_image(images): args.external_weights_dir, args.external_weights, args.vae_decomp_attn, + custom_vae=None, + vae_precision=args.vae_precision, + batch_prompt_input=args.batch_prompt_input, ) + vmfbs, weights = sdxl_pipe.check_prepared(mlirs, vmfbs, weights) - sdxl_pipe.load_pipeline(vmfbs, weights, args.rt_device, args.compiled_pipeline) + + if args.npu_delegate_path: + extra_device_args = {"npu_delegate_path": args.npu_delegate_path} + else: + extra_device_args = {} + sdxl_pipe.load_pipeline( + vmfbs, + weights, + args.compiled_pipeline, + args.split_scheduler, + extra_device_args, + ) sdxl_pipe.generate_images( args.prompt, args.negative_prompt, args.batch_count, args.guidance_scale, args.seed, + False, + args.num_inference_steps, + cpu_scheduling=args.cpu_scheduling, + scheduler_id=args.scheduler_id, ) print("Image generation complete.") diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline_bench_f16.mlir b/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline_bench_f16.mlir deleted file mode 100644 index 523d09fa6..000000000 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline_bench_f16.mlir +++ /dev/null @@ -1,23 +0,0 @@ -module @sdxl_compiled_pipeline { - func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<1x4x128x128xf16>) -> (tensor<1x4x128x128xf16>, tensor<2x6xf16>, tensor) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"} - func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<1x4x128x128xf16>, %arg1: tensor<2x64x2048xf16>, %arg2: tensor<2x1280xf16>, %arg3: tensor<2x6xf16>, %arg4: tensor<1xf16>, %arg5: tensor<1xi64>) -> tensor<1x4x128x128xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} - func.func private @compiled_clip.encode_prompts(%arg0: tensor<1x64xi64>, %arg1: tensor<1x64xi64>, %arg2: tensor<1x64xi64>, %arg3: tensor<1x64xi64>) -> (tensor<2x64x2048xf16>, tensor<2x1280xf16>) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"} - func.func private @compiled_vae.main(%arg0: tensor<1x4x128x128xf16>) -> tensor<1x3x1024x1024xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} - - func.func @tokens_to_image(%sample: tensor<1x4x128x128xf16>, %guidance_scale: tensor<1xf16>, %t_ids_1: tensor<1x64xi64>, %t_ids_2: tensor<1x64xi64>, %u_ids_1: tensor<1x64xi64>, %u_ids_2: tensor<1x64xi64>) -> tensor<1x3x1024x1024xf16> { - %p_embeds, %t_embeds = func.call @compiled_clip.encode_prompts(%t_ids_1, %t_ids_2, %u_ids_1, %u_ids_2) : (tensor<1x64xi64>, tensor<1x64xi64>, tensor<1x64xi64>, tensor<1x64xi64>) -> (tensor<2x64x2048xf16>, tensor<2x1280xf16>) - %noisy_sample, %time_ids, %steps = func.call @compiled_scheduled_unet.run_initialize(%sample) : (tensor<1x4x128x128xf16>) -> (tensor<1x4x128x128xf16>, tensor<2x6xf16>, tensor) - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %steps_int = tensor.extract %steps[] : tensor - %n_steps = arith.index_cast %steps_int: i64 to index - %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<1x4x128x128xf16>) { - %step_64 = arith.index_cast %arg0 : index to i64 - %this_step = tensor.from_elements %step_64 : tensor<1xi64> - %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<1x4x128x128xf16>, tensor<2x64x2048xf16>, tensor<2x1280xf16>, tensor<2x6xf16>, tensor<1xf16>, tensor<1xi64>) -> tensor<1x4x128x128xf16> - scf.yield %inner : tensor<1x4x128x128xf16> - } - %image = func.call @compiled_vae.main(%res): (tensor<1x4x128x128xf16>) -> tensor<1x3x1024x1024xf16> - return %image : tensor<1x3x1024x1024xf16> - } -} \ No newline at end of file diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline_bench_f32.mlir b/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline_bench_f32.mlir deleted file mode 100644 index 669df73b2..000000000 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline_bench_f32.mlir +++ /dev/null @@ -1,23 +0,0 @@ -module @sdxl_compiled_pipeline { - func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<1x4x128x128xf32>) -> (tensor<1x4x128x128xf32>, tensor<2x6xf32>, tensor) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"} - func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<1x4x128x128xf32>, %arg1: tensor<2x64x2048xf32>, %arg2: tensor<2x1280xf32>, %arg3: tensor<2x6xf32>, %arg4: tensor<1xf32>, %arg5: tensor<1xi64>) -> tensor<1x4x128x128xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} - func.func private @compiled_clip.encode_prompts(%arg0: tensor<1x64xi64>, %arg1: tensor<1x64xi64>, %arg2: tensor<1x64xi64>, %arg3: tensor<1x64xi64>) -> (tensor<2x64x2048xf32>, tensor<2x1280xf32>) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"} - func.func private @compiled_vae.main(%arg0: tensor<1x4x128x128xf32>) -> tensor<1x3x1024x1024xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} - - func.func @tokens_to_image(%sample: tensor<1x4x128x128xf32>, %guidance_scale: tensor<1xf32>, %t_ids_1: tensor<1x64xi64>, %t_ids_2: tensor<1x64xi64>, %u_ids_1: tensor<1x64xi64>, %u_ids_2: tensor<1x64xi64>) -> tensor<1x3x1024x1024xf32> { - %p_embeds, %t_embeds = func.call @compiled_clip.encode_prompts(%t_ids_1, %t_ids_2, %u_ids_1, %u_ids_2) : (tensor<1x64xi64>, tensor<1x64xi64>, tensor<1x64xi64>, tensor<1x64xi64>) -> (tensor<2x64x2048xf32>, tensor<2x1280xf32>) - %noisy_sample, %time_ids, %steps = func.call @compiled_scheduled_unet.run_initialize(%sample) : (tensor<1x4x128x128xf32>) -> (tensor<1x4x128x128xf32>, tensor<2x6xf32>, tensor) - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %steps_int = tensor.extract %steps[] : tensor - %n_steps = arith.index_cast %steps_int: i64 to index - %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<1x4x128x128xf32>) { - %step_64 = arith.index_cast %arg0 : index to i64 - %this_step = tensor.from_elements %step_64 : tensor<1xi64> - %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<1x4x128x128xf32>, tensor<2x64x2048xf32>, tensor<2x1280xf32>, tensor<2x6xf32>, tensor<1xf32>, tensor<1xi64>) -> tensor<1x4x128x128xf32> - scf.yield %inner : tensor<1x4x128x128xf32> - } - %image = func.call @compiled_vae.main(%res): (tensor<1x4x128x128xf32>) -> tensor<1x3x1024x1024xf32> - return %image : tensor<1x3x1024x1024xf32> - } -} \ No newline at end of file diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index 1c6b6331c..d579c3419 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -12,6 +12,8 @@ from iree.compiler.ir import Context import numpy as np from shark_turbine.aot import * +from shark_turbine.transforms.general.add_metadata import AddMetadataPass + from turbine_models.custom_models.sd_inference import utils import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer @@ -24,6 +26,8 @@ def __init__( precision, hf_auth_token=None, do_classifier_free_guidance=True, + batch_size=1, + batch_input=False, ): super().__init__() self.torch_dtype = torch.float16 if precision == "fp16" else torch.float32 @@ -37,7 +41,9 @@ def __init__( subfolder="text_encoder_2", token=hf_auth_token, ) - self.do_classifier_free_guidance = do_classifier_free_guidance + self.do_classifier_free_guidance = True + self.batch_size = batch_size + self.batch_input = batch_input def forward( self, text_input_ids_1, text_input_ids_2, uncond_input_ids_1, uncond_input_ids_2 @@ -76,23 +82,71 @@ def forward( neg_prompt_embeds = torch.concat(neg_prompt_embeds_list, dim=-1) bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, 1, 1) prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view( bs_embed * 1, -1 ) + if not self.batch_input: + prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) add_text_embeds = pooled_prompt_embeds + if not self.batch_input: + add_text_embeds = add_text_embeds.repeat(self.batch_size, 1) if self.do_classifier_free_guidance: - neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(1, 1).view( - 1, -1 - ) + if not self.batch_input: + neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat( + 1, 1 + ).view(1, -1) neg_prompt_embeds = neg_prompt_embeds.repeat(1, 1, 1) neg_prompt_embeds = neg_prompt_embeds.view(bs_embed * 1, seq_len, -1) + if not self.batch_input: + neg_prompt_embeds = neg_prompt_embeds.repeat(self.batch_size, 1, 1) prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0) + if not self.batch_input: + neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat( + self.batch_size, 1 + ) add_text_embeds = torch.cat( [neg_pooled_prompt_embeds, add_text_embeds], dim=0 ) + add_text_embeds = add_text_embeds.to(self.torch_dtype) + prompt_embeds = prompt_embeds.to(self.torch_dtype) + return prompt_embeds, add_text_embeds + + def forward_turbo(self, text_input_ids_1, text_input_ids_2): + with torch.no_grad(): + prompt_embeds_1 = self.text_encoder_model_1( + text_input_ids_1, + output_hidden_states=True, + ) + prompt_embeds_2 = self.text_encoder_model_2( + text_input_ids_2, + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds_2[0] + + prompt_embeds_list = [ + prompt_embeds_1.hidden_states[-2], + prompt_embeds_2.hidden_states[-2], + ] + # neg_prompt_embeds_list = [ + # torch.zeros_like(prompt_embeds_list[0]), # dummy tensor + # torch.zeros_like(prompt_embeds_list[1]), # dummy tensor + # ] + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + bs_embed, seq_len, _ = prompt_embeds.shape + + prompt_embeds = prompt_embeds.repeat(1, 1, 1) + prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view( + bs_embed * 1, -1 + ) + prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1) + add_text_embeds = pooled_prompt_embeds + add_text_embeds = add_text_embeds.repeat(self.batch_size, 1) add_text_embeds = add_text_embeds.to(self.torch_dtype) prompt_embeds = prompt_embeds.to(self.torch_dtype) @@ -103,42 +157,36 @@ def export_prompt_encoder( hf_model_name, hf_auth_token=None, max_length=64, + batch_size=1, precision="fp16", compile_to="torch", external_weights=None, external_weight_path=None, device=None, - target_triple=None, + target=None, ireec_flags=None, - exit_on_vmfb=True, + exit_on_vmfb=False, pipeline_dir=None, input_mlir=None, attn_spec=None, weights_only=False, + batch_input=False, + decomp_attn=False, # Compatibility ): - if "turbo" in hf_model_name: - do_classifier_free_guidance = False - else: - do_classifier_free_guidance = True - - if (attn_spec in ["default"]) and ("gfx94" in target_triple): - attn_spec = os.path.join( - os.path.realpath(os.path.dirname(__file__)), "default_mfma_attn_spec.mlir" - ) - else: - attn_spec = None + do_classifier_free_guidance = True + safe_name = utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{str(max_length)}-{precision}-prompt-encoder-{device}", + ) if pipeline_dir not in [None, ""]: - safe_name = os.path.join(pipeline_dir, "prompt_encoder") - else: - safe_name = utils.create_safe_name( - hf_model_name, f"-{str(max_length)}-{precision}-prompt-encoder-{device}" - ) + safe_name = os.path.join(pipeline_dir, safe_name) + if input_mlir: vmfb_path = utils.compile_to_vmfb( input_mlir, device, - target_triple, + target, ireec_flags, safe_name, mlir_source="file", @@ -162,8 +210,18 @@ def export_prompt_encoder( ) tokenizers = [tokenizer_1, tokenizer_2] prompt_encoder_module = PromptEncoderModule( - hf_model_name, precision, hf_auth_token, do_classifier_free_guidance + hf_model_name, + precision, + hf_auth_token, + do_classifier_free_guidance, + batch_size=batch_size, + batch_input=batch_input, ) + + input_batchsize = 1 + if batch_input: + input_batchsize = batchsize + if precision == "fp16": prompt_encoder_module = prompt_encoder_module.half() mapper = {} @@ -188,34 +246,50 @@ class CompiledClip(CompiledModule): def encode_prompts( self, - t_ids_1=AbstractTensor(1, max_length, dtype=torch.int64), - t_ids_2=AbstractTensor(1, max_length, dtype=torch.int64), - uc_ids_1=AbstractTensor(1, max_length, dtype=torch.int64), - uc_ids_2=AbstractTensor(1, max_length, dtype=torch.int64), + t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + uc_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + uc_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), ): return jittable(prompt_encoder_module.forward)( t_ids_1, t_ids_2, uc_ids_1, uc_ids_2 ) + def encode_prompts_turbo( + self, + t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + ): + return jittable(prompt_encoder_module.forward_turbo)(t_ids_1, t_ids_2) + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" inst = CompiledClip(context=Context(), import_to=import_to) - module_str = str(CompiledModule.get_mlir_module(inst)) + module = CompiledModule.get_mlir_module(inst) + + model_metadata_encode = { + "model_name": hf_model_name + "_text_encoder", + "input_shapes": [str((input_batchsize, max_length)) for i in range(4)], + "input_dtypes": ["int64" for i in range(4)], + "use_attention_mask": False, + } + module = AddMetadataPass(module, model_metadata_encode, "encode_prompts").run() + module_str = str(module) if compile_to != "vmfb": - return module_str, tokenizers + return module_str else: vmfb_path = utils.compile_to_vmfb( module_str, device, - target_triple, + target, ireec_flags, safe_name, return_path=not exit_on_vmfb, const_expr_hoisting=True, attn_spec=attn_spec, ) - return module_str, vmfb_path + return vmfb_path if __name__ == "__main__": @@ -225,6 +299,7 @@ def encode_prompts( args.hf_model_name, args.hf_auth_token, args.max_length, + args.batch_size, args.precision, args.compile_to, args.external_weights, @@ -240,7 +315,7 @@ def encode_prompts( if args.input_mlir: exit() safe_name_1 = safe_name = utils.create_safe_name( - args.hf_model_name, f"_{str(args.max_length)}_{args.precision}_prompt_encoder" + args.hf_model_name, f"{str(args.max_length)}_{args.precision}_prompt_encoder" ) with open(f"{safe_name}.mlir", "w+") as f: f.write(mod_str) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder_runner.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder_runner.py index 50c01e964..8f633a6f8 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder_runner.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder_runner.py @@ -5,58 +5,18 @@ import numpy as np -def run_torch_clip(hf_model_name, hf_auth_token, prompt, max_length=64): - # TODO: Integrate with HFTransformerBuilder - from turbine_models.custom_models.sdxl_inference.clip import ClipModel - - model_1 = ClipModel(hf_model_name, hf_auth_token, index=1) - model_2 = ClipModel(hf_model_name, hf_auth_token, index=2) - tokenizer_1 = CLIPTokenizer.from_pretrained( - hf_model_name, - subfolder="tokenizer", - token=hf_auth_token, - ) - tokenizer_2 = CLIPTokenizer.from_pretrained( - hf_model_name, - subfolder="tokenizer_2", - token=hf_auth_token, - ) - text_input_1 = tokenizer_1( - prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - text_input_2 = tokenizer_2( - prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - example_input_1 = text_input_1.input_ids - example_input_2 = text_input_2.input_ids - - results_1 = model_1.forward(example_input_1) - results_2 = model_2.forward(example_input_2) - np_torch_output_1 = results_1[0].detach().cpu().numpy().astype(np.float16) - np_torch_output_2 = results_2[0].detach().cpu().numpy().astype(np.float16) - return np_torch_output_1, np_torch_output_2 - - def run_prompt_encoder( - args, + vmfb_path, + device, + external_weight_path, input_ids, uncond_input_ids, ): - prompt_encoder_runner = vmfbRunner( - args.device, args.vmfb_path, args.external_weight_path - ) - np.save("input0.npy", input_ids[0].numpy()) - np.save("input1.npy", input_ids[1].numpy()) - np.save("input2.npy", uncond_input_ids[0].numpy()) - np.save("input3.npy", uncond_input_ids[1].numpy()) + prompt_encoder_runner = vmfbRunner(device, vmfb_path, external_weight_path) + # np.save("input0.npy", input_ids[0].numpy()) + # np.save("input1.npy", input_ids[1].numpy()) + # np.save("input2.npy", uncond_input_ids[0].numpy()) + # np.save("input3.npy", uncond_input_ids[1].numpy()) prompt_encoder_inputs = [ ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[0]), ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[1]), @@ -66,23 +26,19 @@ def run_prompt_encoder( encoded_outputs = prompt_encoder_runner.ctx.modules.compiled_clip["encode_prompts"]( *prompt_encoder_inputs ) + for i in encoded_outputs: + i = i.to_host() del prompt_encoder_inputs return encoded_outputs -if __name__ == "__main__": - from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args - - tokenizer_1 = CLIPTokenizer.from_pretrained( - args.hf_model_name, - subfolder="tokenizer", - token=args.hf_auth_token, - ) - tokenizer_2 = CLIPTokenizer.from_pretrained( - args.hf_model_name, - subfolder="tokenizer_2", - token=args.hf_auth_token, - ) +def run_tokenize( + tokenizer_1, + tokenizer_2, + prompt, + negative_prompt, + max_length=64, +): text_input_ids_list = [] uncond_input_ids_list = [] @@ -90,16 +46,16 @@ def run_prompt_encoder( tokenizers = [tokenizer_1, tokenizer_2] for tokenizer in tokenizers: text_inputs = tokenizer( - args.prompt, + prompt, padding="max_length", - max_length=args.max_length, + max_length=max_length, truncation=True, return_tensors="pt", ) uncond_input = tokenizer( - args.negative_prompt, + negative_prompt, padding="max_length", - max_length=args.max_length, + max_length=max_length, truncation=True, return_tensors="pt", ) @@ -108,9 +64,34 @@ def run_prompt_encoder( text_input_ids_list.extend([text_input_ids]) uncond_input_ids_list.extend([uncond_input_ids]) + return text_input_ids_list, uncond_input_ids_list + +if __name__ == "__main__": + from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args + + tokenizer_1 = CLIPTokenizer.from_pretrained( + args.hf_model_name, + subfolder="tokenizer", + token=args.hf_auth_token, + ) + tokenizer_2 = CLIPTokenizer.from_pretrained( + args.hf_model_name, + subfolder="tokenizer_2", + token=args.hf_auth_token, + ) + + text_input_ids_list, uncond_input_ids_list = run_tokenize( + tokenizer_1, + tokenizer_2, + args.prompt, + args.negative_prompt, + args.max_length, + ) turbine_output1, turbine_output2 = run_prompt_encoder( - args, + args.vmfb_path, + args.rt_device, + args.external_weight_path, text_input_ids_list, uncond_input_ids_list, ) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_sched_unet_bench_f16.mlir b/models/turbine_models/custom_models/sdxl_inference/sdxl_sched_unet_bench_f16.mlir deleted file mode 100644 index b12fc82b9..000000000 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_sched_unet_bench_f16.mlir +++ /dev/null @@ -1,19 +0,0 @@ -module @sdxl_compiled_pipeline { - func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<1x4x128x128xf16>) -> (tensor<1x4x128x128xf16>, tensor<2x6xf16>, tensor) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"} - func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<1x4x128x128xf16>, %arg1: tensor<2x64x2048xf16>, %arg2: tensor<2x1280xf16>, %arg3: tensor<2x6xf16>, %arg4: tensor<1xf16>, %arg5: tensor<1xi64>) -> tensor<1x4x128x128xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} - - func.func @produce_image_latents(%sample: tensor<1x4x128x128xf16>, %p_embeds: tensor<2x64x2048xf16>, %t_embeds: tensor<2x1280xf16>, %guidance_scale: tensor<1xf16>) -> tensor<1x4x128x128xf16> { - %noisy_sample, %time_ids, %steps = func.call @compiled_scheduled_unet.run_initialize(%sample) : (tensor<1x4x128x128xf16>) -> (tensor<1x4x128x128xf16>, tensor<2x6xf16>, tensor) - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %steps_int = tensor.extract %steps[] : tensor - %n_steps = arith.index_cast %steps_int: i64 to index - %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<1x4x128x128xf16>) { - %step_64 = arith.index_cast %arg0 : index to i64 - %this_step = tensor.from_elements %step_64 : tensor<1xi64> - %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<1x4x128x128xf16>, tensor<2x64x2048xf16>, tensor<2x1280xf16>, tensor<2x6xf16>, tensor<1xf16>, tensor<1xi64>) -> tensor<1x4x128x128xf16> - scf.yield %inner : tensor<1x4x128x128xf16> - } - return %res : tensor<1x4x128x128xf16> - } -} \ No newline at end of file diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_sched_unet_bench_f32.mlir b/models/turbine_models/custom_models/sdxl_inference/sdxl_sched_unet_bench_f32.mlir deleted file mode 100644 index fbc69f854..000000000 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_sched_unet_bench_f32.mlir +++ /dev/null @@ -1,19 +0,0 @@ -module @sdxl_compiled_pipeline { - func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<1x4x128x128xf32>) -> (tensor<1x4x128x128xf32>, tensor<2x6xf32>, tensor) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"} - func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<1x4x128x128xf32>, %arg1: tensor<2x64x2048xf32>, %arg2: tensor<2x1280xf32>, %arg3: tensor<2x6xf32>, %arg4: tensor<1xf32>, %arg5: tensor<1xi64>) -> tensor<1x4x128x128xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} - - func.func @produce_image_latents(%sample: tensor<1x4x128x128xf32>, %p_embeds: tensor<2x64x2048xf32>, %t_embeds: tensor<2x1280xf32>, %guidance_scale: tensor<1xf32>) -> tensor<1x4x128x128xf32> { - %noisy_sample, %time_ids, %steps = func.call @compiled_scheduled_unet.run_initialize(%sample) : (tensor<1x4x128x128xf32>) -> (tensor<1x4x128x128xf32>, tensor<2x6xf32>, tensor) - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %steps_int = tensor.extract %steps[] : tensor - %n_steps = arith.index_cast %steps_int: i64 to index - %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg_s = %noisy_sample) -> (tensor<1x4x128x128xf32>) { - %step_64 = arith.index_cast %arg0 : index to i64 - %this_step = tensor.from_elements %step_64 : tensor<1xi64> - %inner = func.call @compiled_scheduled_unet.run_forward(%arg_s, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<1x4x128x128xf32>, tensor<2x64x2048xf32>, tensor<2x1280xf32>, tensor<2x6xf32>, tensor<1xf32>, tensor<1xi64>) -> tensor<1x4x128x128xf32> - scf.yield %inner : tensor<1x4x128x128xf32> - } - return %res : tensor<1x4x128x128xf32> - } -} \ No newline at end of file diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py index f74c707e7..fd9adaa8f 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py @@ -6,20 +6,25 @@ # from @aviator19941's gist : https://gist.github.com/aviator19941/4e7967bd1787c83ee389a22637c6eea7 +import copy import os import sys +import numpy as np + +# os.environ["TORCH_LOGS"] = "+dynamo" + +import torch +import torch._dynamo as dynamo from iree import runtime as ireert from iree.compiler.ir import Context -import numpy as np + from shark_turbine.aot import * +import shark_turbine.ops as ops + from turbine_models.custom_models.sd_inference import utils -import torch -import torch._dynamo as dynamo +from turbine_models.custom_models.sd_inference.schedulers import get_scheduler from diffusers import UNet2DConditionModel -from shark_turbine.dynamo.passes import ( - DEFAULT_DECOMPOSITIONS, -) class SDXLScheduledUnet(torch.nn.Module): @@ -36,13 +41,19 @@ def __init__( return_index=False, ): super().__init__() + self.do_classifier_free_guidance = True + # if any(key in hf_model_name for key in ["turbo", "lightning"]): + # self.do_classifier_free_guidance = False self.dtype = torch.float16 if precision == "fp16" else torch.float32 self.scheduler = utils.get_schedulers(hf_model_name)[scheduler_id] - if scheduler_id == "PNDM": - num_inference_steps = num_inference_steps - 1 + # if scheduler_id == "PNDM": + # num_inference_steps = num_inference_steps - 1 self.scheduler.set_timesteps(num_inference_steps) self.scheduler.is_scale_input_called = True self.return_index = return_index + self.height = height + self.width = width + self.batch_size = batch_size if precision == "fp16": try: @@ -69,15 +80,16 @@ def __init__( ) def initialize(self, sample): - height = sample.shape[-2] * 8 - width = sample.shape[-1] * 8 + height = self.height + width = self.width original_size = (height, width) target_size = (height, width) crops_coords_top_left = (0, 0) add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids]) - add_time_ids = torch.cat([add_time_ids] * 2, dim=0) - add_time_ids = add_time_ids.repeat(sample.shape[0], 1).type(self.dtype) + add_time_ids = torch.tensor([add_time_ids], dtype=self.dtype) + if self.do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + add_time_ids = add_time_ids.repeat(self.batch_size, 1).type(self.dtype) timesteps = self.scheduler.timesteps step_indexes = torch.tensor(len(timesteps)) sample = sample * self.scheduler.init_noise_sigma @@ -86,31 +98,50 @@ def initialize(self, sample): def forward( self, sample, prompt_embeds, text_embeds, time_ids, guidance_scale, step_index ): - with torch.no_grad(): - added_cond_kwargs = { - "text_embeds": text_embeds, - "time_ids": time_ids, - } - t = self.scheduler.timesteps[step_index] + added_cond_kwargs = { + "time_ids": time_ids, + "text_embeds": text_embeds, + } + t = self.scheduler.timesteps[step_index] + if self.do_classifier_free_guidance: latent_model_input = torch.cat([sample] * 2) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - noise_pred = self.unet.forward( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=None, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - + else: + latent_model_input = sample + # ops.iree.trace_tensor(f"latent_model_input_{step_index}", latent_model_input) + + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ).type(self.dtype) + print( + latent_model_input.shape, + t.shape, + sample.shape, + prompt_embeds.shape, + added_cond_kwargs, + guidance_scale, + step_index, + ) + # ops.iree.trace_tensor(f"latent_model_input_scaled_{step_index}", latent_model_input) + noise_pred = self.unet.forward( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + # ops.iree.trace_tensor(f"noise_pred_{step_index}", noise_pred) + + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) - sample = self.scheduler.step(noise_pred, t, sample, return_dict=False)[0] - return sample.type(self.dtype) + sample = self.scheduler.step(noise_pred, t, sample, return_dict=False)[0] + return sample.type(self.dtype) +@torch.no_grad() def export_scheduled_unet_model( scheduled_unet_model, scheduler_id, @@ -135,30 +166,26 @@ def export_scheduled_unet_model( input_mlir=None, weights_only=False, ): - if "turbo" in hf_model_name: - do_classifier_free_guidance = False - else: - do_classifier_free_guidance = True - if ( - (attn_spec in ["default"]) - and decomp_attn == False - and ("gfx9" in iree_target_triple) - ): - attn_spec = os.path.join( - os.path.realpath(os.path.dirname(__file__)), "default_mfma_attn_spec.mlir" - ) - elif decomp_attn: - attn_spec = None - + # if "turbo" in hf_model_name: + # do_classifier_free_guidance = False + # else: + # do_classifier_free_guidance = True + do_classifier_free_guidance = True + filename_keys = [ + f"bs{batch_size}", + str(max_length), + f"{height}x{width}", + precision, + scheduler_id, + "DiffusionModule", + str(num_inference_steps), + ] + safe_name = utils.create_safe_name( + hf_model_name, + "_".join(filename_keys), + ) if pipeline_dir: - safe_name = os.path.join( - pipeline_dir, f"{scheduler_id}_unet_{str(num_inference_steps)}" - ) - else: - safe_name = utils.create_safe_name( - hf_model_name, - f"_{max_length}_{height}x{width}_{precision}_scheduled_unet_{device}", - ) + safe_name = os.path.join(pipeline_dir, safe_name) if input_mlir: vmfb_path = utils.compile_to_vmfb( @@ -166,85 +193,90 @@ def export_scheduled_unet_model( device, iree_target_triple, ireec_flags, - safe_name, + safe_name + "_" + iree_target_triple, mlir_source="file", return_path=not exit_on_vmfb, attn_spec=attn_spec, ) return vmfb_path - mapper = {} - - decomp_list = DEFAULT_DECOMPOSITIONS - if decomp_attn == True: - decomp_list.extend( - [ - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, - torch.ops.aten._scaled_dot_product_flash_attention.default, - ] - ) - dtype = torch.float16 if precision == "fp16" else torch.float32 if precision == "fp16": scheduled_unet_model = scheduled_unet_model.half() + mapper = {} utils.save_external_weights( mapper, scheduled_unet_model, external_weights, external_weight_path ) - if weights_only: return external_weight_path - sample = ( + if do_classifier_free_guidance: + init_batch_dim = 2 + else: + init_batch_dim = 1 + + sample_shape = [ batch_size, scheduled_unet_model.unet.config.in_channels, height // 8, width // 8, + ] + time_ids_shape = [init_batch_dim * batch_size, 6] + prompt_embeds_shape = [init_batch_dim * batch_size, max_length, 2048] + text_embeds_shape = [init_batch_dim * batch_size, 1280] + + fxb = FxProgramsBuilder(scheduled_unet_model) + + example_init_args = [torch.empty(sample_shape, dtype=dtype)] + example_forward_args = [ + torch.empty(sample_shape, dtype=dtype), + torch.empty(prompt_embeds_shape, dtype=dtype), + torch.empty(text_embeds_shape, dtype=dtype), + torch.empty(time_ids_shape, dtype=dtype), + torch.empty(1, dtype=dtype), # guidance_scale + torch.empty(1, dtype=torch.int64), # timestep + ] + + @fxb.export_program( + args=(example_init_args,), ) - if do_classifier_free_guidance: - init_batch_dim = 2 - else: - init_batch_dim = 1 + def _initialize(module, sample): + return module.initialize(*sample) + + @fxb.export_program( + args=(example_forward_args,), + ) + def _forward( + module, + inputs, + ): + return module.forward(*inputs) + + decomp_list = [] + if decomp_attn == True: + decomp_list.extend( + [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + ] + ) + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): - time_ids_shape = (init_batch_dim * batch_size, 6) - prompt_embeds_shape = (init_batch_dim * batch_size, max_length, 2048) - text_embeds_shape = (init_batch_dim * batch_size, 1280) + class CompiledScheduledUnet(CompiledModule): + run_initialize = _initialize + run_forward = _forward - class CompiledScheduledUnet(CompiledModule): if external_weights: - params = export_parameters( - scheduled_unet_model, - external=True, - external_scope="", - name_mapper=mapper.get, - ) - else: - params = export_parameters(scheduled_unet_model) - - def run_initialize( - self, - sample=AbstractTensor(*sample, dtype=dtype), - ): - return jittable(scheduled_unet_model.initialize)(sample) - - def run_forward( - self, - sample=AbstractTensor(*sample, dtype=dtype), - prompt_embeds=AbstractTensor(*prompt_embeds_shape, dtype=dtype), - text_embeds=AbstractTensor(*text_embeds_shape, dtype=dtype), - time_ids=AbstractTensor(*time_ids_shape, dtype=dtype), - guidance_scale=AbstractTensor(1, dtype=dtype), - step_index=AbstractTensor(1, dtype=torch.int64), - ): - return jittable(scheduled_unet_model.forward, decompose_ops=decomp_list)( - sample, prompt_embeds, text_embeds, time_ids, guidance_scale, step_index - ) + externalize_module_parameters(scheduled_unet_model) - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledScheduledUnet(context=Context(), import_to=import_to) + inst = CompiledScheduledUnet(context=Context(), import_to="IMPORT") - module_str = str(CompiledModule.get_mlir_module(inst)) + module_str = str(CompiledModule.get_mlir_module(inst)) if compile_to != "vmfb": return module_str @@ -254,7 +286,7 @@ def run_forward( device, iree_target_triple, ireec_flags, - safe_name, + safe_name + "_" + iree_target_triple, return_path=True, attn_spec=attn_spec, ) @@ -264,31 +296,43 @@ def run_forward( def export_pipeline_module(args): - pipeline_file = ( - "sdxl_sched_unet_bench_" + "f32" - if args.precision == "fp32" - else "sdxl_sched_unet_bench_" + "f16" + from turbine_models.custom_models.sdxl_inference.pipeline_ir import get_pipeline_ir + + pipeline_file = get_pipeline_ir( + args.width, + args.height, + args.precision, + args.batch_size, + args.max_length, + "unet_loop", ) - if "turbo" in args.hf_model_name: - pipe_prefix = "sdxl_turbo_pipeline_bench_" - else: - pipe_prefix = "sdxl_pipeline_bench_" - full_pipeline_file = ( - pipe_prefix + "f32" if args.precision == "fp32" else pipe_prefix + "f16" + pipeline_vmfb = utils.compile_to_vmfb( + pipeline_file, + args.device, + args.iree_target_triple, + None, + os.path.join(args.pipeline_dir, "pipeline"), + return_path=True, + mlir_source="str", + ) + full_pipeline_file = get_pipeline_ir( + args.width, + args.height, + args.precision, + args.batch_size, + args.max_length, + "tokens_to_image", ) - full_pipeline_vmfb_path = utils.compile_to_vmfb( - os.path.join( - os.path.realpath(os.path.dirname(__file__)), full_pipeline_file + ".mlir" - ), + full_pipeline_vmfb = utils.compile_to_vmfb( + pipeline_file, args.device, args.iree_target_triple, - args.ireec_flags, - "sdxl_full_pipeline_" + args.precision + "_" + args.iree_target_triple, + None, + os.path.join(args.pipeline_dir, "pipeline"), return_path=True, - const_expr_hoisting=False, - mlir_source="file", + mlir_source="str", ) - return full_pipeline_vmfb_path + return full_pipeline_vmfb if __name__ == "__main__": @@ -308,7 +352,7 @@ def export_pipeline_module(args): args.num_inference_steps, args.return_index, ) - if args.compile_to == "vmfb": + if args.compile_to == "vmfb" and args.pipeline_dir is not None: pipeline_vmfb_path = export_pipeline_module(args) mod_str = export_scheduled_unet_model( scheduled_unet_model, @@ -337,7 +381,7 @@ def export_pipeline_module(args): exit() safe_name = utils.create_safe_name( args.hf_model_name + "_" + args.scheduler_id, - f"_{args.max_length}_{args.height}x{args.width}_{args.precision}_unet_{str(args.num_inference_steps)}", + f"_bs{args.batch_size}_{args.max_length}_{args.height}x{args.width}_{args.precision}_unet_{str(args.num_inference_steps)}", ) with open(f"{safe_name}.mlir", "w+") as f: f.write(mod_str) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet_runner.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet_runner.py index 8945d274a..5e90596d9 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet_runner.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet_runner.py @@ -1,155 +1,25 @@ import argparse from turbine_models.model_runner import vmfbRunner -from turbine_models.custom_models.sd_inference import utils +from turbine_models.custom_models.sd_inference import utils, schedulers from iree import runtime as ireert import torch import numpy as np from tqdm.auto import tqdm +from shark_turbine.ops.iree import trace_tensor torch.random.manual_seed(0) -def run_unet_hybrid( - sample, - prompt_embeds, - text_embeds, - args, -): - runner = vmfbRunner(args.device, args.vmfb_path, args.external_weight_path) - init_inp = [ - ireert.asdevicearray(runner.config.device, sample), - ] - sample, time_ids, steps = runner.ctx.modules.compiled_scheduled_unet[ - "run_initialize" - ]( - *init_inp, - ) - dtype = "float16" if args.precision == "fp16" else "float32" - inputs = [ - sample, - ireert.asdevicearray(runner.config.device, prompt_embeds), - ireert.asdevicearray(runner.config.device, text_embeds), - time_ids, - ireert.asdevicearray( - runner.config.device, np.asarray([args.guidance_scale]), dtype=dtype - ), - None, - ] - for i in range(steps.to_host()): - inputs[0] = sample - inputs[5] = ireert.asdevicearray( - runner.config.device, torch.tensor([i]), dtype="int64" - ) - sample = runner.ctx.modules.compiled_scheduled_unet["run_forward"](*inputs) - return sample - - +@torch.no_grad() def run_torch_scheduled_unet( sample, prompt_embeds, text_embeds, args, ): - from diffusers import UNet2DConditionModel - - class SDXLScheduledUnet(torch.nn.Module): - def __init__( - self, - hf_model_name, - scheduler_id, - height, - width, - batch_size, - hf_auth_token=None, - precision="fp32", - num_inference_steps=1, - return_index=False, - ): - super().__init__() - self.dtype = torch.float16 if precision == "fp16" else torch.float32 - self.scheduler = utils.get_schedulers(hf_model_name)[scheduler_id] - self.scheduler.set_timesteps(num_inference_steps) - self.scheduler.is_scale_input_called = True - self.return_index = return_index - - if precision == "fp16": - try: - self.unet = UNet2DConditionModel.from_pretrained( - hf_model_name, - subfolder="unet", - auth_token=hf_auth_token, - low_cpu_mem_usage=False, - variant="fp16", - ) - except: - self.unet = UNet2DConditionModel.from_pretrained( - hf_model_name, - subfolder="unet", - auth_token=hf_auth_token, - low_cpu_mem_usage=False, - ) - else: - self.unet = UNet2DConditionModel.from_pretrained( - hf_model_name, - subfolder="unet", - auth_token=hf_auth_token, - low_cpu_mem_usage=False, - ) - - def initialize(self, sample): - height = sample.shape[-2] * 8 - width = sample.shape[-1] * 8 - original_size = (height, width) - target_size = (height, width) - crops_coords_top_left = (0, 0) - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids]) - add_time_ids = torch.cat([add_time_ids] * 2, dim=0) - add_time_ids = add_time_ids.repeat(sample.shape[0], 1).type(self.dtype) - timesteps = self.scheduler.timesteps - step_indexes = torch.tensor(len(timesteps)) - sample = sample * self.scheduler.init_noise_sigma - return sample.type(self.dtype), add_time_ids, step_indexes - - def forward( - self, - sample, - prompt_embeds, - text_embeds, - time_ids, - guidance_scale, - step_index, - ): - with torch.no_grad(): - added_cond_kwargs = { - "text_embeds": text_embeds, - "time_ids": time_ids, - } - t = self.scheduler.timesteps[step_index] - latent_model_input = torch.cat([sample] * 2) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) - noise_pred = self.unet.forward( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=None, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - sample = self.scheduler.step(noise_pred, t, sample, return_dict=False)[ - 0 - ] - if self.return_index: - return sample.type(self.dtype), step_index - else: - return sample.type(self.dtype) + from turbine_models.custom_models.sdxl_inference.sdxl_scheduled_unet import ( + SDXLScheduledUnet, + ) unet_model = SDXLScheduledUnet( args.hf_model_name, @@ -158,9 +28,9 @@ def forward( args.width, args.batch_size, args.hf_auth_token, - args.precision, + "fp32", args.num_inference_steps, - ) + ).float() sample, add_time_ids, steps = unet_model.initialize(sample) for i in range(steps): sample = unet_model.forward( @@ -168,13 +38,13 @@ def forward( prompt_embeds.float(), text_embeds.float(), add_time_ids.float(), - args.guidance_scale, + torch.tensor(args.guidance_scale, dtype=torch.float32), i, ) return sample -def run_scheduled_unet( +def run_scheduled_unet_compiled( sample, prompt_embeds, text_embeds, @@ -194,7 +64,6 @@ def run_scheduled_unet( pipe_runner.config.device, np.asarray([args.guidance_scale]), dtype=dtype ), ] - print(inputs) latents = pipe_runner.ctx.modules.sdxl_compiled_pipeline["produce_image_latents"]( *inputs, ) @@ -202,6 +71,144 @@ def run_scheduled_unet( return latents +def run_scheduled_unet_initialize( + sample, + unet_runner, + args, +): + inputs = [ + ireert.asdevicearray(unet_runner.config.device, sample), + ] + sample, time_ids, steps = unet_runner.ctx.modules.compiled_scheduled_unet[ + "run_initialize" + ]( + *inputs, + ) + return sample, time_ids, steps + + +def run_scheduled_unet_forward( + inputs, + unet_runner, + args, +): + sample = unet_runner.ctx.modules.compiled_scheduled_unet["run_forward"](*inputs) + return sample + + +def run_scheduled_unet_python( + sample, + prompt_embeds, + text_embeds, + args, +): + unet_runner = vmfbRunner( + args.device, + args.vmfb_path, + args.external_weight_path, + ) + dtype = "float16" if args.precision == "fp16" else "float32" + sample, time_ids, steps = run_scheduled_unet_initialize( + sample, + unet_runner, + args, + ) + iree_inputs = [ + sample, + ireert.asdevicearray(unet_runner.config.device, prompt_embeds), + ireert.asdevicearray(unet_runner.config.device, text_embeds), + time_ids, + ireert.asdevicearray( + unet_runner.config.device, np.asarray([args.guidance_scale]), dtype=dtype + ), + None, + ] + for i in range(steps.to_host()): + iree_inputs[0] = sample + iree_inputs[5] = ireert.asdevicearray( + unet_runner.config.device, torch.tensor([i]), dtype="int64" + ) + sample = run_scheduled_unet_forward( + iree_inputs, + unet_runner, + args, + ) + return sample + + +def run_unet_split_scheduled( + sample, + prompt_embeds, + text_embeds, + args, +): + dtype = "float16" if args.precision == "fp16" else "float32" + torch_dtype = torch.float16 if args.precision == "fp16" else torch.float32 + unet_runner = vmfbRunner( + args.device, + args.vmfb_path, + args.external_weight_path, + ) + if not args.scheduler_vmfb_path: + print("--scheduler_vmfb_path not supplied. Using cpu scheduling.") + scheduler = schedulers.get_scheduler(args.hf_model_name, args.scheduler_id) + scheduler = schedulers.SharkSchedulerCPUWrapper( + scheduler, + args.batch_size, + args.num_inference_steps, + unet_runner.config.device, + dtype, + ) + guidance_scale = torch.tensor([args.guidance_scale]) + else: + scheduler = schedulers.SharkSchedulerWrapper( + args.device, + args.scheduler_vmfb_path, + ) + guidance_scale = ireert.asdevicearray( + scheduler.runner.config.device, + np.asarray([args.guidance_scale]), + dtype=dtype, + ) + sample, time_ids, steps, timesteps = scheduler.initialize(sample) + iree_inputs = [ + sample, + ireert.asdevicearray(unet_runner.config.device, prompt_embeds), + ireert.asdevicearray(unet_runner.config.device, text_embeds), + time_ids, + None, + ] + for i in range(steps.to_host()): + # print(f"step {i}") + if args.scheduler_vmfb_path: + step_index = ireert.asdevicearray( + unet_runner.config.device, torch.tensor([i]), dtype="int64" + ) + else: + step_index = i + latents, t = scheduler.scale_model_input( + sample, + step_index, + timesteps, + ) + noise_pred = unet_runner.ctx.modules.compiled_unet["run_forward"]( + latents, + t, + iree_inputs[1], + iree_inputs[2], + iree_inputs[3], + ) + sample = scheduler.step( + noise_pred, + t, + sample, + guidance_scale, + step_index, + ) + return sample + + +@torch.no_grad() def run_torch_diffusers_loop( sample, prompt_embeds, @@ -215,40 +222,48 @@ def run_torch_diffusers_loop( args.hf_auth_token, precision="fp32", ) - scheduler = utils.get_schedulers(args.hf_model_name)[args.scheduler_id] - + scheduler = schedulers.get_scheduler(args.hf_model_name, args.scheduler_id) + if args.scheduler_id == "PNDM": + scheduler.config.skip_prk_steps = True scheduler.set_timesteps(args.num_inference_steps) - scheduler.is_scale_input_called = True + timesteps = scheduler.timesteps + print(timesteps) sample = sample * scheduler.init_noise_sigma - height = sample.shape[-2] * 8 - width = sample.shape[-1] * 8 + height = args.height + width = args.width original_size = (height, width) target_size = (height, width) crops_coords_top_left = (0, 0) add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids, add_time_ids], dtype=torch.float32) + add_time_ids = torch.tensor([add_time_ids], dtype=torch.float32) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) add_time_ids = add_time_ids.repeat(args.batch_size * 1, 1) sample = sample.to(torch.float32) prompt_embeds = prompt_embeds.to(torch.float32) text_embeds = text_embeds.to(torch.float32) - for i in range(args.num_inference_steps): - timestep = scheduler.timesteps[i] - - latent_model_input = scheduler.scale_model_input(sample, timestep) + for idx, t in enumerate(timesteps): + print(t) + latent_model_input = torch.cat([sample] * 2) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) noise_pred = unet_model.forward( latent_model_input, - timestep, + t, prompt_embeds, text_embeds, add_time_ids, - args.guidance_scale, + ) + # print("NOISE_PRED: ", noise_pred) + # print("STEP_INDEX : ", idx) + noise_preds = noise_pred.chunk(2) + noise_pred = noise_preds[0] + args.guidance_scale * ( + noise_preds[1] - noise_preds[0] ) sample = scheduler.step( noise_pred, - timestep, + t, sample, return_dict=False, )[0] @@ -263,96 +278,92 @@ def run_torch_diffusers_loop( dtype = torch.float16 else: dtype = torch.float32 + + init_batch_dim = 2 sample = torch.rand( args.batch_size, 4, args.height // 8, args.width // 8, dtype=dtype ) timestep = torch.zeros(1, dtype=torch.int64) - prompt_embeds = torch.rand(2 * args.batch_size, args.max_length, 2048, dtype=dtype) - text_embeds = torch.rand(2 * args.batch_size, 1280, dtype=dtype) - - turbine_output = run_scheduled_unet( - sample, - prompt_embeds, - text_embeds, - args, - ) - print( - "TURBINE OUTPUT:", - turbine_output.to_host(), - turbine_output.to_host().shape, - turbine_output.to_host().dtype, + prompt_embeds = torch.rand( + init_batch_dim * args.batch_size, args.max_length, 2048, dtype=dtype ) - - if args.compare_vs_torch: - from turbine_models.custom_models.sd_inference import utils - - print("generating output with python/torch scheduling unet: ") - hybrid_output = run_unet_hybrid( + text_embeds = torch.rand(init_batch_dim * args.batch_size, 1280, dtype=dtype) + time_ids = torch.rand(init_batch_dim * args.batch_size, 6, dtype=dtype) + if args.compiled_pipeline: + assert ( + args.pipeline_vmfb_path is not None + ), "--pipeline_vmfb_path is required for compiled pipeline run" + turbine_compiled_output = run_scheduled_unet_compiled( sample, prompt_embeds, text_embeds, args, + ).to_host() + print( + "TURBINE COMPILED OUTPUT:", + turbine_compiled_output, + turbine_compiled_output.shape, + turbine_compiled_output.dtype, ) - print("generating torch output: ") - torch_output = run_torch_scheduled_unet( + turbine_output = turbine_compiled_output + elif args.split_scheduler: + turbine_split_output = run_unet_split_scheduled( sample, prompt_embeds, text_embeds, args, ) - print("generating torch+diffusers output: ") - diff_output = run_torch_diffusers_loop( + if args.scheduler_vmfb_path: + turbine_split_output = turbine_split_output.to_host() + print( + "TURBINE SPLIT OUTPUT:", + turbine_split_output, + turbine_split_output.shape, + turbine_split_output.dtype, + ) + turbine_output = turbine_split_output + else: + turbine_python_output = run_scheduled_unet_python( sample, prompt_embeds, text_embeds, args, - ) + ).to_host() print( - "diffusers-like OUTPUT:", diff_output, diff_output.shape, diff_output.dtype + "TURBINE PYTHON OUTPUT:", + turbine_python_output, + turbine_python_output.shape, + turbine_python_output.dtype, ) - print("torch OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + turbine_output = turbine_python_output - print( - "HYBRID OUTPUT:", - hybrid_output.to_host(), - hybrid_output.to_host().shape, - hybrid_output.to_host().dtype, - ) - print("Comparing... \n(turbine pipelined unet to torch unet): ") - try: - np.testing.assert_allclose( - turbine_output, torch_output, rtol=4e-2, atol=4e-2 + if args.compare_vs_torch: + if args.scheduler_id == "EulerAncestralDiscrete" and args.scheduler_vmfb_path: + print( + f"WARNING: {args.scheduler_id} scheduler adds random noise to results and we haven't piped through a torch generator yet to fix the seed. Expect mismatch results." ) - except AssertionError as err: - print(err) - print("\n(turbine pipelined unet to hybrid unet): ") - try: - np.testing.assert_allclose( - hybrid_output, turbine_output, rtol=4e-2, atol=4e-2 + if args.scheduler_id == "PNDM" and args.scheduler_vmfb_path: + print( + f"WARNING: {args.scheduler_id} scheduler normally uses data-dependent control flow with counters and other data dependence. Expect different results after 1 step." ) - print("passed!") - except AssertionError as err: - print(err) - print("\n(hybrid unet to diff unet): ") - try: - np.testing.assert_allclose(diff_output, hybrid_output, rtol=4e-2, atol=4e-2) - print("passed!") - except AssertionError as err: - print(err) - print("\n(turbine loop to diffusers loop): ") + print("generating torch output: ") + torch_output = run_torch_diffusers_loop( + sample, + prompt_embeds, + text_embeds, + args, + ) + print("torch OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + + print("\n(torch (diffusers) image latents to iree image latents): ") try: np.testing.assert_allclose( - turbine_output, diff_output, rtol=4e-2, atol=4e-2 + turbine_output, torch_output, rtol=4e-2, atol=4e-2 ) print("passed!") except AssertionError as err: + if args.scheduler_id == "EulerAncestralDiscrete": + print( + "Expected failure matching numerics due to intentionally random noise in results." + ) print(err) - print("\n(torch sched unet loop to diffusers loop): ") - try: - np.testing.assert_allclose(torch_output, diff_output, rtol=4e-2, atol=4e-2) - print("passed!") - except AssertionError as err: - print(err) - - # TODO: Figure out why we occasionally segfault without unlinking output variables - turbine_output = None diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py deleted file mode 100644 index a3ae29595..000000000 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# from @aviator19941's gist : https://gist.github.com/aviator19941/4e7967bd1787c83ee389a22637c6eea7 - -import os -import sys - -from iree import runtime as ireert -from iree.compiler.ir import Context -import numpy as np -from shark_turbine.aot import * -from turbine_models.custom_models.sd_inference import utils -import torch -import torch._dynamo as dynamo -from diffusers import UNet2DConditionModel -from shark_turbine.dynamo.passes import ( - DEFAULT_DECOMPOSITIONS, -) - -import safetensors - - -class SDXLScheduler(torch.nn.Module): - def __init__( - self, - hf_model_name, - num_inference_steps, - scheduler, - hf_auth_token=None, - precision="fp32", - ): - super().__init__() - self.scheduler = scheduler - self.scheduler.set_timesteps(num_inference_steps) - self.guidance_scale = 7.5 - if precision == "fp16": - try: - self.unet = UNet2DConditionModel.from_pretrained( - hf_model_name, - subfolder="unet", - auth_token=hf_auth_token, - low_cpu_mem_usage=False, - variant="fp16", - ) - except: - self.unet = UNet2DConditionModel.from_pretrained( - hf_model_name, - subfolder="unet", - auth_token=hf_auth_token, - low_cpu_mem_usage=False, - ) - else: - self.unet = UNet2DConditionModel.from_pretrained( - hf_model_name, - subfolder="unet", - auth_token=hf_auth_token, - low_cpu_mem_usage=False, - ) - - def forward(self, sample, prompt_embeds, text_embeds, time_ids): - sample = sample * self.scheduler.init_noise_sigma - for t in self.scheduler.timesteps: - with torch.no_grad(): - added_cond_kwargs = { - "text_embeds": text_embeds, - "time_ids": time_ids, - } - latent_model_input = torch.cat([sample] * 2) - t = t.unsqueeze(0) - # print('UNSQUEEZE T:', t) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, timestep=t - ) - noise_pred = self.unet.forward( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=None, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - sample = self.scheduler.step(noise_pred, t, sample, return_dict=False)[ - 0 - ] - return sample - - -def export_scheduler( - scheduler, - hf_model_name, - batch_size, - height, - width, - hf_auth_token=None, - compile_to="torch", - external_weights=None, - external_weight_path=None, - device=None, - target_triple=None, - ireec_flags=None, -): - mapper = {} - utils.save_external_weights( - mapper, scheduler, external_weights, external_weight_path - ) - - decomp_list = DEFAULT_DECOMPOSITIONS - - decomp_list.extend( - [ - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, - torch.ops.aten._scaled_dot_product_flash_attention.default, - ] - ) - # tensor shapes for tracing - sample = (batch_size, 4, height // 8, width // 8) - prompt_embeds = (2, 77, 2048) - text_embeds = (2, 1280) - time_ids = (2, 6) - - class CompiledScheduler(CompiledModule): - if external_weights: - params = export_parameters( - scheduler, external=True, external_scope="", name_mapper=mapper.get - ) - else: - params = export_parameters(scheduler) - - def main( - self, - sample=AbstractTensor(*sample, dtype=torch.float32), - prompt_embeds=AbstractTensor(*prompt_embeds, dtype=torch.float32), - text_embeds=AbstractTensor(*text_embeds, dtype=torch.float32), - time_ids=AbstractTensor(*time_ids, dtype=torch.float32), - ): - return jittable(scheduler.forward, decompose_ops=decomp_list)( - sample, prompt_embeds, text_embeds, time_ids - ) - - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledScheduler(context=Context(), import_to=import_to) - - module_str = str(CompiledModule.get_mlir_module(inst)) - - safe_name = utils.create_safe_name(hf_model_name, "-scheduler") - with open(f"{safe_name}.mlir", "w+") as f: - f.write(module_str) - print("Saved to", safe_name + ".mlir") - - if compile_to != "vmfb": - return module_str - else: - utils.compile_to_vmfb(module_str, device, target_triple, ireec_flags, safe_name) - - -if __name__ == "__main__": - from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args - - hf_model_name = "stabilityai/stable-diffusion-xl-base-1.0" - schedulers = utils.get_schedulers(args.hf_model_name) - scheduler = schedulers[args.scheduler_id] - scheduler_module = SDXLScheduler( - args.hf_model_name, - args.num_inference_steps, - scheduler, - hf_auth_token=None, - precision=args.precision, - ) - - print("export scheduler begin") - mod_str = export_scheduler( - scheduler_module, - args.hf_model_name, - args.batch_size, - args.height, - args.width, - args.hf_auth_token, - args.compile_to, - args.external_weights, - args.external_weight_path, - args.device, - args.iree_target_triple, - args.ireec_flags, - ) - print("export scheduler complete") - safe_name = utils.create_safe_name(args.hf_model_name, "-scheduler") - with open(f"{safe_name}.mlir", "w+") as f: - f.write(mod_str) - print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index e9839ba06..bd36db763 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -4,24 +4,27 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import copy import os import sys +import safetensors from iree import runtime as ireert from iree.compiler.ir import Context import numpy as np from shark_turbine.aot import * -from shark_turbine.dynamo.passes import ( - DEFAULT_DECOMPOSITIONS, -) +from shark_turbine.transforms.general.add_metadata import AddMetadataPass + + from turbine_models.custom_models.sd_inference import utils import torch -import torch._dynamo as dynamo -from diffusers import UNet2DConditionModel +from huggingface_hub import hf_hub_download class UnetModel(torch.nn.Module): def __init__(self, hf_model_name, hf_auth_token=None, precision="fp32"): + from diffusers import UNet2DConditionModel + super().__init__() if precision == "fp16": try: @@ -46,41 +49,114 @@ def __init__(self, hf_model_name, hf_auth_token=None, precision="fp32"): auth_token=hf_auth_token, low_cpu_mem_usage=False, ) - if "turbo" in hf_model_name: - self.do_classifier_free_guidance = False - else: - self.do_classifier_free_guidance = True + self.do_classifier_free_guidance = True def forward( - self, sample, timestep, prompt_embeds, text_embeds, time_ids, guidance_scale + self, + latent_model_input, + timestep, + prompt_embeds, + text_embeds, + time_ids, + guidance_scale, ): - with torch.no_grad(): - added_cond_kwargs = { - "text_embeds": text_embeds, - "time_ids": time_ids, - } - if self.do_classifier_free_guidance: - latent_model_input = torch.cat([sample] * 2) - else: - latent_model_input = sample - noise_pred = self.unet.forward( - latent_model_input, - timestep, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=None, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + added_cond_kwargs = { + "text_embeds": text_embeds, + "time_ids": time_ids, + } + if self.do_classifier_free_guidance: + latent_model_input = torch.cat([latent_model_input] * 2) + noise_pred = self.unet.forward( + latent_model_input, + timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + if self.do_classifier_free_guidance: + noise_preds = noise_pred.chunk(2) + noise_pred = noise_preds[0] + guidance_scale * ( + noise_preds[1] - noise_preds[0] + ) return noise_pred +def get_punet_model(hf_model_name, external_weight_path, precision="i8"): + from sharktank.models.punet.model import ( + Unet2DConditionModel as sharktank_unet2d, + ClassifierFreeGuidanceUnetModel as sharktank_CFGPunetModel, + ) + from sharktank.utils import cli + + if precision == "i8": + repo_id = "amd-shark/sdxl-quant-models" + subfolder = "unet/int8" + revision = "942e771bf0c2657a8b33380103d04747a75dfa4a" + elif precision in ["fp16", "fp32"]: + repo_id = hf_model_name + subfolder = "unet" + revision = "76d28af79639c28a79fa5c6c6468febd3490a37e" + + def download(filename): + return hf_hub_download( + repo_id=repo_id, subfolder=subfolder, filename=filename, revision=revision + ) + + results = { + "config.json": download("config.json"), + "params.safetensors": download("params.safetensors"), + } + output_dir = os.path.dirname(external_weight_path) + + if precision == "i8": + results["quant_params.json"] = download("quant_params.json") + ds_filename = os.path.basename(external_weight_path) + output_path = os.path.join(output_dir, ds_filename) + ds = get_punet_dataset( + results["config.json"], + results["params.safetensors"], + output_path, + results["quant_params.json"], + ) + else: + ds_filename = ( + os.path.basename(external_weight_path).split("unet")[0] + + f"punet_dataset_{precision}.irpa" + ) + output_path = os.path.join(output_dir, ds_filename) + ds = get_punet_dataset( + results["config.json"], + results["params.safetensors"], + output_path, + ) + + cond_unet = sharktank_unet2d.from_dataset(ds) + mdl = sharktank_CFGPunetModel(cond_unet) + return mdl + + +def get_punet_dataset( + config_json_path, + params_path, + output_path, + quant_params_path=None, +): + from sharktank.models.punet.tools import import_brevitas_dataset + + ds_import_args = [ + f"--config-json={config_json_path}", + f"--params={params_path}", + f"--output-irpa-file={output_path}", + ] + if quant_params_path: + ds_import_args.extend([f"--quant-params={quant_params_path}"]) + import_brevitas_dataset.main(ds_import_args) + return import_brevitas_dataset.Dataset.load(output_path) + + +@torch.no_grad() def export_unet_model( - unet_model, hf_model_name, batch_size, height, @@ -92,39 +168,41 @@ def export_unet_model( external_weights=None, external_weight_path=None, device=None, - target_triple=None, + target=None, ireec_flags=None, decomp_attn=False, exit_on_vmfb=False, + pipeline_dir=None, attn_spec=None, input_mlir=None, weights_only=False, + use_punet=False, ): - if "turbo" in hf_model_name: - do_classifier_free_guidance = False + if use_punet: + submodel_name = "punet" else: - do_classifier_free_guidance = True - - if ( - (attn_spec in ["default"]) - and decomp_attn == False - and ("gfx9" in target_triple) - ): - attn_spec = os.path.join( - os.path.realpath(os.path.dirname(__file__)), "default_mfma_attn_spec.mlir" - ) - elif decomp_attn: - attn_spec = None - + submodel_name = "unet" + if (not decomp_attn) and use_punet: + attn_spec = "punet" + elif (not decomp_attn) and "gfx9" in target: + attn_spec = "mfma" + elif (not decomp_attn) and "gfx11" in target: + attn_spec = "wmma" safe_name = utils.create_safe_name( - hf_model_name, f"_{max_length}_{height}x{width}_{precision}_unet_{device}" + hf_model_name, + f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_{submodel_name}", ) + if pipeline_dir: + safe_name = os.path.join(pipeline_dir, safe_name) + + if decomp_attn == True: + ireec_flags += ",--iree-opt-aggressively-propagate-transposes=False" if input_mlir: vmfb_path = utils.compile_to_vmfb( input_mlir, device, - target_triple, + target, ireec_flags, safe_name, mlir_source="file", @@ -132,81 +210,140 @@ def export_unet_model( attn_spec=attn_spec, ) return vmfb_path + elif use_punet: + unet_model = get_punet_model(hf_model_name, external_weight_path, precision) + else: + unet_model = UnetModel(hf_model_name, hf_auth_token, precision) mapper = {} - decomp_list = DEFAULT_DECOMPOSITIONS - if decomp_attn == True: - decomp_list.extend( - [ - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, - torch.ops.aten._scaled_dot_product_flash_attention.default, - ] - ) - dtype = torch.float16 if precision == "fp16" else torch.float32 + np_dtypes = { + "fp16": "float16", + "fp32": "float32", + "i8": "int8", + } + torch_dtypes = { + "fp16": torch.float16, + "fp32": torch.float32, + "i8": torch.int8, + } + dtype = torch_dtypes[precision] + np_dtype = np_dtypes[precision] - if precision == "fp16": + if precision == "fp16" and not use_punet: unet_model = unet_model.half() - utils.save_external_weights( - mapper, unet_model, external_weights, external_weight_path - ) + if use_punet: + dtype = torch.float16 + + if not use_punet: + utils.save_external_weights( + mapper, unet_model, external_weights, external_weight_path + ) if weights_only: return external_weight_path - sample = ( + do_classifier_free_guidance = True + init_batch_dim = 2 if do_classifier_free_guidance else 1 + + sample = [ batch_size, - unet_model.unet.config.in_channels, + 4, height // 8, width // 8, - ) - if do_classifier_free_guidance: - init_batch_dim = 2 - else: - init_batch_dim = 1 + ] time_ids_shape = (init_batch_dim * batch_size, 6) prompt_embeds_shape = (init_batch_dim * batch_size, max_length, 2048) text_embeds_shape = (init_batch_dim * batch_size, 1280) - - class CompiledUnet(CompiledModule): - if external_weights: - params = export_parameters( - unet_model, external=True, external_scope="", name_mapper=mapper.get + example_forward_args = [ + torch.empty(sample, dtype=dtype), + torch.empty(1, dtype=dtype), + torch.empty(prompt_embeds_shape, dtype=dtype), + torch.empty(text_embeds_shape, dtype=dtype), + torch.empty(time_ids_shape, dtype=dtype), + torch.tensor([7.5], dtype=dtype), + ] + example_forward_args_dict = { + "sample": torch.rand(sample, dtype=dtype), + "timestep": torch.zeros(1, dtype=dtype), + "encoder_hidden_states": torch.rand(prompt_embeds_shape, dtype=dtype), + "text_embeds": torch.rand(text_embeds_shape, dtype=dtype), + "time_ids": torch.zeros(time_ids_shape, dtype=dtype), + "guidance_scale": torch.tensor([7.5], dtype=dtype), + } + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + if use_punet: + output = export( + unet_model, + kwargs=example_forward_args_dict, + module_name="compiled_punet", ) + module = output.mlir_module else: - params = export_parameters(unet_model) - - def main( - self, - sample=AbstractTensor(*sample, dtype=dtype), - timestep=AbstractTensor(1, dtype=torch.int64), - prompt_embeds=AbstractTensor(*prompt_embeds_shape, dtype=dtype), - text_embeds=AbstractTensor(*text_embeds_shape, dtype=dtype), - time_ids=AbstractTensor(*time_ids_shape, dtype=dtype), - guidance_scale=AbstractTensor(1, dtype=dtype), - ): - return jittable(unet_model.forward, decompose_ops=decomp_list)( - sample, timestep, prompt_embeds, text_embeds, time_ids, guidance_scale + if external_weights: + externalize_module_parameters(unet_model) + fxb = FxProgramsBuilder(unet_model) + + @fxb.export_program( + args=(example_forward_args,), ) + def _forward( + module, + inputs, + ): + return module.forward(*inputs) + + class CompiledUnet(CompiledModule): + run_forward = _forward + + inst = CompiledUnet(context=Context(), import_to="IMPORT") - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledUnet(context=Context(), import_to=import_to) + module = CompiledModule.get_mlir_module(inst) - module_str = str(CompiledModule.get_mlir_module(inst)) + model_metadata_run_forward = { + "model_name": "sd_unet", + "input_shapes": [ + sample, + (1,), + prompt_embeds_shape, + text_embeds_shape, + time_ids_shape, + (1,), + ], + "input_dtypes": [np_dtype for x in range(6)], + "output_shapes": [sample], + "output_dtypes": [np_dtype], + } + module = AddMetadataPass(module, model_metadata_run_forward, "run_forward").run() + + module_str = str(module) if compile_to != "vmfb": return module_str else: - utils.compile_to_vmfb( + vmfb_path = utils.compile_to_vmfb( module_str, device, - target_triple, + target, ireec_flags, safe_name, - return_path=False, + return_path=True, attn_spec=attn_spec, ) + if exit_on_vmfb: + exit() + return vmfb_path if __name__ == "__main__": @@ -246,8 +383,9 @@ def main( exit() safe_name = utils.create_safe_name( args.hf_model_name, - f"_{args.max_length}_{args.height}x{args.width}_{args.precision}_unet", + f"_bs{args.batch_size}_{args.max_length}_{args.height}x{args.width}_{args.precision}_{'p' if args.use_i8_punet else ''}unet", ) - with open(f"{safe_name}.mlir", "w+") as f: - f.write(mod_str) - print("Saved to", safe_name + ".mlir") + if args.compile_to != "vmfb": + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sdxl_inference/unet_runner.py b/models/turbine_models/custom_models/sdxl_inference/unet_runner.py index 197d850a9..c474982d7 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet_runner.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet_runner.py @@ -33,7 +33,7 @@ def run_unet( ireert.asdevicearray(runner.config.device, time_ids), ireert.asdevicearray(runner.config.device, guidance_scale), ] - results = runner.ctx.modules.compiled_unet["main"](*inputs) + results = runner.ctx.modules.compiled_unet["run_forward"](*inputs) return results @@ -57,7 +57,7 @@ def run_unet_steps( ireert.asdevicearray(runner.config.device, prompt_embeds), ireert.asdevicearray(runner.config.device, text_embeds), ireert.asdevicearray(runner.config.device, time_ids), - ireert.asdevicearray(runner.config.device, (guidance_scale,)), + ireert.asdevicearray(runner.config.device, guidance_scale), ] for i, t in tqdm(enumerate(scheduler.timesteps)): timestep = t @@ -69,7 +69,7 @@ def run_unet_steps( inputs[1] = timestep = ireert.asdevicearray( runner.config.device, (timestep,), dtype="int64" ) - noise_pred = runner.ctx.modules.compiled_unet["main"](*inputs).to_host() + noise_pred = runner.ctx.modules.compiled_unet["run_forward"](*inputs).to_host() sample = scheduler.step( torch.from_numpy(noise_pred).cpu(), timestep, @@ -112,15 +112,31 @@ def run_torch_unet( dtype = torch.float16 else: dtype = torch.float32 + + save_inputs = True + sample = torch.rand( args.batch_size, 4, args.height // 8, args.width // 8, dtype=dtype ) - timestep = torch.zeros(1, dtype=torch.int64) + timestep = torch.ones(1, dtype=dtype) prompt_embeds = torch.rand(2 * args.batch_size, args.max_length, 2048, dtype=dtype) text_embeds = torch.rand(2 * args.batch_size, 1280, dtype=dtype) - time_ids = torch.zeros(2 * args.batch_size, 6, dtype=dtype) + time_ids = torch.rand(2 * args.batch_size, 6, dtype=dtype) guidance_scale = torch.tensor([7.5], dtype=dtype) + if save_inputs: + import os + + inputs_dir = "sdxl_unet_inputs_" + args.precision + if not os.path.exists(inputs_dir): + os.mkdir(inputs_dir) + np.save("input1.npy", sample) + np.save("input2.npy", timestep) + np.save("input3.npy", prompt_embeds) + np.save("input4.npy", text_embeds) + np.save("input5.npy", time_ids) + np.save("input6.npy", guidance_scale) + turbine_output = run_unet( args.device, sample, @@ -133,12 +149,12 @@ def run_torch_unet( args.hf_model_name, args.hf_auth_token, args.external_weight_path, - ) + ).to_host() print( "TURBINE OUTPUT:", - turbine_output.to_host(), - turbine_output.to_host().shape, - turbine_output.to_host().dtype, + turbine_output, + turbine_output.shape, + turbine_output.dtype, ) if args.compare_vs_torch: @@ -158,9 +174,8 @@ def run_torch_unet( # precision="fp16", ) print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) - err = utils.largest_error(torch_output, turbine_output) - print("Largest Error: ", err) - assert err < 9e-3 - - # TODO: Figure out why we occasionally segfault without unlinking output variables - turbine_output = None + if save_inputs: + np.save("golden_out.npy", torch_output) + atol = 4e-2 + rtol = 4e-1 + np.testing.assert_allclose(turbine_output, torch_output, atol=atol, rtol=rtol) diff --git a/models/turbine_models/custom_models/sdxl_inference/vae.py b/models/turbine_models/custom_models/sdxl_inference/vae.py index 7563eed96..753cbb9e7 100644 --- a/models/turbine_models/custom_models/sdxl_inference/vae.py +++ b/models/turbine_models/custom_models/sdxl_inference/vae.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import copy import os import sys @@ -18,6 +19,7 @@ import torch import torch._dynamo as dynamo from diffusers import AutoencoderKL +import safetensors class VaeModel(torch.nn.Module): @@ -33,6 +35,14 @@ def __init__( hf_model_name, subfolder="vae", ) + elif "safetensors" in custom_vae: + custom_vae = safetensors.torch.load_file(custom_vae) + # custom vae as a HF state dict + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + self.vae.load_state_dict(custom_vae) elif not isinstance(custom_vae, dict): try: # custom HF repo with no vae subfolder @@ -45,20 +55,13 @@ def __init__( custom_vae, subfolder="vae", ) - else: - # custom vae as a HF state dict - self.vae = AutoencoderKL.from_pretrained( - hf_model_name, - subfolder="vae", - ) - self.vae.load_state_dict(custom_vae) - def decode_inp(self, inp): - inp = 1 / 0.13025 * inp - x = self.vae.decode(inp, return_dict=False)[0] + def decode(self, inp): + img = 1 / 0.13025 * inp + x = self.vae.decode(img, return_dict=False)[0] return (x / 2 + 0.5).clamp(0, 1) - def encode_inp(self, inp): + def encode(self, inp): latents = self.vae.encode(inp).latent_dist.sample() return 0.13025 * latents @@ -84,75 +87,89 @@ def export_vae_model( input_mlir=None, weights_only=False, ): - if ( - (attn_spec in ["default"]) - and decomp_attn == False - and ("gfx9" in target_triple) - ): - attn_spec = os.path.join( - os.path.realpath(os.path.dirname(__file__)), "default_mfma_attn_spec.mlir" - ) - elif decomp_attn: - attn_spec = None - + safe_name = utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{height}x{width}_{precision}_vae_{variant}", + ) if pipeline_dir: - safe_name = os.path.join(pipeline_dir, "vae_" + variant) - else: - safe_name = utils.create_safe_name( - hf_model_name, f"_{height}x{width}_{precision}_vae_{variant}_{device}" - ) + safe_name = os.path.join(pipeline_dir, safe_name) + if input_mlir: vmfb_path = utils.compile_to_vmfb( input_mlir, device, target_triple, ireec_flags, - safe_name, + safe_name + "_" + target_triple, mlir_source="file", return_path=not exit_on_vmfb, attn_spec=attn_spec, ) return vmfb_path + # if precision == "fp32" and device == "rocm": + # decomp_attn = True + # external_weights = None + # print("Decomposing attention and inlining weights for fp32 VAE on ROCm") + if device == "cpu": + decomp_attn = True - mapper = {} - decomp_list = DEFAULT_DECOMPOSITIONS - if decomp_attn == True: - decomp_list.extend( - [ - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, - torch.ops.aten._scaled_dot_product_flash_attention.default, - ] - ) dtype = torch.float16 if precision == "fp16" else torch.float32 if precision == "fp16": vae_model = vae_model.half() + + mapper = {} + utils.save_external_weights( mapper, vae_model, external_weights, external_weight_path ) if weights_only: return external_weight_path - sample = (batch_size, 4, height // 8, width // 8) - if variant == "encode": - sample = (batch_size, 3, height, width) - class CompiledVae(CompiledModule): - if external_weights: - params = export_parameters( - vae_model, external=True, external_scope="", name_mapper=mapper.get - ) - else: - params = export_parameters(vae_model) + input_image_shape = (height, width, 3) + input_latents_shape = (batch_size, 4, height // 8, width // 8) + encode_args = [ + torch.empty( + input_image_shape, + dtype=torch.float32, + ) + ] + decode_args = [ + torch.empty( + input_latents_shape, + dtype=dtype, + ) + ] + decomp_list = [] + if decomp_attn == True: + safe_name += "_decomp" + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + fxb = FxProgramsBuilder(vae_model) - def main(self, inp=AbstractTensor(*sample, dtype=dtype)): - if variant == "decode": - return jittable(vae_model.decode_inp, decompose_ops=decomp_list)(inp) - elif variant == "encode": - return jittable(vae_model.encode_inp, decompose_ops=decomp_list)(inp) + # @fxb.export_program(args=(encode_args,)) + # def _encode(module, inputs,): + # return module.encode(*inputs) + + @fxb.export_program(args=(decode_args,)) + def _decode(module, inputs): + return module.decode(*inputs) + + class CompiledVae(CompiledModule): + main = _decode + + if external_weights: + externalize_module_parameters(vae_model) - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledVae(context=Context(), import_to=import_to) + inst = CompiledVae(context=Context(), import_to="IMPORT") - module_str = str(CompiledModule.get_mlir_module(inst)) + module_str = str(CompiledModule.get_mlir_module(inst)) if compile_to != "vmfb": return module_str @@ -162,7 +179,7 @@ def main(self, inp=AbstractTensor(*sample, dtype=dtype)): device, target_triple, ireec_flags, - safe_name, + safe_name + "_" + target_triple, return_path=not exit_on_vmfb, attn_spec=attn_spec, ) @@ -206,7 +223,7 @@ def main(self, inp=AbstractTensor(*sample, dtype=dtype)): exit() safe_name = utils.create_safe_name( args.hf_model_name, - f"_{args.height}x{args.width}_{args.precision}_vae_{args.vae_variant}", + f"_bs{str(args.batch_size)}_{args.height}x{args.width}_{args.precision}_vae_{args.vae_variant}", ) with open(f"{safe_name}.mlir", "w+") as f: f.write(mod_str) diff --git a/models/turbine_models/custom_models/sdxl_inference/vae_runner.py b/models/turbine_models/custom_models/sdxl_inference/vae_runner.py index 539c99868..01767d322 100644 --- a/models/turbine_models/custom_models/sdxl_inference/vae_runner.py +++ b/models/turbine_models/custom_models/sdxl_inference/vae_runner.py @@ -15,64 +15,22 @@ def run_vae( ): runner = vmfbRunner(device, vmfb_path, external_weight_path) inputs = [ireert.asdevicearray(runner.config.device, example_input)] - results = runner.ctx.modules.compiled_vae["main"](*inputs) + results = runner.ctx.modules.compiled_vae["decode"](*inputs) return results def run_torch_vae(hf_model_name, custom_vae, variant, example_input): - from diffusers import AutoencoderKL - - class VaeModel(torch.nn.Module): - def __init__( - self, - hf_model_name, - custom_vae=custom_vae, - ): - super().__init__() - self.vae = None - if custom_vae in ["", None]: - self.vae = AutoencoderKL.from_pretrained( - hf_model_name, - subfolder="vae", - ) - elif not isinstance(custom_vae, dict): - try: - # custom HF repo with no vae subfolder - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - ) - except: - # some larger repo with vae subfolder - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - subfolder="vae", - ) - else: - # custom vae as a HF state dict - self.vae = AutoencoderKL.from_pretrained( - hf_model_name, - subfolder="vae", - ) - self.vae.load_state_dict(custom_vae) - - def decode_inp(self, inp): - inp = inp / 0.13025 - x = self.vae.decode(inp, return_dict=False)[0] - return (x / 2 + 0.5).clamp(0, 1) - - def encode_inp(self, inp): - latents = self.vae.encode(inp).latent_dist.sample() - return 0.13025 * latents + from turbine_models.custom_models.sd_inference.vae import VaeModel vae_model = VaeModel( hf_model_name, ) if variant == "decode": - results = vae_model.decode_inp(example_input) + results = vae_model.decode(example_input) elif variant == "encode": - results = vae_model.encode_inp(example_input) + results = vae_model.encode(example_input) np_torch_output = results.detach().cpu().numpy() return np_torch_output diff --git a/models/turbine_models/custom_models/stateless_llama.py b/models/turbine_models/custom_models/stateless_llama.py index baa4e2348..74fd4d421 100644 --- a/models/turbine_models/custom_models/stateless_llama.py +++ b/models/turbine_models/custom_models/stateless_llama.py @@ -13,6 +13,7 @@ from turbine_models.custom_models.llm_optimizations.streaming_llm.modify_llama import ( enable_llama_pos_shift_attention, ) +from turbine_models.custom_models.sd_inference.utils import compile_to_vmfb from turbine_models.custom_models import remap_gguf import safetensors @@ -62,6 +63,11 @@ action="store_true", help="Compile LLM with StreamingLLM optimizations", ) +parser.add_argument( + "--decomp_attn", + action="store_true", + help="Decompose attention ops at fx graph level.", +) def generate_schema(num_layers): @@ -116,14 +122,39 @@ def export_transformer_model( quantization=None, precision=None, device=None, - target_triple=None, + target_triple="x86_64-unknown-linux-gnu", vulkan_max_allocation=None, streaming_llm=False, vmfb_path=None, upload_ir=False, mod=None, tokenizer=None, + decomp_attn=False, + input_mlir=None, ): + safe_name = hf_model_name.replace("-", "_").replace("/", "_") + if streaming_llm: + safe_name += "_streaming" + if not vmfb_path: + vmfb_path = safe_name + "_" + target_triple + + iree_flags = [] + ukernel_supported_arch = {"gfx90a", "gfx940", "gfx1030", "gfx1100"} + if target_triple in ukernel_supported_arch: + iree_flags.extend(["--iree-rocm-enable-ukernels=argmax"]) + if input_mlir is not None: + vmfb_path = compile_to_vmfb( + input_mlir, + device, + target_triple, + ireec_flags=iree_flags, + safe_name=vmfb_path.split(".vmfb")[0], + return_path=True, + const_expr_hoisting=True, + mlir_source="file", + save_mlir=False, + attn_spec="mfma" if "gfx9" in target_triple else "wmma", + ) if tokenizer == None: tokenizer = AutoTokenizer.from_pretrained( hf_model_name, @@ -175,243 +206,261 @@ def export_transformer_model( tensor_mapper = remap_gguf.TensorNameMap(remap_gguf.MODEL_ARCH.LLAMA, HEADS) mapper = tensor_mapper.mapping - class StateUpdateModule(CompiledModule): - if external_weights: - params = export_parameters( - mod, external=True, external_scope="", name_mapper=mapper.get - ) - else: - params = export_parameters(mod) - global_seq_step = export_global(AbstractIndex, mutable=True) - global_k_caches = export_global_tree( - kv_cache_structure, uninitialized=True, mutable=True - ) - global_v_caches = export_global_tree( - kv_cache_structure, uninitialized=True, mutable=True - ) + initial_table = decompositions.current_aot_decompositions() + print("Decomposing torch SDPA") + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=[ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.masked_fill_.Scalar, + torch.ops.aten.copy, + ], + ): + current_table = decompositions.current_aot_decompositions() - def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)): - init_const = [x.dynamic_dim(1) < MAX_STEP_SEQ] - token, *state = self.initialize(x, constraints=init_const) - self.global_seq_step = IREE.tensor_dim( - state[0], 1 - ) # ? dimension of arbitrarily 0th kv tensor - for i in range(NUM_LAYERS): - slice_of_state = IREE.tensor_reshape( - state[i * 2], 1, self.global_seq_step, HEADS, HIDDEN_DIM + class StateUpdateModule(CompiledModule): + if external_weights: + params = export_parameters( + mod, external=True, external_scope="", name_mapper=mapper.get ) - self.global_k_caches["layer_idx"][i] = IREE.tensor_update( - self.global_k_caches["layer_idx"][i], slice_of_state, 0, 0, 0, 0 - ) - for i in range(NUM_LAYERS): - slice_of_state = IREE.tensor_reshape( - state[i * 2 + 1], 1, self.global_seq_step, HEADS, HIDDEN_DIM - ) - self.global_v_caches["layer_idx"][i] = IREE.tensor_update( - self.global_v_caches["layer_idx"][i], slice_of_state, 0, 0, 0, 0 - ) - return token - - def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)): - state_arg = slice_up_to_step( - self.global_k_caches, - self.global_v_caches, - self.global_seq_step, - HEADS, - HIDDEN_DIM, - NUM_LAYERS, + else: + params = export_parameters(mod) + global_seq_step = export_global(AbstractIndex, mutable=True) + global_k_caches = export_global_tree( + kv_cache_structure, uninitialized=True, mutable=True ) - forw_const = ( - [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] - + [ - x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) - for x in state_arg[1:] - ] - + [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]] + global_v_caches = export_global_tree( + kv_cache_structure, uninitialized=True, mutable=True ) - token, *state_update = self.forward(x, *state_arg, constraints=forw_const) - for i in range(NUM_LAYERS): - update = IREE.tensor_reshape( - state_update[i * 2], 1, 1, HEADS, HIDDEN_DIM - ) - self.global_k_caches["layer_idx"][i] = IREE.tensor_update( - self.global_k_caches["layer_idx"][i], - update, - 0, + + def run_initialize( + self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64) + ): + init_const = [x.dynamic_dim(1) < MAX_STEP_SEQ] + token, *state = self.initialize(x, constraints=init_const) + self.global_seq_step = IREE.tensor_dim( + state[0], 1 + ) # ? dimension of arbitrarily 0th kv tensor + for i in range(NUM_LAYERS): + slice_of_state = IREE.tensor_reshape( + state[i * 2], 1, self.global_seq_step, HEADS, HIDDEN_DIM + ) + self.global_k_caches["layer_idx"][i] = IREE.tensor_update( + self.global_k_caches["layer_idx"][i], slice_of_state, 0, 0, 0, 0 + ) + for i in range(NUM_LAYERS): + slice_of_state = IREE.tensor_reshape( + state[i * 2 + 1], 1, self.global_seq_step, HEADS, HIDDEN_DIM + ) + self.global_v_caches["layer_idx"][i] = IREE.tensor_update( + self.global_v_caches["layer_idx"][i], slice_of_state, 0, 0, 0, 0 + ) + return token + + def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)): + state_arg = slice_up_to_step( + self.global_k_caches, + self.global_v_caches, self.global_seq_step, - 0, - 0, + HEADS, + HIDDEN_DIM, + NUM_LAYERS, ) - for i in range(NUM_LAYERS): - update = IREE.tensor_reshape( - state_update[i * 2 + 1], 1, 1, HEADS, HIDDEN_DIM + forw_const = ( + [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] + + [ + x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) + for x in state_arg[1:] + ] + + [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]] ) - self.global_v_caches["layer_idx"][i] = IREE.tensor_update( - self.global_v_caches["layer_idx"][i], - update, - 0, - self.global_seq_step, - 0, - 0, + token, *state_update = self.forward( + x, *state_arg, constraints=forw_const ) - self.global_seq_step = self.global_seq_step + 1 - return token + for i in range(NUM_LAYERS): + update = IREE.tensor_reshape( + state_update[i * 2], 1, 1, HEADS, HIDDEN_DIM + ) + self.global_k_caches["layer_idx"][i] = IREE.tensor_update( + self.global_k_caches["layer_idx"][i], + update, + 0, + self.global_seq_step, + 0, + 0, + ) + for i in range(NUM_LAYERS): + update = IREE.tensor_reshape( + state_update[i * 2 + 1], 1, 1, HEADS, HIDDEN_DIM + ) + self.global_v_caches["layer_idx"][i] = IREE.tensor_update( + self.global_v_caches["layer_idx"][i], + update, + 0, + self.global_seq_step, + 0, + 0, + ) + self.global_seq_step = self.global_seq_step + 1 + return token - def get_seq_step(self): - return self.global_seq_step + def get_seq_step(self): + return self.global_seq_step - @jittable - def initialize(input_ids): - result = mod.forward(input_ids) - state1_flat, _ = pytree.tree_flatten(result.past_key_values) - token1 = torch.argmax(result.logits[:, -1, :], dim=1) - token1 = token1[None, :] - state1_flat = [torch.transpose(x, 1, 2) for x in state1_flat] - return token1, *state1_flat + @jittable + def initialize(input_ids): + result = mod.forward(input_ids) + state1_flat, _ = pytree.tree_flatten(result.past_key_values) + token1 = torch.argmax(result.logits[:, -1, :], dim=1) + token1 = token1[None, :] + state1_flat = [torch.transpose(x, 1, 2) for x in state1_flat] + return token1, *state1_flat - @jittable - def forward(token0: torch.Tensor, *state0_flat): - # Unpad the states. - state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat] - state0 = pytree.tree_unflatten(state0_flat, state_schema) - result = mod.forward(token0, past_key_values=state0) - state1_flat, _ = pytree.tree_flatten(result.past_key_values) - state1_flat = [torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat] - token1 = torch.argmax(result.logits[:, -1, :], dim=1) - token1 = token1[None, :] - return token1, *state1_flat - - class StreamingStateUpdateModule(StateUpdateModule): - def run_cached_initialize( - self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64) - ): - state_arg = slice_up_to_step( - self.global_k_caches, - self.global_v_caches, - self.global_seq_step, - HEADS, - HIDDEN_DIM, - NUM_LAYERS, - ) - forw_const = ( - [x.dynamic_dim(1) < MAX_STEP_SEQ] - + [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] - + [ - x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) - for x in state_arg[1:] + @jittable + def forward(token0: torch.Tensor, *state0_flat): + # Unpad the states. + state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat] + state0 = pytree.tree_unflatten(state0_flat, state_schema) + result = mod.forward(token0, past_key_values=state0) + state1_flat, _ = pytree.tree_flatten(result.past_key_values) + state1_flat = [ + torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat ] - + [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]] - ) - token, *state = self.cached_initialize( - x, *state_arg, constraints=forw_const - ) - len_of_new_tokens = IREE.tensor_dim( - state[0], 1 - ) # ? dimension of arbitrarily 0th kv tensor - for i in range(NUM_LAYERS): - slice_of_state = IREE.tensor_reshape( - state[i * 2], 1, len_of_new_tokens, HEADS, HIDDEN_DIM - ) - self.global_k_caches["layer_idx"][i] = IREE.tensor_update( - self.global_k_caches["layer_idx"][i], - slice_of_state, - 0, + token1 = torch.argmax(result.logits[:, -1, :], dim=1) + token1 = token1[None, :] + return token1, *state1_flat + + class StreamingStateUpdateModule(StateUpdateModule): + def run_cached_initialize( + self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64) + ): + state_arg = slice_up_to_step( + self.global_k_caches, + self.global_v_caches, self.global_seq_step, - 0, - 0, + HEADS, + HIDDEN_DIM, + NUM_LAYERS, ) - for i in range(NUM_LAYERS): - slice_of_state = IREE.tensor_reshape( - state[i * 2 + 1], 1, len_of_new_tokens, HEADS, HIDDEN_DIM + forw_const = ( + [x.dynamic_dim(1) < MAX_STEP_SEQ] + + [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] + + [ + x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) + for x in state_arg[1:] + ] + + [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]] ) - self.global_v_caches["layer_idx"][i] = IREE.tensor_update( - self.global_v_caches["layer_idx"][i], - slice_of_state, - 0, - self.global_seq_step, - 0, - 0, + token, *state = self.cached_initialize( + x, *state_arg, constraints=forw_const ) - self.global_seq_step = self.global_seq_step + len_of_new_tokens - return token + len_of_new_tokens = IREE.tensor_dim( + state[0], 1 + ) # ? dimension of arbitrarily 0th kv tensor + for i in range(NUM_LAYERS): + slice_of_state = IREE.tensor_reshape( + state[i * 2], 1, len_of_new_tokens, HEADS, HIDDEN_DIM + ) + self.global_k_caches["layer_idx"][i] = IREE.tensor_update( + self.global_k_caches["layer_idx"][i], + slice_of_state, + 0, + self.global_seq_step, + 0, + 0, + ) + for i in range(NUM_LAYERS): + slice_of_state = IREE.tensor_reshape( + state[i * 2 + 1], 1, len_of_new_tokens, HEADS, HIDDEN_DIM + ) + self.global_v_caches["layer_idx"][i] = IREE.tensor_update( + self.global_v_caches["layer_idx"][i], + slice_of_state, + 0, + self.global_seq_step, + 0, + 0, + ) + self.global_seq_step = self.global_seq_step + len_of_new_tokens + return token - @jittable - def cached_initialize(input_ids, *state0_flat): - # Unpad the states. - cur_token_len = state0_flat[0].size(1) - state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat] - state0 = pytree.tree_unflatten(state0_flat, state_schema) - result = mod.forward(input_ids, past_key_values=state0) - state1_flat, _ = pytree.tree_flatten(result.past_key_values) - state1_flat = [ - torch.transpose(x[:, :, cur_token_len:, :], 1, 2) for x in state1_flat - ] - token1 = torch.argmax(result.logits[:, -1, :], dim=1) - token1 = token1[None, :] - return token1, *state1_flat + @jittable + def cached_initialize(input_ids, *state0_flat): + # Unpad the states. + cur_token_len = state0_flat[0].size(1) + state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat] + state0 = pytree.tree_unflatten(state0_flat, state_schema) + result = mod.forward(input_ids, past_key_values=state0) + state1_flat, _ = pytree.tree_flatten(result.past_key_values) + state1_flat = [ + torch.transpose(x[:, :, cur_token_len:, :], 1, 2) + for x in state1_flat + ] + token1 = torch.argmax(result.logits[:, -1, :], dim=1) + token1 = token1[None, :] + return token1, *state1_flat - # Streaming-LLM KVCache evict algorithm: - # slice1 = KVCache[0 : sink] - # slice2 = KVCache[seq_len - window_size : seq_len] - # KVCache = torch.cat([slice1, slice2]) - # TODO: Add move to handle overlap of data. - def evict_kvcache_space(self): - # TODO: Replace hardcoded with global variable. - sink_size = 4 - window_size = 252 - most_recent_window = self.global_seq_step + (-window_size) - for i in range(NUM_LAYERS): - update_window_state = IREE.tensor_slice( - self.global_k_caches["layer_idx"][i], - 0, - (most_recent_window, window_size), - (0, HEADS), - (0, HIDDEN_DIM), - ) # sequence context dim - self.global_k_caches["layer_idx"][i] = IREE.tensor_update( - self.global_k_caches["layer_idx"][i], - update_window_state, - 0, - sink_size, - 0, - 0, - ) - for i in range(NUM_LAYERS): - update_window_state = IREE.tensor_slice( - self.global_v_caches["layer_idx"][i], - 0, - (most_recent_window, window_size), - (0, HEADS), - (0, HIDDEN_DIM), - ) # sequence context dim - self.global_v_caches["layer_idx"][i] = IREE.tensor_update( - self.global_v_caches["layer_idx"][i], - update_window_state, - 0, - sink_size, - 0, - 0, - ) - self.global_seq_step.set(window_size + sink_size) - return self.global_seq_step + # Streaming-LLM KVCache evict algorithm: + # slice1 = KVCache[0 : sink] + # slice2 = KVCache[seq_len - window_size : seq_len] + # KVCache = torch.cat([slice1, slice2]) + # TODO: Add move to handle overlap of data. + def evict_kvcache_space(self): + # TODO: Replace hardcoded with global variable. + sink_size = 4 + window_size = 252 + most_recent_window = self.global_seq_step + (-window_size) + for i in range(NUM_LAYERS): + update_window_state = IREE.tensor_slice( + self.global_k_caches["layer_idx"][i], + 0, + (most_recent_window, window_size), + (0, HEADS), + (0, HIDDEN_DIM), + ) # sequence context dim + self.global_k_caches["layer_idx"][i] = IREE.tensor_update( + self.global_k_caches["layer_idx"][i], + update_window_state, + 0, + sink_size, + 0, + 0, + ) + for i in range(NUM_LAYERS): + update_window_state = IREE.tensor_slice( + self.global_v_caches["layer_idx"][i], + 0, + (most_recent_window, window_size), + (0, HEADS), + (0, HIDDEN_DIM), + ) # sequence context dim + self.global_v_caches["layer_idx"][i] = IREE.tensor_update( + self.global_v_caches["layer_idx"][i], + update_window_state, + 0, + sink_size, + 0, + 0, + ) + self.global_seq_step.set(window_size + sink_size) + return self.global_seq_step - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - if streaming_llm: - print("Compiling with Streaming LLM") - inst = StreamingStateUpdateModule(context=Context(), import_to=import_to) - else: - inst = StateUpdateModule(context=Context(), import_to=import_to) - # TODO: Integrate with external parameters to actually be able to run - # TODO: Make more generalizable to be able to quantize with all compile_to options - if quantization == "int4" and not compile_to == "linalg": - from shark_turbine.transforms.quantization import mm_group_quant + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + if streaming_llm: + print("Compiling with Streaming LLM") + inst = StreamingStateUpdateModule(context=Context(), import_to=import_to) + else: + inst = StateUpdateModule(context=Context(), import_to=import_to) + # TODO: Integrate with external parameters to actually be able to run + # TODO: Make more generalizable to be able to quantize with all compile_to options + if quantization == "int4" and not compile_to == "linalg": + from shark_turbine.transforms.quantization import mm_group_quant - mm_group_quant.MMGroupQuantRewriterPass( - CompiledModule.get_mlir_module(inst).operation - ).run() - module_str = str(CompiledModule.get_mlir_module(inst)) - safe_name = hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) + mm_group_quant.MMGroupQuantRewriterPass( + CompiledModule.get_mlir_module(inst).operation + ).run() + module_str = str(CompiledModule.get_mlir_module(inst)) if upload_ir: with open(f"{safe_name}.mlir", "w+") as f: f.write(module_str) @@ -423,64 +472,21 @@ def evict_kvcache_space(self): if compile_to != "vmfb": return module_str, tokenizer else: - flags = [ - "--iree-input-type=torch", - "--mlir-print-debuginfo", - "--mlir-print-op-on-diagnostic=false", - "--iree-llvmcpu-target-cpu-features=host", - "--iree-llvmcpu-target-triple=x86_64-linux-gnu", - "--iree-stream-resource-index-bits=64", - "--iree-vm-target-index-bits=64", - ] - if device == "cpu" or device == "llvm-cpu": - flags.append("--iree-llvmcpu-enable-ukernels=all") - device = "llvm-cpu" - elif device == "vulkan": - flags.extend( - [ - "--iree-vulkan-target-triple=" + target_triple, - "--iree-stream-resource-max-allocation-size=" - + vulkan_max_allocation, - ] - ) - elif device == "rocm": - flags.extend( - [ - "--iree-rocm-target-chip=" + target_triple, - "--iree-rocm-link-bc=true", - "--iree-vm-bytecode-module-strip-source-map=true", - "--iree-opt-strip-assertions=true", - "--iree-vm-target-truncate-unsupported-floats", - ] - ) - ukernel_supported_arch = {"gfx90a", "gfx940", "gfx1030", "gfx1100"} - if target_triple in ukernel_supported_arch: - flags.extend(["--iree-rocm-enable-ukernels=argmax"]) - elif device == "cuda": - flags.extend( - [ - "--iree-hal-cuda-llvm-target-arch=" + target_triple, - "--iree-vm-bytecode-module-strip-source-map=true", - "--iree-vm-target-truncate-unsupported-floats", - ] - ) - else: - print("Unknown device kind: ", device) - import iree.compiler as ireec - - flatbuffer_blob = ireec.compile_str( + blob_name = compile_to_vmfb( module_str, - target_backends=[device], - extra_args=flags, + device, + target_triple, + ireec_flags=iree_flags, + safe_name=vmfb_path.split(".vmfb")[0], + return_path=True, + const_expr_hoisting=True, + mlir_source="str", + save_mlir=False, + attn_spec="mfma" if "gfx9" in target_triple else "wmma", ) - if vmfb_path is None: - vmfb_path = f"{safe_name}.vmfb" - with open(vmfb_path, "wb+") as f: - f.write(flatbuffer_blob) - print("saved to ", safe_name + ".vmfb") if upload_ir: return blob_name - return module_str, tokenizer + return blob_name, tokenizer if __name__ == "__main__": @@ -498,6 +504,8 @@ def evict_kvcache_space(self): args.vulkan_max_allocation, args.streaming_llm, args.vmfb_path, + upload_ir=False, + decomp_attn=args.decomp_attn, ) safe_name = args.hf_model_name.split("/")[-1].strip() safe_name = re.sub("-", "_", safe_name) diff --git a/models/turbine_models/model_runner.py b/models/turbine_models/model_runner.py index a173f3166..1b27ca83b 100644 --- a/models/turbine_models/model_runner.py +++ b/models/turbine_models/model_runner.py @@ -72,8 +72,9 @@ def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=No self.config.vm_instance, index.create_provider(scope="model") ) vm_modules.insert(i, param_module) + del param_module del index - del param_module + self.ctx = ireert.SystemContext( vm_modules=vm_modules, config=self.config, diff --git a/models/turbine_models/tests/conftest.py b/models/turbine_models/tests/conftest.py index 7a1f55b1a..4292c7390 100644 --- a/models/turbine_models/tests/conftest.py +++ b/models/turbine_models/tests/conftest.py @@ -18,7 +18,7 @@ def pytest_addoption(parser): action="store", default="blurry, unsaturated, watermark, noisy, grainy, out of focus", ) - parser.addoption("--num_inference_steps", type=int, action="store", default=5) + parser.addoption("--num_inference_steps", type=int, action="store", default=2) parser.addoption("--guidance_scale", type=float, action="store", default=7.5) parser.addoption("--seed", type=float, action="store", default=0.0) parser.addoption("--vmfb_path", action="store", default="") @@ -36,7 +36,8 @@ def pytest_addoption(parser): # General Options parser.addoption("--compile_to", action="store", default=None) parser.addoption("--external_weights", action="store", default="safetensors") - parser.addoption("--decomp_attn", action="store", default=True) + parser.addoption("--decomp_attn", action="store", default=False) + parser.addoption("--vae_decomp_attn", action="store", default=False) parser.addoption("--attn_spec", action="store", default="") # Compiler Options parser.addoption("--device", action="store", default="cpu") @@ -50,4 +51,17 @@ def pytest_addoption(parser): parser.addoption("--in_channels", type=int, action="store", default=4) parser.addoption("--benchmark", action="store_true", default=False) parser.addoption("--tracy_profile", action="store_true", default=False) - parser.addoption("--compiled_pipeline", type=bool, default=True) + parser.addoption("--compiled_pipeline", type=bool, default=False) + parser.addoption("--model_path", type=str, action="store", default=None) + parser.addoption("--vae_model_path", type=str, action="store", default=None) + parser.addoption("--pipeline_vmfb_path", type=str, action="store", default=None) + parser.addoption("--scheduler_vmfb_path", type=str, action="store", default=None) + parser.addoption("--split_scheduler", action="store_true", default=True) + parser.addoption("--cpu_scheduling", action="store_true", default=True) + parser.addoption("--npu_delegate_path", type=str, action="store", default=None) + parser.addoption("--clip_precision", type=str, action="store", default=None) + parser.addoption("--mmdit_precision", type=str, action="store", default=None) + parser.addoption("--unet_precision", type=str, action="store", default=None) + parser.addoption("--vae_precision", type=str, action="store", default=None) + parser.addoption("--shift", type=float, action="store", default=None) + parser.addoption("--denoise", action="store_true", default=None) diff --git a/models/turbine_models/tests/pipeline_test.py b/models/turbine_models/tests/pipeline_test.py new file mode 100644 index 000000000..658402652 --- /dev/null +++ b/models/turbine_models/tests/pipeline_test.py @@ -0,0 +1,139 @@ +# Copyright 2024 Advanced Micro Devices, inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import pytest +import unittest +import torch +import os +import numpy as np +from iree.compiler.ir import Context +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +from turbine_models.custom_models.pipeline_base import ( + PipelineComponent, + TurbinePipelineBase, +) +from shark_turbine.transforms.general.add_metadata import AddMetadataPass + +model_metadata_forward = { + "model_name": "TestModel2xLinear", + "input_shapes": [10], + "input_dtypes": ["float32"], + "output_shapes": [10], + "output_dtypes": ["float32"], + "test_kwarg_1": "test_kwarg_1_value", + "test_kwarg_2": "test_kwarg_2_value", +} + + +class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 10) + self.fc2 = torch.nn.Linear(10, 10) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + +torch.no_grad() + + +def export_dummy_model(): + model = TestModule() + target = "x86_64-unknown-linux-gnu" + device = "llvm-cpu" + + dummy_input = torch.empty(10) + safe_keys = [ + model_metadata_forward["model_name"], + "fp32", + "bs1", + ] + safe_name = "_".join(safe_keys) + vmfb_path = f"./{safe_name}.vmfb" + + fxb = FxProgramsBuilder(model) + + @fxb.export_program(args=(dummy_input,)) + def _forward(module, inputs): + return module.forward(inputs) + + class CompiledTester(CompiledModule): + forward = _forward + + inst = CompiledTester(context=Context(), import_to="IMPORT") + mlir_module = CompiledModule.get_mlir_module(inst) + mlir_module = AddMetadataPass(mlir_module, model_metadata_forward, "forward").run() + vmfb_path = utils.compile_to_vmfb( + str(mlir_module), + device, + target, + None, + safe_name + "_" + target, + return_path=True, + ) + return vmfb_path + + +class TestPipeline(TurbinePipelineBase): + def __init__( + self, + **base_args, + ): + super().__init__(**base_args) + + def run(self, inputs: list): + return self.test_model_1("forward", *inputs) + + +class PipelineTest(unittest.TestCase): + def setUp(self): + model_map = { + "test_model_1": { + "model_name": "TestModel1", + "external_weights": None, + "module_name": "compiled_tester", + "safe_name": "TestModel2xLinear", + "keywords": ["Test", "Model", "2x", "Linear"], + "export_fn": export_dummy_model, + } + } + self.pipe = TestPipeline( + model_map=model_map, + device="cpu", + target="x86_64-unknown-linux-gnu", + pipeline_dir="./", + precision="fp32", + ) + self.pipe.prepare_all() + self.pipe.load_map() + self.test_input = [torch.ones(10)] + + def test_pipeline(self): + output = self.pipe.run(self.test_input).to_host() + print(output) + + def test_pipeline_benchmark(self): + self.pipe.test_model_1.benchmark = True + output = self.pipe.run(self.test_input).to_host() + print(output) + + def test_pipeline_metadata(self): + metadata = self.pipe.test_model_1.get_metadata("forward") + expected = model_metadata_forward + for i in expected.keys(): + expected[i] = str(expected[i]) + assert expected == metadata, "Metadata mismatch: expected {}, got {}".format( + expected, metadata + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/models/turbine_models/tests/resnet_test.py b/models/turbine_models/tests/resnet_test.py index 5d514e6fe..0cafcd2c7 100644 --- a/models/turbine_models/tests/resnet_test.py +++ b/models/turbine_models/tests/resnet_test.py @@ -3,35 +3,70 @@ from turbine_models.custom_models import resnet_18 import unittest import os -import pytest - -arguments = { - "run_vmfb": True, - "compile_to": None, - "vmfb_path": "", - "device": "local-task", - "iree_target_triple": "", - "vulkan_max_allocation": "4294967296", -} resnet_model = resnet_18.Resnet18Model() class Resnet18Test(unittest.TestCase): - @pytest.mark.xfail( - reason="caused by lack of support for DenseResourceElementsAttr iteration over a generic FloatAttr" - ) - def testExportResnet18Model(self): - with self.assertRaises(SystemExit) as cm: - resnet_18.export_resnet_18_model( - resnet_model, - "vmfb", - "cpu", - ) - self.assertEqual(cm.exception.code, None) - namespace = argparse.Namespace(**arguments) - resnet_18.run_resnet_18_vmfb_comparison(resnet_model, namespace) - os.remove("resnet_18.vmfb") + def testExportResnet18ModelCPU(self): + from turbine_models.tests.testing_cmd_opts import args + + arguments = { + "run_vmfb": True, + "compile_to": "vmfb", + "vmfb_path": "", + "device": "local-task", + "target_triple": "x86_64-unknown-linux-gnu", + "vulkan_max_allocation": "4294967296", + "precision": "fp32", + } + resnet_18.export_resnet_18_model( + resnet_model, + "vmfb", + "cpu", + ) + namespace = AttributeDict(arguments) + err = resnet_18.run_resnet_18_vmfb_comparison(resnet_model, namespace) + assert err < 1e-5 + + def testExportResnet18ModelStaticGFX1100(self): + arguments = { + "run_vmfb": True, + "compile_to": "vmfb", + "vmfb_path": "", + "device": "rocm", + "target_triple": "gfx1100", + "vulkan_max_allocation": "4294967296", + "precision": "fp16", + } + resnet_18.export_static_resnet_18_model( + resnet_model, + "vmfb", + "rocm", + arguments["target_triple"], + ) + namespace = AttributeDict(arguments) + rocm_err = resnet_18.run_resnet_18_vmfb_comparison(resnet_model, namespace) + namespace.device = "hip" + hip_err = resnet_18.run_resnet_18_vmfb_comparison(resnet_model, namespace) + print("ROCM ERROR:", rocm_err) + print("HIP ERROR:", hip_err) + assert rocm_err < 1e-5 + assert hip_err < 1e-5 + + # def tearDown(self): + # if os.path.exists("resnet_18.vmfb"): + # os.remove("resnet_18.vmfb") + # if os.path.exists("resnet_18.mlir"): + # os.remove("resnet_18.mlir") + + +class AttributeDict(dict): + def __getattr__(self, attr): + return self[attr] + + def __setattr__(self, attr, value): + self[attr] = value if __name__ == "__main__": diff --git a/models/turbine_models/tests/sd3_test.py b/models/turbine_models/tests/sd3_test.py new file mode 100644 index 000000000..e44defe65 --- /dev/null +++ b/models/turbine_models/tests/sd3_test.py @@ -0,0 +1,391 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import pytest +import torch +from transformers import CLIPTokenizer +from turbine_models.custom_models.sd_inference.utils import create_safe_name +from turbine_models.custom_models.sd3_inference.text_encoder_impls import SD3Tokenizer +from turbine_models.custom_models.sd3_inference import ( + sd3_text_encoders, + sd3_text_encoders_runner, + sd3_mmdit, + sd3_mmdit_runner, + sd3_vae, + sd3_vae_runner, + sd3_pipeline, + sd3_schedulers, +) +from turbine_models.custom_models.sd_inference import utils +from turbine_models.custom_models.sd3_inference.sd3_text_encoders import ( + TextEncoderModule, +) +from turbine_models.utils.sdxl_benchmark import run_benchmark +import unittest +from tqdm.auto import tqdm +from PIL import Image +import os +import numpy as np +import time + + +torch.random.manual_seed(0) + +arguments = {} + + +@pytest.fixture(scope="session") +def command_line_args(request): + arguments["hf_auth_token"] = request.config.getoption("--hf_auth_token") + arguments["hf_model_name"] = "stabilityai/stable-diffusion-3-medium-diffusers" + arguments["scheduler_id"] = request.config.getoption("--scheduler_id") + arguments["prompt"] = request.config.getoption("--prompt") + arguments["negative_prompt"] = request.config.getoption("--negative_prompt") + arguments["num_inference_steps"] = int( + request.config.getoption("--num_inference_steps") + ) + arguments["guidance_scale"] = float(request.config.getoption("--guidance_scale")) + arguments["seed"] = float(request.config.getoption("--seed")) + arguments["denoise"] = request.config.getoption("--denoise") + arguments["external_weight_path"] = request.config.getoption( + "--external_weight_path" + ) + arguments["external_weight_dir"] = request.config.getoption("--external_weight_dir") + arguments["external_weight_file"] = request.config.getoption( + "--external_weight_file" + ) + arguments["vmfb_path"] = request.config.getoption("--vmfb_path") + arguments["pipeline_vmfb_path"] = request.config.getoption("--pipeline_vmfb_path") + arguments["scheduler_vmfb_path"] = request.config.getoption("--scheduler_vmfb_path") + arguments["split_scheduler"] = request.config.getoption("--split_scheduler") + arguments["cpu_scheduling"] = request.config.getoption("--cpu_scheduling") + arguments["pipeline_dir"] = request.config.getoption("--pipeline_dir") + arguments["compiled_pipeline"] = request.config.getoption("--compiled_pipeline") + arguments["npu_delegate_path"] = request.config.getoption("--npu_delegate_path") + arguments["batch_size"] = int(request.config.getoption("--batch_size")) + arguments["height"] = int(request.config.getoption("--height")) + arguments["width"] = int(request.config.getoption("--width")) + arguments["precision"] = request.config.getoption("--precision") + arguments["vae_precision"] = request.config.getoption("--vae_precision") + arguments["max_length"] = int(request.config.getoption("--max_length")) + arguments["shift"] = request.config.getoption("--shift") + arguments["vae_decomp_attn"] = request.config.getoption("--vae_decomp_attn") + arguments["external_weights"] = request.config.getoption("--external_weights") + arguments["decomp_attn"] = request.config.getoption("--decomp_attn") + arguments["attn_spec"] = request.config.getoption("--attn_spec") + arguments["device"] = utils.iree_device_map(request.config.getoption("--device")) + arguments["backend"] = utils.iree_backend_map( + request.config.getoption("--device").split("://")[0] + ) + arguments["iree_target_triple"] = request.config.getoption("--iree_target_triple") + arguments["ireec_flags"] = request.config.getoption("--ireec_flags") + # TODO (Ean Garvey): align attention spec handling so we don't have to do this. + if not arguments["attn_spec"] and not arguments["decomp_attn"]: + if "gfx9" in arguments["iree_target_triple"]: + arguments["attn_spec"] = "mfma" + elif "gfx11" in arguments["iree_target_triple"]: + arguments["attn_spec"] = "wmma" + + +@pytest.mark.usefixtures("command_line_args") +class StableDiffusion3Test(unittest.TestCase): + def setUp(self): + self.safe_model_name = create_safe_name(arguments["hf_model_name"], "") + + @pytest.mark.xfail(reason="Numerics issues on ~.01 percent of output values") + def test01_ExportPromptEncoder(self): + if arguments["device"] in ["vulkan", "cuda"]: + self.skipTest("Not testing sd3 on vk or cuda") + arguments["external_weight_path"] = ( + self.safe_model_name + "_text_encoders_" + arguments["precision"] + ".irpa" + ) + prompt_encoder_vmfb = sd3_text_encoders.export_text_encoders( + arguments["hf_model_name"], + max_length=arguments["max_length"], + precision=arguments["precision"], + compile_to="vmfb", + external_weights=arguments["external_weights"], + external_weight_path=arguments["external_weight_path"], + device=arguments["backend"], + target_triple=arguments["iree_target_triple"], + ireec_flags=arguments["ireec_flags"], + exit_on_vmfb=False, + pipeline_dir=arguments["pipeline_dir"], + input_mlir=None, + attn_spec=arguments["attn_spec"], + batch_size=arguments["batch_size"], + decomp_attn=True, + ) + tokenizer = SD3Tokenizer() + ( + text_input_ids_list, + uncond_input_ids_list, + ) = sd3_text_encoders_runner.run_tokenize( + tokenizer, + arguments["prompt"], + arguments["negative_prompt"], + ) + ( + turbine_output1, + turbine_output2, + ) = sd3_text_encoders_runner.run_prompt_encoder( + prompt_encoder_vmfb, + arguments["device"], + arguments["external_weight_path"], + text_input_ids_list, + uncond_input_ids_list, + ) + torch_encoder_model = TextEncoderModule( + arguments["batch_size"], + ) + torch_output1, torch_output2 = torch_encoder_model.forward( + *text_input_ids_list, *uncond_input_ids_list + ) + rtol = 4e-2 + atol = 4e-2 + np.testing.assert_allclose(torch_output1, turbine_output1, rtol, atol) + np.testing.assert_allclose(torch_output2, turbine_output2, rtol, atol) + + @pytest.mark.xfail( + reason="Runners need secure dedicated access to gated HF repo for imports." + ) + def test02_ExportMMDITModel(self): + if arguments["device"] in ["vulkan", "cuda"]: + self.skipTest("Not testing on vulkan or cuda") + self.mmdit_model = sd3_mmdit.MMDiTModel( + arguments["hf_model_name"], + dtype=torch.float16 if arguments["precision"] == "fp16" else torch.float32, + ) + arguments["external_weight_path"] = ( + self.safe_model_name + + "_" + + arguments["precision"] + + "_mmdit." + + arguments["external_weights"] + ) + sd3_mmdit.export_mmdit_model( + mmdit_model=self.mmdit_model, + hf_model_name=arguments["hf_model_name"], + batch_size=arguments["batch_size"], + height=arguments["height"], + width=arguments["width"], + precision=arguments["precision"], + max_length=arguments["max_length"], + compile_to="vmfb", + external_weights=arguments["external_weights"], + external_weight_path=arguments["external_weight_path"], + device=arguments["backend"], + target_triple=arguments["iree_target_triple"], + ireec_flags=arguments["ireec_flags"], + decomp_attn=arguments["decomp_attn"], + attn_spec=arguments["attn_spec"], + ) + arguments["vmfb_path"] = ( + self.safe_model_name + + "_" + + str(arguments["max_length"]) + + "_" + + str(arguments["height"]) + + "x" + + str(arguments["width"]) + + "_" + + arguments["precision"] + + "_unet_" + + arguments["device"] + + ".vmfb" + ) + dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32 + + hidden_states = torch.randn( + ( + arguments["batch_size"], + 16, + arguments["height"] // 8, + arguments["width"] // 8, + ), + dtype=dtype, + ) + encoder_hidden_states = torch.randn( + (arguments["batch_size"], arguments["max_length"] * 2, 4096), dtype=dtype + ) + pooled_projections = torch.randn((arguments["batch_size"], 2048), dtype=dtype) + timestep = torch.tensor([0, 0], dtype=dtype) + turbine = sd3_mmdit_runner.run_mmdit_turbine( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + arguments, + ) + torch_output = sd3_mmdit_runner.run_diffusers_mmdit( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + arguments, + ) + # if arguments["benchmark"] or arguments["tracy_profile"]: + # run_benchmark( + # "unet", + # arguments["vmfb_path"], + # arguments["external_weight_path"], + # arguments["rt_device"], + # max_length=arguments["max_length"], + # height=arguments["height"], + # width=arguments["width"], + # batch_size=arguments["batch_size"], + # in_channels=arguments["in_channels"], + # precision=arguments["precision"], + # tracy_profile=arguments["tracy_profile"], + # ) + rtol = 4e-2 + atol = 4e-1 + + np.testing.assert_allclose(torch_output, turbine, rtol, atol) + + @pytest.mark.xfail( + reason="Runners need secure dedicated access to gated HF repo for imports." + ) + def test03_ExportVaeModelDecode(self): + if arguments["device"] in ["vulkan", "cuda"]: + self.skipTest("not testing vulkan or cuda") + vae_model = sd3_vae.VaeModel( + # This is a public model, so no auth required + arguments["hf_model_name"], + ) + arguments["external_weight_path"] = ( + self.safe_model_name + + "_" + + arguments["precision"] + + "_vae_decode." + + arguments["external_weights"] + ) + sd3_vae.export_vae_model( + vae_model, + hf_model_name=arguments["hf_model_name"], + batch_size=arguments["batch_size"], + height=arguments["height"], + width=arguments["width"], + precision=arguments["precision"], + compile_to="vmfb", + external_weights=arguments["external_weights"], + external_weight_path=arguments["external_weight_path"], + device=arguments["backend"], + target_triple=arguments["iree_target_triple"], + ireec_flags=arguments["ireec_flags"], + decomp_attn=arguments["decomp_attn"], + attn_spec=arguments["attn_spec"], + ) + arguments["vmfb_path"] = ( + self.safe_model_name + + "_" + + str(arguments["height"]) + + "x" + + str(arguments["width"]) + + "_" + + arguments["precision"] + + "_vae_decode_" + + arguments["device"] + + ".vmfb" + ) + example_input = torch.ones( + arguments["batch_size"], + 16, + arguments["height"] // 8, + arguments["width"] // 8, + dtype=torch.float32, + ) + example_input_torch = example_input + if arguments["precision"] == "fp16": + example_input = example_input.half() + turbine = sd3_vae_runner.run_vae( + arguments["device"], + example_input, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["external_weight_path"], + ) + torch_output = sd3_vae_runner.run_torch_vae( + arguments["hf_model_name"], + ( + "madebyollin/sdxl-vae-fp16-fix" + if arguments["precision"] == "fp16" + else "" + ), + "decode", + example_input_torch, + ) + # if arguments["benchmark"] or arguments["tracy_profile"]: + # run_benchmark( + # "vae_decode", + # arguments["vmfb_path"], + # arguments["external_weight_path"], + # arguments["rt_device"], + # height=arguments["height"], + # width=arguments["width"], + # precision=arguments["precision"], + # tracy_profile=arguments["tracy_profile"], + # ) + rtol = 4e-2 + atol = 4e-1 + + np.testing.assert_allclose(torch_output, turbine, rtol, atol) + + @pytest.mark.skip( + reason="Waiting on inference plumbing for generalized sd pipeline" + ) + def test04SDPipeline(self): + from turbine_models.custom_models.sd_inference.sd_pipeline import ( + SharkSDPipeline, + ) + + current_args = copy.deepcopy(default_arguments) + decomp_attn = { + "text_encoder": False, + "unet": False, + "vae": current_args["vae_decomp_attn"], + } + sd_pipe = SharkSDPipeline( + current_args["hf_model_name"], + current_args["height"], + current_args["width"], + current_args["batch_size"], + current_args["max_length"], + current_args["precision"], + current_args["device"], + current_args["iree_target_triple"], + ireec_flags=None, # ireec_flags + attn_spec=current_args["attn_spec"], + decomp_attn=decomp_attn, + pipeline_dir="test_vmfbs", # pipeline_dir + external_weights_dir="test_weights", # external_weights_dir + external_weights=current_args["external_weights"], + num_inference_steps=current_args["num_inference_steps"], + cpu_scheduling=True, + scheduler_id=current_args["scheduler_id"], + shift=None, # shift + use_i8_punet=False, + ) + sd_pipe.prepare_all() + sd_pipe.load_map() + output = sd_pipe.generate_images( + current_args["prompt"], + current_args["negative_prompt"], + current_args["num_inference_steps"], + 1, # batch count + current_args["guidance_scale"], + current_args["seed"], + current_args["cpu_scheduling"], + current_args["scheduler_id"], + True, # return_img + ) + assert output is not None + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/models/turbine_models/tests/sd_test.py b/models/turbine_models/tests/sd_test.py index acd150551..674e7d81b 100644 --- a/models/turbine_models/tests/sd_test.py +++ b/models/turbine_models/tests/sd_test.py @@ -23,6 +23,7 @@ import os import copy import platform +from PIL import Image from turbine_models.turbine_tank import turbine_tank @@ -30,8 +31,8 @@ "hf_auth_token": None, "hf_model_name": "CompVis/stable-diffusion-v1-4", "safe_model_name": "stable-diffusion_v1_4", - "scheduler_id": "PNDM", - "num_inference_steps": 5, + "scheduler_id": "EulerDiscrete", + "num_inference_steps": 2, "batch_size": 1, "height": 512, "width": 512, @@ -46,131 +47,40 @@ "device": "cpu", "rt_device": "local-task", "iree_target_triple": "x86_64-linux-gnu", - "vulkan_max_allocation": "4294967296", "prompt": "a photograph of an astronaut riding a horse", + "negative_prompt": "blurry, out of focus", "in_channels": 4, + "vae_decomp_attn": True, + "seed": 0, + "use_i8_punet": False, + "attn_spec": None, + "cpu_scheduling": True, } UPLOAD_IR = os.environ.get("TURBINE_TANK_ACTION", "not_upload") == "upload" -unet_model = unet.UnetModel( - # This is a public model, so no auth required - default_arguments["hf_model_name"], -) - -vae_model = vae.VaeModel( - # This is a public model, so no auth required - default_arguments["hf_model_name"], - custom_vae=None, -) - -schedulers_dict = utils.get_schedulers( - # This is a public model, so no auth required - "CompVis/stable-diffusion-v1-4", -) -scheduler = schedulers_dict[default_arguments["scheduler_id"]] -scheduler_module = schedulers.Scheduler( - "CompVis/stable-diffusion-v1-4", default_arguments["num_inference_steps"], scheduler -) - - # TODO: this is a mess, don't share args across tests, create a copy for each test class StableDiffusionTest(unittest.TestCase): - def testExportT5Model(self): + def testExportClipModel(self): current_args = copy.deepcopy(default_arguments) - current_args["hf_model_name"] = "google/t5-v1_1-small" - safe_prefix = "t5_v1_1_small" - blob_name = clip.export_clip_model( - hf_model_name=current_args["hf_model_name"], - hf_auth_token=None, - compile_to="vmfb", - external_weights=None, - external_weight_path=None, - device="cpu", - target_triple=None, - max_alloc=None, - upload_ir=UPLOAD_IR, - ) - current_args["vmfb_path"] = safe_prefix + "_clip.vmfb" - turbine = clip_runner.run_clip( - current_args["rt_device"], - current_args["prompt"], - current_args["vmfb_path"], - current_args["hf_model_name"], - current_args["hf_auth_token"], - None, - ) - torch_output = clip_runner.run_torch_clip( - current_args["hf_model_name"], - current_args["hf_auth_token"], - current_args["prompt"], + current_args["hf_model_name"] = "CompVis/stable-diffusion-v1-4" + safe_prefix = utils.create_safe_name( + current_args["hf_model_name"].split("/")[-1], "clip" ) - err = utils.largest_error(torch_output, turbine[0]) - assert err < 9e-4 - if UPLOAD_IR: - new_blob_name = blob_name.split(".") - new_blob_name = new_blob_name[0] + "-pass.mlir" - turbine_tank.changeBlobName(blob_name, new_blob_name) - del current_args - del turbine - - def testExportClipVitLarge14(self): - current_args = copy.deepcopy(default_arguments) - current_args["hf_model_name"] = "openai/clip-vit-large-patch14" - safe_prefix = "clip_vit_large_patch14" blob_name = clip.export_clip_model( hf_model_name=current_args["hf_model_name"], - hf_auth_token=None, + max_length=current_args["max_length"], + precision=current_args["precision"], compile_to="vmfb", external_weights="safetensors", external_weight_path=safe_prefix + ".safetensors", device="cpu", - target_triple=None, - max_alloc=None, + target=current_args["iree_target_triple"], + exit_on_vmfb=False, upload_ir=UPLOAD_IR, ) current_args["external_weight_path"] = safe_prefix + ".safetensors" - current_args["vmfb_path"] = safe_prefix + "_clip.vmfb" - turbine = clip_runner.run_clip( - current_args["rt_device"], - current_args["prompt"], - current_args["vmfb_path"], - current_args["hf_model_name"], - current_args["hf_auth_token"], - current_args["external_weight_path"], - ) - torch_output = clip_runner.run_torch_clip( - current_args["hf_model_name"], - current_args["hf_auth_token"], - current_args["prompt"], - ) - err = utils.largest_error(torch_output, turbine[0]) - assert err < 9e-5 - if UPLOAD_IR: - new_blob_name = blob_name.split(".") - new_blob_name = new_blob_name[0] + "-pass.mlir" - turbine_tank.changeBlobName(blob_name, new_blob_name) - if platform.system() != "Windows": - os.remove(current_args["external_weight_path"]) - os.remove(current_args["vmfb_path"]) - del current_args - del turbine - - def testExportClipModel(self): - current_args = copy.deepcopy(default_arguments) - current_args["hf_model_name"] = "CompVis/stable-diffusion-v1-4" - blob_name = clip.export_clip_model( - # This is a public model, so no auth required - "CompVis/stable-diffusion-v1-4", - None, - "vmfb", - "safetensors", - "stable_diffusion_v1_4_clip.safetensors", - "cpu", - upload_ir=UPLOAD_IR, - ) - current_args["external_weight_path"] = "stable_diffusion_v1_4_clip.safetensors" - current_args["vmfb_path"] = "stable_diffusion_v1_4_clip.vmfb" + current_args["vmfb_path"] = blob_name turbine = clip_runner.run_clip( current_args["rt_device"], current_args["prompt"], @@ -199,24 +109,23 @@ def testExportClipModel(self): def testExportUnetModel(self): current_args = copy.deepcopy(default_arguments) blob_name = unet.export_unet_model( - unet_model, - "CompVis/stable-diffusion-v1-4", - current_args["batch_size"], - current_args["height"], - current_args["width"], - current_args["precision"], - current_args["max_length"], - None, - "vmfb", - "safetensors", - "stable_diffusion_v1_4_unet.safetensors", - "cpu", + hf_model_name=current_args["hf_model_name"], + batch_size=current_args["batch_size"], + height=current_args["height"], + width=current_args["width"], + precision=current_args["precision"], + max_length=current_args["max_length"], + compile_to="vmfb", + external_weights="safetensors", + external_weight_path="stable_diffusion_unet.safetensors", + device="cpu", + target=current_args["iree_target_triple"], upload_ir=UPLOAD_IR, ) - current_args["external_weight_path"] = "stable_diffusion_v1_4_unet.safetensors" - current_args["vmfb_path"] = "stable_diffusion_v1_4_unet.vmfb" + current_args["external_weight_path"] = "stable_diffusion_unet.safetensors" + current_args["vmfb_path"] = blob_name sample = torch.rand( - current_args["batch_size"], + current_args["batch_size"] * 2, current_args["in_channels"], current_args["height"] // 8, current_args["width"] // 8, @@ -225,9 +134,13 @@ def testExportUnetModel(self): timestep = torch.zeros(1, dtype=torch.float32) if current_args["hf_model_name"] == "CompVis/stable-diffusion-v1-4": - encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) + encoder_hidden_states = torch.rand( + 2, current_args["max_length"], 768, dtype=torch.float32 + ) elif current_args["hf_model_name"] == "stabilityai/stable-diffusion-2-1-base": - encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32) + encoder_hidden_states = torch.rand( + 2, current_args["max_length"], 1024, dtype=torch.float32 + ) guidance_scale = torch.tensor( [current_args["guidance_scale"]], dtype=torch.float32 ) @@ -242,6 +155,7 @@ def testExportUnetModel(self): current_args["hf_model_name"], current_args["hf_auth_token"], current_args["external_weight_path"], + "float32", ) torch_output = unet_runner.run_torch_unet( current_args["hf_model_name"], @@ -257,30 +171,29 @@ def testExportUnetModel(self): new_blob_name = blob_name.split(".") new_blob_name = new_blob_name[0] + "-pass.mlir" turbine_tank.changeBlobName(blob_name, new_blob_name) - os.remove("stable_diffusion_v1_4_unet.safetensors") - os.remove("stable_diffusion_v1_4_unet.vmfb") + os.remove("stable_diffusion_unet.safetensors") + os.remove(blob_name) del torch_output del turbine def testExportVaeModelDecode(self): current_args = copy.deepcopy(default_arguments) blob_name = vae.export_vae_model( - vae_model, - # This is a public model, so no auth required - "CompVis/stable-diffusion-v1-4", - current_args["batch_size"], - current_args["height"], - current_args["width"], - None, - "vmfb", - "safetensors", - "stable_diffusion_v1_4_vae.safetensors", - "cpu", - variant="decode", + hf_model_name=current_args["hf_model_name"], + batch_size=current_args["batch_size"], + height=current_args["height"], + width=current_args["width"], + precision=current_args["precision"], + compile_to="vmfb", + external_weights="safetensors", + external_weight_path="stable_diffusion_v1_4_vae.safetensors", + device="cpu", + target=current_args["iree_target_triple"], + decomp_attn=current_args["vae_decomp_attn"], upload_ir=UPLOAD_IR, ) current_args["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" - current_args["vmfb_path"] = "stable_diffusion_v1_4_vae.vmfb" + current_args["vmfb_path"] = blob_name example_input = torch.rand( current_args["batch_size"], 4, @@ -288,14 +201,14 @@ def testExportVaeModelDecode(self): current_args["width"] // 8, dtype=torch.float32, ) - turbine = vae_runner.run_vae( + turbine = vae_runner.run_vae_decode( current_args["rt_device"], example_input, current_args["vmfb_path"], current_args["hf_model_name"], current_args["external_weight_path"], ) - torch_output = vae_runner.run_torch_vae( + torch_output = vae_runner.run_torch_vae_decode( current_args["hf_model_name"], "decode", example_input, @@ -310,112 +223,54 @@ def testExportVaeModelDecode(self): del torch_output del turbine os.remove("stable_diffusion_v1_4_vae.safetensors") - os.remove("stable_diffusion_v1_4_vae.vmfb") + os.remove(blob_name) - def testExportVaeModelEncode(self): - current_args = copy.deepcopy(default_arguments) - blob_name = vae.export_vae_model( - vae_model, - # This is a public model, so no auth required - "CompVis/stable-diffusion-v1-4", - current_args["batch_size"], - current_args["height"], - current_args["width"], - current_args["precision"], - "vmfb", - "safetensors", - "stable_diffusion_v1_4_vae.safetensors", - "cpu", - variant="encode", - upload_ir=UPLOAD_IR, + def testSDPipeline(self): + from turbine_models.custom_models.sd_inference.sd_pipeline import ( + SharkSDPipeline, ) - current_args["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" - current_args["vmfb_path"] = "stable_diffusion_v1_4_vae.vmfb" - example_input = torch.rand( - current_args["batch_size"], - 3, - current_args["height"], - current_args["width"], - dtype=torch.float32, - ) - turbine = vae_runner.run_vae( - current_args["rt_device"], - example_input, - current_args["vmfb_path"], - current_args["hf_model_name"], - current_args["external_weight_path"], - ) - torch_output = vae_runner.run_torch_vae( - current_args["hf_model_name"], - "encode", - example_input, - ) - err = utils.largest_error(torch_output, turbine) - assert err < 3e-3 - if UPLOAD_IR: - new_blob_name = blob_name.split(".") - new_blob_name = new_blob_name[0] + "-pass.mlir" - turbine_tank.changeBlobName(blob_name, new_blob_name) - os.remove("stable_diffusion_v1_4_vae.safetensors") - os.remove("stable_diffusion_v1_4_vae.vmfb") - del current_args - del turbine - @unittest.expectedFailure - def testExportPNDMScheduler(self): current_args = copy.deepcopy(default_arguments) - safe_name = "stable_diffusion_v1_4_scheduler" - blob_name = schedulers.export_scheduler( - scheduler_module, - # This is a public model, so no auth required - "CompVis/stable-diffusion-v1-4", - current_args["batch_size"], + decomp_attn = { + "text_encoder": False, + "unet": False, + "vae": current_args["vae_decomp_attn"], + } + sd_pipe = SharkSDPipeline( + current_args["hf_model_name"], current_args["height"], current_args["width"], - None, - "vmfb", - "safetensors", - "stable_diffusion_v1_4_scheduler.safetensors", - "cpu", - upload_ir=UPLOAD_IR, - ) - current_args["external_weight_path"] = safe_name + ".safetensors" - current_args["vmfb_path"] = safe_name + ".vmfb" - sample = torch.rand( current_args["batch_size"], - 4, - current_args["height"] // 8, - current_args["width"] // 8, - dtype=torch.float32, - ) - encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) - turbine = schedulers_runner.run_scheduler( - current_args["rt_device"], - sample, - encoder_hidden_states, - current_args["vmfb_path"], - current_args["hf_model_name"], - current_args["hf_auth_token"], - current_args["external_weight_path"], + current_args["max_length"], + current_args["precision"], + current_args["device"], + current_args["iree_target_triple"], + ireec_flags=None, # ireec_flags + attn_spec=current_args["attn_spec"], + decomp_attn=decomp_attn, + pipeline_dir="test_vmfbs", # pipeline_dir + external_weights_dir="test_weights", # external_weights_dir + external_weights=current_args["external_weights"], + num_inference_steps=current_args["num_inference_steps"], + cpu_scheduling=True, + scheduler_id=current_args["scheduler_id"], + shift=None, # shift + use_i8_punet=current_args["use_i8_punet"], ) - torch_output = schedulers_runner.run_torch_scheduler( - current_args["hf_model_name"], - scheduler, + sd_pipe.prepare_all() + sd_pipe.load_map() + output = sd_pipe.generate_images( + current_args["prompt"], + current_args["negative_prompt"], current_args["num_inference_steps"], - sample, - encoder_hidden_states, + 1, # batch count + current_args["guidance_scale"], + current_args["seed"], + current_args["cpu_scheduling"], + current_args["scheduler_id"], + True, # return_img ) - err = utils.largest_error(torch_output, turbine) - assert err < 9e-3 - if UPLOAD_IR: - new_blob_name = blob_name.split(".") - new_blob_name = new_blob_name[0] + "-pass.mlir" - turbine_tank.changeBlobName(blob_name, new_blob_name) - os.remove("stable_diffusion_v1_4_scheduler.safetensors") - os.remove("stable_diffusion_v1_4_scheduler.vmfb") - del current_args - del torch_output - del turbine + assert output is not None if __name__ == "__main__": diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index a45fd7ca4..216b6ff59 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -7,13 +7,17 @@ import logging import pytest import torch +import shutil +from transformers import CLIPTokenizer from turbine_models.custom_models.sd_inference.utils import create_safe_name +from turbine_models.custom_models.sd_inference import schedulers, vae from turbine_models.custom_models.sdxl_inference import ( - clip, - clip_runner, + sdxl_prompt_encoder, + sdxl_prompt_encoder_runner, unet, unet_runner, - vae, + sdxl_scheduled_unet, + sdxl_scheduled_unet_runner, vae_runner, sdxl_compiled_pipeline, ) @@ -77,151 +81,83 @@ def command_line_args(request): class StableDiffusionXLTest(unittest.TestCase): def setUp(self): self.safe_model_name = create_safe_name(arguments["hf_model_name"], "") - self.unet_model = unet.UnetModel( - # This is a public model, so no auth required - arguments["hf_model_name"], - precision=arguments["precision"], - ) - self.vae_model = vae.VaeModel( - # This is a public model, so no auth required - arguments["hf_model_name"], - custom_vae=( - "madebyollin/sdxl-vae-fp16-fix" - if arguments["precision"] == "fp16" - else None - ), - ) - def test01_ExportClipModels(self): + def test01_ExportPromptEncoder(self): if arguments["device"] in ["vulkan", "cuda", "rocm"]: self.skipTest( - "Compilation error on vulkan; Runtime error on rocm; To be tested on cuda." + "Compilation error on vulkan; recent numerics regression (nans) on hip driver, To be tested on cuda." ) - clip.export_clip_model( - # This is a public model, so no auth required - hf_model_name=arguments["hf_model_name"], - hf_auth_token=None, - max_length=arguments["max_length"], - precision=arguments["precision"], - compile_to="vmfb", - external_weights=arguments["external_weights"], - external_weight_path=self.safe_model_name - + "_" - + arguments["precision"] - + "_clip", - device=arguments["device"], - target_triple=arguments["iree_target_triple"], - ireec_flags=arguments["ireec_flags"], - index=1, - exit_on_vmfb=True, + arguments["external_weight_path"] = ( + "prompt_encoder." + arguments["external_weights"] ) - clip.export_clip_model( - hf_model_name=arguments["hf_model_name"], - hf_auth_token=None, # This is a public model, so no auth required + prompt_encoder_vmfb = sdxl_prompt_encoder.export_prompt_encoder( + arguments["hf_model_name"], + hf_auth_token=None, max_length=arguments["max_length"], + batch_size=arguments["batch_size"], precision=arguments["precision"], compile_to="vmfb", - external_weights=arguments["external_weights"], - external_weight_path=self.safe_model_name - + "_" - + arguments["precision"] - + "_clip", + external_weights="safetensors", + external_weight_path=arguments["external_weight_path"], device=arguments["device"], - target_triple=arguments["iree_target_triple"], - ireec_flags=arguments["ireec_flags"], - index=2, - exit_on_vmfb=True, + target=arguments["iree_target_triple"], ) - arguments["external_weight_path_1"] = ( - self.safe_model_name - + "_" - + arguments["precision"] - + "_clip_1." - + arguments["external_weights"] - ) - arguments["external_weight_path_2"] = ( - self.safe_model_name - + "_" - + arguments["precision"] - + "_clip_2." - + arguments["external_weights"] - ) - arguments["vmfb_path_1"] = ( - self.safe_model_name - + "_" - + str(arguments["max_length"]) - + "_" - + arguments["precision"] - + "_clip_1_" - + arguments["device"] - + ".vmfb" - ) - arguments["vmfb_path_2"] = ( - self.safe_model_name - + "_" - + str(arguments["max_length"]) - + "_" - + arguments["precision"] - + "_clip_2_" - + arguments["device"] - + ".vmfb" + tokenizer_1 = CLIPTokenizer.from_pretrained( + arguments["hf_model_name"], + subfolder="tokenizer", + token=arguments["hf_auth_token"], ) - turbine_1 = clip_runner.run_clip( - arguments["rt_device"], - arguments["prompt"], - arguments["vmfb_path_1"], + tokenizer_2 = CLIPTokenizer.from_pretrained( arguments["hf_model_name"], - arguments["hf_auth_token"], - arguments["external_weight_path_1"], + subfolder="tokenizer_2", + token=arguments["hf_auth_token"], + ) + ( + text_input_ids_list, + uncond_input_ids_list, + ) = sdxl_prompt_encoder_runner.run_tokenize( + tokenizer_1, + tokenizer_2, + arguments["prompt"], + arguments["negative_prompt"], arguments["max_length"], - index=1, ) - turbine_2 = clip_runner.run_clip( + ( + turbine_output1, + turbine_output2, + ) = sdxl_prompt_encoder_runner.run_prompt_encoder( + prompt_encoder_vmfb, arguments["rt_device"], - arguments["prompt"], - arguments["vmfb_path_2"], - arguments["hf_model_name"], - arguments["hf_auth_token"], - arguments["external_weight_path_2"], - arguments["max_length"], - index=2, + arguments["external_weight_path"], + text_input_ids_list, + uncond_input_ids_list, ) - torch_output_1, torch_output_2 = clip_runner.run_torch_clip( + torch_model = sdxl_prompt_encoder.PromptEncoderModule( arguments["hf_model_name"], + arguments["precision"], arguments["hf_auth_token"], - arguments["prompt"], - arguments["max_length"], + ) + torch_output1, torch_output2 = torch_model.forward( + *text_input_ids_list, *uncond_input_ids_list ) if arguments["benchmark"] or arguments["tracy_profile"]: run_benchmark( - "clip_1", - arguments["vmfb_path_1"], - arguments["external_weight_path_1"], - arguments["rt_device"], - max_length=arguments["max_length"], - tracy_profile=arguments["tracy_profile"], - ) - run_benchmark( - "clip_2", - arguments["vmfb_path_2"], - arguments["external_weight_path_2"], + "prompt_encoder", + prompt_encoder_vmfb, + arguments["external_weight_path"], arguments["rt_device"], max_length=arguments["max_length"], tracy_profile=arguments["tracy_profile"], ) rtol = 4e-1 atol = 4e-1 - np.testing.assert_allclose(torch_output_1, turbine_1[0], rtol, atol) - np.testing.assert_allclose(torch_output_2, turbine_2[0], rtol, atol) + np.testing.assert_allclose(torch_output1, turbine_output1, rtol, atol) + np.testing.assert_allclose(torch_output2, turbine_output2, rtol, atol) def test02_ExportUnetModel(self): - if arguments["device"] in ["vulkan", "cuda", "rocm"]: - self.skipTest( - "Unknown error on vulkan; Runtime error on rocm; To be tested on cuda." - ) - unet.export_unet_model( - unet_model=self.unet_model, - # This is a public model, so no auth required + if arguments["device"] in ["vulkan", "cuda"]: + self.skipTest("Unknown error on vulkan; To be tested on cuda.") + unet_vmfb = unet.export_unet_model( hf_model_name=arguments["hf_model_name"], batch_size=arguments["batch_size"], height=arguments["height"], @@ -237,9 +173,11 @@ def test02_ExportUnetModel(self): + "_unet." + arguments["external_weights"], device=arguments["device"], - target_triple=arguments["iree_target_triple"], + target=arguments["iree_target_triple"], ireec_flags=arguments["ireec_flags"], decomp_attn=arguments["decomp_attn"], + attn_spec=arguments["attn_spec"], + exit_on_vmfb=False, ) arguments["external_weight_path"] = ( self.safe_model_name @@ -248,20 +186,7 @@ def test02_ExportUnetModel(self): + "_unet." + arguments["external_weights"] ) - arguments["vmfb_path"] = ( - self.safe_model_name - + "_" - + str(arguments["max_length"]) - + "_" - + str(arguments["height"]) - + "x" - + str(arguments["width"]) - + "_" - + arguments["precision"] - + "_unet_" - + arguments["device"] - + ".vmfb" - ) + arguments["vmfb_path"] = unet_vmfb dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32 sample = torch.rand( ( @@ -272,7 +197,7 @@ def test02_ExportUnetModel(self): ), dtype=dtype, ) - timestep = torch.zeros(1, dtype=torch.int64) + timestep = torch.zeros(1, dtype=dtype) prompt_embeds = torch.rand( (2 * arguments["batch_size"], arguments["max_length"], 2048), dtype=dtype, @@ -320,18 +245,14 @@ def test02_ExportUnetModel(self): tracy_profile=arguments["tracy_profile"], ) rtol = 4e-2 - atol = 4e-2 + atol = 4e-1 np.testing.assert_allclose(torch_output, turbine, rtol, atol) def test03_ExportVaeModelDecode(self): - if arguments["device"] in ["vulkan", "cuda", "rocm"]: - self.skipTest( - "Compilation error on vulkan; Runtime error on rocm; To be tested on cuda." - ) - vae.export_vae_model( - vae_model=self.vae_model, - # This is a public model, so no auth required + if arguments["device"] in ["vulkan", "cuda"]: + self.skipTest("Compilation error on vulkan; To be tested on cuda.") + vae_vmfb = vae.export_vae_model( hf_model_name=arguments["hf_model_name"], batch_size=arguments["batch_size"], height=arguments["height"], @@ -345,11 +266,11 @@ def test03_ExportVaeModelDecode(self): + "_vae_decode." + arguments["external_weights"], device=arguments["device"], - target_triple=arguments["iree_target_triple"], + target=arguments["iree_target_triple"], ireec_flags=arguments["ireec_flags"], - variant="decode", - decomp_attn=arguments["decomp_attn"], - exit_on_vmfb=True, + decomp_attn=True, + attn_spec=arguments["attn_spec"], + exit_on_vmfb=False, ) arguments["external_weight_path"] = ( self.safe_model_name @@ -358,18 +279,7 @@ def test03_ExportVaeModelDecode(self): + "_vae_decode." + arguments["external_weights"] ) - arguments["vmfb_path"] = ( - self.safe_model_name - + "_" - + str(arguments["height"]) - + "x" - + str(arguments["width"]) - + "_" - + arguments["precision"] - + "_vae_decode_" - + arguments["device"] - + ".vmfb" - ) + arguments["vmfb_path"] = vae_vmfb example_input = torch.ones( arguments["batch_size"], 4, @@ -409,7 +319,7 @@ def test03_ExportVaeModelDecode(self): tracy_profile=arguments["tracy_profile"], ) rtol = 4e-2 - atol = 4e-2 + atol = 4e-1 np.testing.assert_allclose(torch_output, turbine, rtol, atol) @@ -418,7 +328,7 @@ def test04_ExportVaeModelEncode(self): self.skipTest( "Compilation error on cpu, vulkan and rocm; To be tested on cuda." ) - vae.export_vae_model( + vae_vmfb = vae.export_vae_model( vae_model=self.vae_model, # This is a public model, so no auth required hf_model_name=arguments["hf_model_name"], @@ -434,10 +344,9 @@ def test04_ExportVaeModelEncode(self): + "_vae_encode." + arguments["external_weights"], device=arguments["device"], - target_triple=arguments["iree_target_triple"], + target=arguments["iree_target_triple"], ireec_flags=arguments["ireec_flags"], - variant="encode", - decomp_attn=arguments["decomp_attn"], + decomp_attn=True, exit_on_vmfb=True, ) arguments["external_weight_path"] = ( @@ -447,18 +356,7 @@ def test04_ExportVaeModelEncode(self): + "_vae_encode." + arguments["external_weights"] ) - arguments["vmfb_path"] = ( - self.safe_model_name - + "_" - + str(arguments["height"]) - + "x" - + str(arguments["width"]) - + "_" - + arguments["precision"] - + "_vae_encode_" - + arguments["device"] - + ".vmfb" - ) + arguments["vmfb_path"] = vae_vmfb example_input = torch.ones( arguments["batch_size"], 3, @@ -502,98 +400,103 @@ def test04_ExportVaeModelEncode(self): np.testing.assert_allclose(torch_output, turbine, rtol, atol) def test05_t2i_generate_images(self): - if arguments["device"] in ["vulkan", "cuda", "rocm"]: - self.skipTest("Have issues with submodels on these backends") - mlirs = { - "vae_decode": None, - "prompt_encoder": None, - "scheduled_unet": None, - "pipeline": None, - "full_pipeline": None, - } - vmfbs = { - "vae_decode": None, - "prompt_encoder": None, - "scheduled_unet": None, - "pipeline": None, - "full_pipeline": None, - } - weights = { - "vae_decode": None, - "prompt_encoder": None, - "scheduled_unet": None, - "pipeline": None, - "full_pipeline": None, - } + if arguments["device"] in ["vulkan", "cuda"]: + self.skipTest("Have issues with submodels on vulkan, cuda") + from turbine_models.custom_models.sd_inference.sd_pipeline import ( + SharkSDPipeline, + ) - if not arguments["pipeline_dir"]: - pipe_id_list = [ - "sdxl_1_0", - str(arguments["height"]), - str(arguments["width"]), - str(arguments["max_length"]), - arguments["precision"], - arguments["device"], - ] - arguments["pipeline_dir"] = os.path.join( - ".", - "_".join(pipe_id_list), - ) - ireec_flags = { - "unet": arguments["ireec_flags"], - "vae": arguments["ireec_flags"], - "clip": arguments["ireec_flags"], - "pipeline": arguments["ireec_flags"], + decomp_attn = { + "text_encoder": False, + "unet": False, + "vae": True, } - user_mlir_list = [] - for submodel_id, mlir_path in zip(mlirs.keys(), user_mlir_list): - if submodel_id in mlir_path: - mlirs[submodel_id] = mlir_path - external_weights_dir = arguments["pipeline_dir"] - sdxl_pipe = sdxl_compiled_pipeline.SharkSDXLPipeline( + sd_pipe = SharkSDPipeline( arguments["hf_model_name"], - arguments["scheduler_id"], arguments["height"], arguments["width"], - arguments["precision"], - arguments["max_length"], arguments["batch_size"], - arguments["num_inference_steps"], + arguments["max_length"], + arguments["precision"], arguments["device"], arguments["iree_target_triple"], - ireec_flags, - arguments["attn_spec"], - arguments["decomp_attn"], - arguments["pipeline_dir"], - external_weights_dir, - arguments["external_weights"], - ) - vmfbs, weights = sdxl_pipe.check_prepared( - mlirs, vmfbs, weights, interactive=False - ) - sdxl_pipe.load_pipeline( - vmfbs, weights, arguments["rt_device"], arguments["compiled_pipeline"] - ) - sdxl_pipe.generate_images( + ireec_flags=None, # ireec_flags + attn_spec=arguments["attn_spec"], + decomp_attn=decomp_attn, + pipeline_dir="test_vmfbs", # pipeline_dir + external_weights_dir="test_weights", # external_weights_dir + external_weights=arguments["external_weights"], + num_inference_steps=arguments["num_inference_steps"], + cpu_scheduling=True, + scheduler_id=arguments["scheduler_id"], + shift=None, # shift + use_i8_punet=False, + ) + sd_pipe.prepare_all() + sd_pipe.load_map() + output = sd_pipe.generate_images( arguments["prompt"], arguments["negative_prompt"], - 1, + arguments["num_inference_steps"], + 1, # batch count arguments["guidance_scale"], arguments["seed"], + True, + arguments["scheduler_id"], + True, # return_img ) - print("Image generation complete.") - os.remove(os.path.join(arguments["pipeline_dir"], "prompt_encoder.vmfb")) - os.remove( - os.path.join( - arguments["pipeline_dir"], - arguments["scheduler_id"] - + "_unet_" - + str(arguments["num_inference_steps"]) - + ".vmfb", + assert output is not None + + @pytest.mark.skip(reason="Needs sdxl_quantized branch of IREE") + def test06_t2i_generate_images_punet(self): + if arguments["device"] in ["vulkan", "cuda", "rocm"]: + self.skipTest( + "Have issues with submodels on vulkan, cuda; ROCM hangs on mi250 despite submodels working." ) + from turbine_models.custom_models.sd_inference.sd_pipeline import ( + SharkSDPipeline, + ) + + decomp_attn = { + "text_encoder": False, + "unet": False, + "vae": True, + } + sd_pipe = SharkSDPipeline( + arguments["hf_model_name"], + arguments["height"], + arguments["width"], + arguments["batch_size"], + arguments["max_length"], + arguments["precision"], + arguments["device"], + arguments["iree_target_triple"], + ireec_flags=None, # ireec_flags + attn_spec=arguments["attn_spec"], + decomp_attn=decomp_attn, + pipeline_dir="test_vmfbs", # pipeline_dir + external_weights_dir="test_weights", # external_weights_dir + external_weights=arguments["external_weights"], + num_inference_steps=arguments["num_inference_steps"], + cpu_scheduling=True, + scheduler_id=arguments["scheduler_id"], + shift=None, # shift + use_i8_punet=True, + ) + sd_pipe.prepare_all() + sd_pipe.load_map() + output = sd_pipe.generate_images( + arguments["prompt"], + arguments["negative_prompt"], + arguments["num_inference_steps"], + 1, # batch count + arguments["guidance_scale"], + arguments["seed"], + True, + arguments["scheduler_id"], + True, # return_img ) - os.remove(os.path.join(arguments["pipeline_dir"], "vae_decode.vmfb")) - os.remove(os.path.join(arguments["pipeline_dir"], "full_pipeline.vmfb")) + assert output is not None if __name__ == "__main__": diff --git a/models/turbine_models/tests/stateless_llama_test.py b/models/turbine_models/tests/stateless_llama_test.py index 884caa575..4b1ffef73 100644 --- a/models/turbine_models/tests/stateless_llama_test.py +++ b/models/turbine_models/tests/stateless_llama_test.py @@ -139,6 +139,9 @@ def test_vmfb_comparison(self): new_blob_name = new_blob_name[0] + "-pass.mlir" turbine_tank.changeBlobName(blob_name, new_blob_name) + # See: https://github.com/nod-ai/SHARK-Turbine/issues/601 + # Developed issues related to the pytorch 2.3 upgrade. + @unittest.expectedFailure def test_streaming_vmfb_comparison(self): """ Similar test to above but for streaming-LLM. diff --git a/models/turbine_models/utils/sdxl_benchmark.py b/models/turbine_models/utils/sdxl_benchmark.py index 1c37f93a1..decc2d940 100644 --- a/models/turbine_models/utils/sdxl_benchmark.py +++ b/models/turbine_models/utils/sdxl_benchmark.py @@ -41,6 +41,8 @@ def run_benchmark( inputs.append(f"1x{max_length}xi64") case "clip_2": inputs.append(f"1x{max_length}xi64") + case "prompt_encoder": + inputs.extend([f"1x{max_length}xi64"] * 4) case "unet": inputs.append( f"{batch_size}x{in_channels}x{height//8}x{width//8}x{DTYPE_MAP[precision]}"