From ade5b3a58520e10de83d4d313223a6d168c65473 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 1 Jul 2024 20:56:42 -0500 Subject: [PATCH 01/12] Adds one SD pipeline to rule them all --- .../custom_models/pipeline_base.py | 330 ++++-- .../custom_models/sd3_inference/sd3_mmdit.py | 20 +- .../sd3_inference/sd3_schedulers.py | 30 +- .../sd3_inference/sd3_text_encoders.py | 22 +- .../sd3_inference/sd3_vae_runner.py | 4 +- .../custom_models/sd_inference/clip.py | 118 +-- .../custom_models/sd_inference/schedulers.py | 57 +- .../custom_models/sd_inference/sd_cmd_opts.py | 19 +- .../custom_models/sd_inference/sd_pipeline.py | 967 +++++++++--------- .../sd_inference/tokenization.py | 552 +++------- .../custom_models/sd_inference/unet.py | 141 +-- .../custom_models/sd_inference/utils.py | 4 +- .../custom_models/sd_inference/vae.py | 163 ++- .../sdxl_inference/sdxl_prompt_encoder.py | 36 +- .../custom_models/sdxl_inference/unet.py | 56 +- .../sdxl_inference/vae_runner.py | 44 +- models/turbine_models/tests/pipeline_test.py | 11 +- 17 files changed, 1303 insertions(+), 1271 deletions(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 24973e548..98601ba76 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -6,6 +6,8 @@ import logging import torch +import ast +from collections.abc import Iterable import iree.runtime as ireert from turbine_models.custom_models.sd_inference import utils, schedulers @@ -23,10 +25,24 @@ import copy from datetime import datetime as dt +np_dtypes = { + "fp16": np.float16, + "fp32": np.float32, + "float16": np.float16, + "float32": np.float32, +} +torch_dtypes = { + "fp16": torch.float16, + "fp32": torch.float32, + "float16": torch.float16, + "float32": torch.float32, +} def merge_arg_into_map(model_map, arg, arg_name): if isinstance(arg, dict): for key in arg.keys(): + if key not in model_map.keys(): + continue if not model_map[key].get(arg_name): model_map[key][arg_name] = arg[key] else: @@ -35,6 +51,26 @@ def merge_arg_into_map(model_map, arg, arg_name): model_map[key][arg_name] = arg return model_map +def merge_export_arg(model_map, arg, arg_name): + if isinstance(arg, dict): + for key in arg.keys(): + if key not in model_map.keys(): + continue + if arg_name not in model_map[key].get("export_args", {}): + model_map[key]["export_args"][arg_name] = arg[key] + else: + for key in model_map.keys(): + if not model_map[key].get("export_args", {}).get(arg_name): + continue + model_map[key]["export_args"][arg_name] = arg + return model_map + + +# def str_to_list(string): +# out = string.strip("[]").replace(" ", "").split(";") +# for item in out: +# item = ast.literal_eval(item) +# return out class PipelineComponent: """ @@ -44,14 +80,14 @@ class PipelineComponent: This aims to make new pipelines and execution modes easier to write, manage, and debug. """ - def __init__(self, dest_type=ireert.DeviceArray, dest_dtype="float16"): + def __init__(self, dest_type="devicearray", dest_dtype="float16"): self.runner = None self.module_name = None self.device = None self.metadata = None self.benchmark = False - self.output_type = dest_type - self.output_dtype = dest_dtype + self.dest_type = dest_type + self.dest_dtype = dest_dtype def load( self, @@ -62,25 +98,104 @@ def load( extra_plugin=None, ): self.module_name = module_name + print(f"Loading {module_name} from {vmfb_path} with external weights: {external_weight_path}.") self.runner = vmfbRunner( rt_device, vmfb_path, external_weight_path, extra_plugin ) self.device = self.runner.config.device self.module = getattr(self.runner.ctx.modules, module_name) - self.metadata = None + self.get_metadata() def unload(self): self.device = None self.runner = None gc.collect() - def get_metadata(self, function_name): - if not self.metadata: - self.metadata = self.module[function_name].vm_function.reflection - return self.metadata - + def get_metadata(self): + self.metadata = {} + for function_name in self.module.vm_module.function_names: + if any(x in function_name for x in ["$async", "__init"]): + continue + try: + self.metadata[function_name] = self.module[function_name].vm_function.reflection + except: + logging.warning( + f"Could not get metadata for {self.module_name}['{function_name}']." + ) + self.metadata[function_name] = None + + def _validate_or_convert_inputs(self, function_name, inputs): + if self.metadata: + expected_input_shapes = self.metadata.get(function_name, {}).get("input_shapes") + if expected_input_shapes: + expected_input_shapes = ast.literal_eval(expected_input_shapes) + expected_input_dtypes = self.metadata.get(function_name, {}).get("input_dtypes", "") + if expected_input_dtypes: + expected_input_dtypes = ast.literal_eval(expected_input_dtypes) + if not isinstance(expected_input_shapes, list): + expected_input_shapes = [expected_input_shapes] + if not expected_input_dtypes: + pass + if not expected_input_shapes: + logging.warning( + f"No input shapes found for {self.module_name}['{function_name}']." + ) + for i in inputs: + if not isinstance(i, ireert.DeviceArray): + i = ireert.asdevicearray(self.device, i) + pass + for i, input_dtype in enumerate(expected_input_dtypes): + if not isinstance(inputs[i], ireert.DeviceArray): + if isinstance(inputs[i], torch.Tensor) or isinstance(inputs[i], torch.HalfTensor): + new_input = inputs[i].float().cpu().numpy() + else: + new_input = inputs[i] + + inputs[i] = ireert.asdevicearray( + self.device, new_input, input_dtype + ) + if str(inputs[i].dtype).split(".")[-1] != input_dtype: + logging.warning( + f"Converting input {i} to {input_dtype} for {self.module_name}['{function_name}']." + ) + inputs[i] = inputs[i].astype(input_dtype) + for i, input_shape in enumerate(expected_input_shapes): + if isinstance(input_shape, str): + input_shape = ast.literal_eval(input_shape) + elif not input_shape: + continue + if tuple(inputs[i].shape) != tuple(input_shape): + raise ValueError( + f"Expected input {i} to be of shape {input_shape} for {self.module_name}['{function_name}'], got {str(tuple(inputs[i].shape))}." + ) + else: + logging.warning( + f"No metadata found for {self.module_name}['{function_name}']." + ) + for i in inputs: + if not isinstance(i, ireert.DeviceArray): + i = ireert.asdevicearray(self.device, i) + + def _output_cast(self, output): + if isinstance(output, tuple): + out_tuple = () + for array in output: + array_out = self._output_cast(array) + out_tuple += (array_out,) + return out_tuple + match self.dest_type: + case "devicearray": + output = output.astype(self.dest_dtype) if output.dtype != self.dest_dtype else output + return output + case "torch": + output = torch.tensor(output.to_host(), dtype=torch_dtypes[self.dest_dtype]) + return output + case "numpy": + return output.to_host().astype(np_dtypes[self.dest_dtype]) + case _: + return output + def _run(self, function_name, inputs: list): - print(inputs) return self.module[function_name](*inputs) def _run_and_benchmark(self, function_name, inputs: list): @@ -92,26 +207,17 @@ def _run_and_benchmark(self, function_name, inputs: list): def __call__(self, function_name, inputs: list): casted_output = False + self._validate_or_convert_inputs(function_name, inputs) if not isinstance(inputs, list): inputs = [inputs] if self.benchmark: output = self._run_and_benchmark(function_name, inputs) else: output = self._run(function_name, inputs) - if output.dtype != self.output_dtype: - casted_output = True - output = output.astype(self.output_dtype) - match self.output_type: - case ireert.DeviceArray: - if casted_output: - output = ireert.asdevicearray( - self.device, output, self.output_dtype - ) - return output - case torch.Tensor: - return torch.tensor(output.to_host()) - case np.ndarray: - return output.to_host() + print("Output before cast: ", output) + output = self._output_cast(output) + print("Output after cast: ", output) + return output class TurbinePipelineBase: @@ -155,7 +261,7 @@ class TurbinePipelineBase: device: str | dict[str] Either a string i.e. "rocm://0", or a dictionary of such with keys matching the submodels of a given pipeline. If a string, a dictionary will be created based on the pipeline's model map and the same device will be used for all submodels. - iree_target_triple: str | dict[str] + target: str | dict[str] Either a string i.e. "gfx1100", or a dictionary with keys matching the submodels of a given pipeline. ireec_flags: str | dict[str] A comma-separated string of flags to pass to the IREE compiler, or a dict of them with keys matching submodels of a given pipeline. @@ -164,9 +270,8 @@ class TurbinePipelineBase: def __init__( self, model_map: dict, - batch_size: int, device: str | dict[str], - iree_target_triple: str | dict[str], + target: str | dict[str], ireec_flags: str | dict[str] = None, precision: str | dict[str] = "fp16", td_spec: str | dict[str] = None, @@ -174,55 +279,71 @@ def __init__( external_weights: str | dict[str] = None, pipeline_dir: str = "./shark_vmfbs", external_weights_dir: str = "./shark_weights", + hf_model_name: str | dict[str] = None, + common_export_args: dict = {}, ): self.map = model_map - self.batch_size = batch_size if isinstance(device, dict): assert isinstance( - iree_target_triple, dict + target, dict ), "Device and target triple must be both dicts or both strings." for submodel in self.map.keys(): assert submodel in device.keys(), f"Device for {submodel} not found." assert ( - submodel in iree_target_triple.keys() + submodel in target.keys() ), f"Target arch for {submodel} not found." self.map[submodel]["device"] = device[submodel] self.map[submodel]["driver"] = utils.iree_device_map(device[submodel]) - self.map[submodel]["target"] = iree_target_triple[submodel] + self.map[submodel]["target"] = target[submodel] else: assert isinstance( - iree_target_triple, str + target, str ), "Device and target triple must be both dicts or both strings." for submodel in self.map.keys(): self.map[submodel]["device"] = device self.map[submodel]["driver"] = utils.iree_device_map(device) - self.map[submodel]["target"] = iree_target_triple + self.map[submodel]["target"] = target map_arguments = { "ireec_flags": ireec_flags, "precision": precision, "td_spec": td_spec, "decomp_attn": decomp_attn, "external_weights": external_weights, + "hf_model_name": hf_model_name, } + print(map_arguments) for arg in map_arguments.keys(): self.map = merge_arg_into_map(self.map, map_arguments[arg], arg) - np_dtypes = { - "fp16": np.float16, - "fp32": np.float32, - } - torch_dtypes = { - "fp16": torch.float16, - "fp32": torch.float32, - } + + self.map = merge_arg_into_map( + self.map, np_dtypes[self.map[submodel]["precision"]], "np_dtype" + ) + self.map = merge_arg_into_map( + self.map, torch_dtypes[self.map[submodel]["precision"]], "torch_dtype" + ) + for arg in common_export_args.keys(): + for submodel in self.map.keys(): + self.map[submodel].get("export_args", {})[arg] = self.map[submodel].get(arg, common_export_args[arg]) for submodel in self.map.keys(): - self.map = merge_arg_into_map( - self.map, np_dtypes[self.map[submodel]["precision"]], "np_dtype" - ) - self.map = merge_arg_into_map( - self.map, torch_dtypes[self.map[submodel]["precision"]], "torch_dtype" - ) - print(self.map) - + for key, value in map_arguments.items(): + self.map = merge_export_arg(self.map, value, key) + for key, value in self.map[submodel].get("export_args", {}).items(): + if key == "hf_model_name": + self.map[submodel]["keywords"].append(utils.create_safe_name(value.split("/")[-1], "")) + if key == "decomp_attn": + if not value: + self.map[submodel]["keywords"].append("!decomp_attn") + else: + self.map[submodel]["keywords"].append("decomp_attn") + elif key == "batch_size": + self.map[submodel]["keywords"].append(f"bs{value}") + elif key in ["height"]: + dims = f"{self.map[submodel]['export_args']['width']}x{self.map[submodel]['export_args']['height']}" + self.map[submodel]["keywords"].append(dims) + elif key in ["max_length", "precision"]: + self.map[submodel]["keywords"].append(str(value)) + + self.pipeline_dir = pipeline_dir if not os.path.exists(self.pipeline_dir): os.makedirs(self.pipeline_dir) @@ -266,11 +387,13 @@ def prepare_all( for submodel in self.map.keys(): if not self.map[submodel].get("vmfb"): print("Fetching: ", submodel) - self.export_submodel(submodel, input_mlir=mlirs) - if not self.map[submodel]["external_weights"]: + self.export_submodel(submodel, input_mlir=self.map[submodel].get("mlir")) + if not self.map[submodel]["export_args"]["external_weights"]: assert not self.map[submodel].get( "weights" ), f"External weights should not be used for a model with inlined params." + if not self.map[submodel].get("weights") and self.map[submodel]["export_args"].get("external_weights"): + self.export_submodel(submodel, weights_only=True) return self.prepare_all(mlirs, vmfbs, weights, interactive) def is_prepared(self, vmfbs, weights): @@ -288,20 +411,33 @@ def is_prepared(self, vmfbs, weights): continue # search self.pipeline_dir for key-specific vmfb keywords = self.map[key].get("keywords", []) + mlir_keywords = copy.deepcopy(keywords) + mlir_keywords.extend( + [ + "mlir", + self.map[key]["precision"], + ] + ) keywords.extend( [ - self.map[key]["safe_name"], "vmfb", - "bs" + str(self.batch_size), self.map[key]["target"], self.map[key]["precision"], ] ) + print(keywords) + neg_keywords = [] + for kw in keywords: + if kw.startswith("!"): + neg_keywords.append(kw.strip("!")) + keywords.remove(kw) avail_files = os.listdir(pipeline_dir) candidates = [] for filename in avail_files: - if all(str(x) in filename for x in keywords): + if all(str(x) in filename for x in keywords) and not any(x in filename for x in neg_keywords): candidates.append(os.path.join(pipeline_dir, filename)) + if all(str(x) in filename for x in mlir_keywords) and not any(x in filename for x in neg_keywords): + self.map[key]["mlir"] = os.path.join(pipeline_dir, filename) if len(candidates) == 1: self.map[key]["vmfb"] = candidates[0] elif len(candidates) > 1: @@ -313,8 +449,8 @@ def is_prepared(self, vmfbs, weights): missing[key].append("vmfb") # Make sure vmfb needs external weights, as they may be inlined. - if self.map[key].get("external_weights"): - if self.map[key]["external_weights"]: + if self.map[key].get("export_args", {}).get("external_weights"): + if not self.map[key]["external_weights"]: continue if self.map[key].get("weights"): # weights already found in model map @@ -325,10 +461,9 @@ def is_prepared(self, vmfbs, weights): continue # search self.external_weights_dir for key-specific weights w_keywords = [ - self.map[key]["safe_name"], - self.map[key]["precision"], - self.map[key]["external_weights"], + self.map[key]["export_args"]["external_weight_path"], ] + avail_files = os.listdir(self.external_weights_dir) candidates = [] for filename in avail_files: @@ -338,17 +473,20 @@ def is_prepared(self, vmfbs, weights): ) if len(candidates) == 1: self.map[key]["weights"] = candidates[0] + self.map[key]["export_args"]["external_weight_path"] = None elif len(candidates) > 1: print(f"Multiple weight files found for {key}: {candidates}") print(f"Choosing {candidates[0]} for {key}.") self.map[key][weights] = candidates[0] - else: + self.map[key]["export_args"]["external_weight_path"] = None + elif self.map[key].get("external_weights"): # weights not found in external_weights_dir. Add to list of files to generate. missing[key].append("weights") if not any(x for x in missing.values()): ready = True else: print("Missing files: ", missing) + ready = False return ready def get_mlir_from_turbine_tank(self, submodel, container_name): @@ -379,9 +517,10 @@ def export_submodel( if not os.path.exists(self.external_weights_dir): os.makedirs(self.external_weights_dir, exist_ok=False) - self.map[submodel]["weights"] = os.path.join( + self.map[submodel]["export_args"]["external_weight_path"] = os.path.join( self.external_weights_dir, - f"{submodel}_{self.map[submodel]['precision']}." + utils.create_safe_name(self.map[submodel]["export_args"].get("hf_model_name", ""), "") + + f"_{submodel}_{self.map[submodel]['precision']}." + self.map[submodel]["external_weights"], ) @@ -404,31 +543,33 @@ def export_submodel( input_mlir = None else: input_mlir = None - self.map[submodel]["mlir"] = input_mlir + self.map[submodel]["export_args"]["input_mlir"] = input_mlir match submodel: case "unetloop": # SDXL ONLY FOR NOW pipeline_file = get_pipeline_ir( - self.width, - self.height, - self.precision, - self.batch_size, - self.max_length, + self.map[submodel]["export_args"]["width"], + self.map[submodel]["export_args"]["height"], + self.map[submodel]["export_args"]["precision"], + self.map[submodel]["export_args"]["batch_size"], + self.map[submodel]["export_args"]["max_length"], "unet_loop", ) + dims = [self.map[submodel]["export_args"]["width"], self.map[submodel]["export_args"]["height"]] + dims = "x".join([str(x) for x in dims]) pipeline_keys = [ - utils.create_safe_name(self.hf_model_name.split("/")[-1], ""), - "bs" + str(self.batch_size), - f"{str(self.width)}x{str(self.height)}", - self.precision, - str(self.max_length), + utils.create_safe_name(self.map[submodel]["export_args"]["hf_model_name"].split("/")[-1], ""), + "bs" + str(self.map[submodel]["export_args"]["batch_size"]), + dims, + self.map[submodel]["export_args"]["precision"], + str(self.map[submodel]["export_args"]["max_length"]), "unetloop", ] vmfb_path = utils.compile_to_vmfb( pipeline_file, self.map["unet"]["device"], self.map["unet"]["target"], - self.ireec_flags["pipeline"], + None, os.path.join(self.pipeline_dir, "_".join(pipeline_keys)), return_path=True, mlir_source="str", @@ -437,26 +578,28 @@ def export_submodel( self.map[submodel]["weights"] = None case "fullpipeline": # SDXL ONLY FOR NOW pipeline_file = get_pipeline_ir( - self.width, - self.height, - self.precision, - self.batch_size, - self.max_length, + self.map[submodel]["export_args"]["width"], + self.map[submodel]["export_args"]["height"], + self.map[submodel]["export_args"]["precision"], + self.map[submodel]["export_args"]["batch_size"], + self.map[submodel]["export_args"]["max_length"], "tokens_to_image", ) + dims = [self.map[submodel]["export_args"]["width"], self.map[submodel]["export_args"]["height"]] + dims = "x".join([str(x) for x in dims]) pipeline_keys = [ - utils.create_safe_name(self.hf_model_name.split("/")[-1], ""), - "bs" + str(self.batch_size), - f"{str(self.width)}x{str(self.height)}", - self.precision, - str(self.max_length), + utils.create_safe_name(self.map[submodel]["export_args"]["hf_model_name"].split("/")[-1], ""), + "bs" + str(self.map[submodel]["export_args"]["batch_size"]), + dims, + self.map[submodel]["export_args"]["precision"], + str(self.map[submodel]["export_args"]["max_length"]), "fullpipeline", ] vmfb_path = utils.compile_to_vmfb( pipeline_file, self.map["unet"]["device"], self.map["unet"]["target"], - self.ireec_flags["pipeline"], + None, os.path.join(self.pipeline_dir, "_".join(pipeline_keys)), return_path=True, mlir_source="str", @@ -467,14 +610,24 @@ def export_submodel( export_args = self.map[submodel].get("export_args", {}) if self.map[submodel].get("input_mlir"): export_args["input_mlir"] = self.map[submodel].get("mlir") + if weights_only: + export_args["weights_only"] = True if export_args: - vmfb_path = self.map[submodel]["export_fn"](**export_args) + exported = self.map[submodel]["export_fn"](**export_args) else: - vmfb_path = self.map[submodel]["export_fn"]() + exported = self.map[submodel]["export_fn"]() + if not self.map[submodel].get("weights") and os.path.exists(self.map[submodel]["export_args"].get("external_weight_path")): + self.map[submodel]["weights"] = self.map[submodel]["export_args"].get("external_weight_path", None) + if not weights_only: + self.map[submodel]["vmfb"] = exported + # LOAD def load_map(self): for submodel in self.map.keys(): + if not self.map[submodel]["load"]: + print("Skipping load for ", submodel) + continue self.load_submodel(submodel) def load_submodel(self, submodel): @@ -484,7 +637,8 @@ def load_submodel(self, submodel): "external_weights" ): raise ValueError(f"Weights not found for {submodel}.") - self.map[submodel]["runner"] = PipelineComponent() + dest_type = self.map[submodel].get("dest_type", "devicearray") + self.map[submodel]["runner"] = PipelineComponent(dest_type=dest_type) self.map[submodel]["runner"].load( self.map[submodel]["driver"], self.map[submodel]["vmfb"], diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py index 8b3176c8d..e19cac162 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py @@ -17,6 +17,7 @@ from shark_turbine.dynamo.passes import ( DEFAULT_DECOMPOSITIONS, ) +from shark_turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch import torch._dynamo as dynamo @@ -160,6 +161,7 @@ def export_mmdit_model( weights_only=False, ): dtype = torch.float16 if precision == "fp16" else torch.float32 + np_dtype = "float16" if precision == "fp16" else "float32" safe_name = utils.create_safe_name( hf_model_name, f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_mmdit", @@ -239,8 +241,22 @@ class CompiledMmdit(CompiledModule): inst = CompiledMmdit(context=Context(), import_to="IMPORT") - module_str = str(CompiledModule.get_mlir_module(inst)) - + module = CompiledModule.get_mlir_module(inst) + + model_metadata_run_forward = { + "model_name": "sd3_mmdit", + "input_shapes": [ + hidden_states_shape, + encoder_hidden_states_shape, + pooled_projections_shape, + init_batch_dim + ], + "input_dtypes": [np_dtype for x in range(4)], + "output_shapes": [hidden_states_shape], + "output_dtypes": [np_dtype], + } + module = AddMetadataPass(module, model_metadata_run_forward, "run_forward").run() + module_str = str(module) if compile_to != "vmfb": return module_str else: diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py index ea0213486..6b4fe135b 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Union from shark_turbine.aot import * import shark_turbine.ops.iree as ops +from shark_turbine.transforms.general.add_metadata import AddMetadataPass from iree.compiler.ir import Context import iree.runtime as ireert import numpy as np @@ -213,6 +214,7 @@ def export_scheduler_model( upload_ir=False, ): dtype = torch.float16 if precision == "fp16" else torch.float32 + np_dtype = "float16" if precision == "fp16" else "float32" scheduler_module = FlowSchedulingModel(hf_model_name, num_inference_steps, dtype) vmfb_names = [ "EulerFlowScheduler", @@ -317,8 +319,34 @@ class CompiledScheduler(CompiledModule): import_to = "INPUT" if compile_to == "linalg" else "IMPORT" inst = CompiledScheduler(context=Context(), import_to=import_to) - module_str = str(CompiledModule.get_mlir_module(inst)) + module = CompiledModule.get_mlir_module(inst) + + model_metadata_run_init = { + "model_name": "sd3_scheduler_FlowEulerDiscrete", + "input_shapes": [sample], + "input_dtypes": [np_dtype], + "output_shapes": [sample, "?", "?"], + "output_dtypes": [np_dtype, "int32", "float32"], + } + model_metadata_run_prep = { + "model_name": "sd3_scheduler_FlowEulerDiscrete", + "input_shapes": [sample, 1, [19]], + "input_dtypes": [np_dtype, "float32", "float32"], + "output_shapes": [noise_pred_shape, noise_pred_shape[0]], + "output_dtypes": [np_dtype, "float32"], + } + model_metadata_run_step = { + "model_name": "sd3_scheduler_FlowEulerDiscrete", + "input_shapes": [noise_pred_shape, 1, sample, 1, 1], + "input_dtypes": [np_dtype, np_dtype, np_dtype, np_dtype, "int64"], + "output_shapes": [sample], + "output_dtypes": [np_dtype], + } + module = AddMetadataPass(module, model_metadata_run_init, "run_init").run() + module = AddMetadataPass(module, model_metadata_run_prep, "run_prep").run() + module = AddMetadataPass(module, model_metadata_run_step, "run_step").run() + module_str = str(module) if compile_to != "vmfb": return module_str elif compile_to == "vmfb": diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py index 2e0a69445..2784e873e 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py @@ -13,6 +13,7 @@ from iree.compiler.ir import Context import numpy as np from shark_turbine.aot import * +from shark_turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch from turbine_models.custom_models.sd3_inference.text_encoder_impls import ( @@ -113,8 +114,8 @@ def forward(self, tokens_g, tokens_l, tokens_t5xxl, neg_g, neg_l, neg_t5): @torch.no_grad() def export_text_encoders( hf_model_name, - hf_auth_token=None, max_length=64, + batch_size=1, precision="fp16", compile_to="torch", external_weights=None, @@ -126,7 +127,6 @@ def export_text_encoders( pipeline_dir=None, input_mlir=None, attn_spec=None, - output_batchsize=1, decomp_attn=True, ): @@ -191,9 +191,18 @@ class CompiledTextEncoder(CompiledModule): save_module_parameters(external_weight_path, model) inst = CompiledTextEncoder(context=Context(), import_to="IMPORT") - - module_str = str(CompiledModule.get_mlir_module(inst)) - + + module = CompiledModule.get_mlir_module(inst) + + model_metadata_forward = { + "model_name": "sd3_clip_t5xxl_text_encoders", + "input_shapes": [(1, max_length, 2) for x in range(6)], + "input_dtypes": ["int64" for x in range(6)], + "output_shapes": [(2*output_batchsize,max_length*2,4096), (2*output_batchsize,2048)], + "output_dtypes": ["float32"], + } + module = AddMetadataPass(module, model_metadata_forward, "forward").run() + module_str = str(module) if compile_to != "vmfb": return module_str else: @@ -215,8 +224,8 @@ class CompiledTextEncoder(CompiledModule): mod_str, _ = export_text_encoders( args.hf_model_name, - args.hf_auth_token, args.max_length, + args.batch_size, args.precision, args.compile_to, args.external_weights, @@ -228,7 +237,6 @@ class CompiledTextEncoder(CompiledModule): pipeline_dir=args.pipeline_dir, input_mlir=args.input_mlir, attn_spec=args.attn_spec, - output_batchsize=args.batch_size, ) if args.input_mlir or args.weights_only or args.compile_to == "vmfb": exit() diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py b/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py index 1267bb862..521f90bb9 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py @@ -21,9 +21,9 @@ def run_vae( def run_torch_vae(hf_model_name, variant, example_input): - from turbine_models.custom_models.sd3_inference.sd3_vae import VaeModel + from turbine_models.custom_models.sd_inference.vae import SD3VaeModel - vae_model = VaeModel( + vae_model = SD3VaeModel( hf_model_name, ) diff --git a/models/turbine_models/custom_models/sd_inference/clip.py b/models/turbine_models/custom_models/sd_inference/clip.py index 52c36a5c3..a4c177736 100644 --- a/models/turbine_models/custom_models/sd_inference/clip.py +++ b/models/turbine_models/custom_models/sd_inference/clip.py @@ -9,78 +9,49 @@ from iree.compiler.ir import Context from shark_turbine.aot import * +from shark_turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor from turbine_models.turbine_tank import turbine_tank -import argparse - -parser = argparse.ArgumentParser() -parser.add_argument( - "--hf_auth_token", type=str, help="The Hugging Face auth token, required" -) -parser.add_argument( - "--hf_model_name", - type=str, - help="HF model name", - default="CompVis/stable-diffusion-v1-4", -) -parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") -parser.add_argument("--external_weight_path", type=str, default="") -parser.add_argument( - "--external_weights", - type=str, - default=None, - help="saves ir/vmfb without global weights for size and readability, options [safetensors]", -) -parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") -# TODO: Bring in detection for target triple -parser.add_argument( - "--iree_target_triple", - type=str, - default="", - help="Specify vulkan target triple or rocm/cuda target device.", -) -parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") - - +@torch.no_grad() def export_clip_model( hf_model_name, - hf_auth_token: str = None, + batch_size: int = 1, max_length: int = 64, precision: str = "fp16", compile_to: str = "torch", external_weights: str = None, external_weight_path: str = None, device: str = "llvm-cpu", - target_triple: str = "x86_64-linux-gnu", + target: str = "x86_64-linux-gnu", ireec_flags: str = None, exit_on_vmfb: bool = False, pipeline_dir: str = None, input_mlir: str = None, - td_spec: str = None, + attn_spec: str = None, weights_only: bool = False, upload_ir: bool = False, + decomp_attn: bool = False, ): input_len = max_length + safe_name = utils.create_safe_name( + hf_model_name, f"_bs{batch_size}_{str(max_length)}-{precision}-clip" + ) if pipeline_dir not in [None, ""]: - safe_name = os.path.join(pipeline_dir, "clip") - else: - safe_name = utils.create_safe_name( - hf_model_name, f"_{str(max_length)}-{precision}-clip-{device}" - ) + safe_name = os.path.join(pipeline_dir, safe_name) if input_mlir: vmfb_path = utils.compile_to_vmfb( input_mlir, device, - target_triple, + target, ireec_flags, safe_name, mlir_source="file", return_path=not exit_on_vmfb, const_expr_hoisting=True, - attn_spec=td_spec, + attn_spec=attn_spec, ) return vmfb_path if "google/t5" in hf_model_name: @@ -101,27 +72,25 @@ def export_clip_model( tokenizer = CLIPTokenizer.from_pretrained( hf_model_name, subfolder="tokenizer", - token=hf_auth_token, ) hf_subfolder = "text_encoder" text_encoder_model = CLIPTextModel.from_pretrained( hf_model_name, subfolder=hf_subfolder, - token=hf_auth_token, ) - + if precision == "fp16": + text_encoder_model = text_encoder_model.half() mapper = {} utils.save_external_weights( mapper, text_encoder_model, external_weights, external_weight_path ) - if weights_only: return external_weight_path - + if "google/t5" in hf_model_name: - - class CompiledClip(CompiledModule): + input_shapes = [(batch_size, input_len), (batch_size, input_len)] + class CompiledTextEncoder(CompiledModule): if external_weights: params = export_parameters( text_encoder_model, @@ -132,7 +101,7 @@ class CompiledClip(CompiledModule): else: params = export_parameters(text_encoder_model) - def main( + def encode_tokens( self, inp=AbstractTensor(1, input_len, dtype=torch.int64), decoder_input_ids=AbstractTensor(1, input_len, dtype=torch.int64), @@ -140,10 +109,9 @@ def main( return jittable(text_encoder_model.forward)( input_ids=inp, decoder_input_ids=decoder_input_ids ) - else: - - class CompiledClip(CompiledModule): + input_shapes = [str((batch_size, input_len)), str((batch_size, input_len))] + class CompiledTextEncoder(CompiledModule): if external_weights: params = export_parameters( text_encoder_model, @@ -154,31 +122,57 @@ class CompiledClip(CompiledModule): else: params = export_parameters(text_encoder_model) - def main(self, inp=AbstractTensor(1, input_len, dtype=torch.int64)): + def encode_tokens_attn_mask( + self, + inp=AbstractTensor(1, input_len, dtype=torch.int64), + attn_mask=AbstractTensor(1, input_len, dtype=torch.int64), + ): + return jittable(text_encoder_model.forward)(input_ids=inp, attention_mask=attn_mask) + + def encode_tokens( + self, + inp=AbstractTensor(1, input_len, dtype=torch.int64), + ): return jittable(text_encoder_model.forward)(input_ids=inp) import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledClip(context=Context(), import_to=import_to) - - module_str = str(CompiledModule.get_mlir_module(inst)) + inst = CompiledTextEncoder(context=Context(), import_to=import_to) + module = CompiledModule.get_mlir_module(inst) + + model_metadata_attn_mask = { + "model_name": hf_model_name + "_text_encoder", + "input_shapes": input_shapes, + "input_dtypes": ['int64', 'int64'], + "use_attention_mask": True, + } + model_metadata_encode = { + "model_name": hf_model_name + "_text_encoder", + "input_shapes": input_shapes[0], + "input_dtypes": ['int64'], + "use_attention_mask": False, + } + module = AddMetadataPass(module, model_metadata_attn_mask, "encode_tokens_attn_mask").run() + module = AddMetadataPass(module, model_metadata_encode, "encode_tokens").run() + + module_str = str(module) if compile_to != "vmfb": - return module_str, tokenizer + return module_str else: vmfb_path = utils.compile_to_vmfb( module_str, device, - target_triple, + target, ireec_flags, safe_name, return_path=not exit_on_vmfb, const_expr_hoisting=True, - attn_spec=td_spec, + attn_spec=attn_spec, ) - return vmfb_path, None + return vmfb_path if __name__ == "__main__": - from .sd_cmd_opts import args + from turbine_models.custom_models.sd_inference.sd_cmd_opts import args mod_str, _ = export_clip_model( args.hf_model_name, @@ -193,7 +187,7 @@ def main(self, inp=AbstractTensor(1, input_len, dtype=torch.int64)): exit_on_vmfb=True, pipeline_dir=args.pipeline_dir, input_mlir=args.input_mlir, - td_spec=args.attn_spec, + attn_spec=args.attn_spec, weights_only=False, upload_ir=False, ) diff --git a/models/turbine_models/custom_models/sd_inference/schedulers.py b/models/turbine_models/custom_models/sd_inference/schedulers.py index 2c8d618c6..8d8b5e651 100644 --- a/models/turbine_models/custom_models/sd_inference/schedulers.py +++ b/models/turbine_models/custom_models/sd_inference/schedulers.py @@ -74,6 +74,9 @@ def __init__( self.model = scheduler self.height = height self.width = width + self.is_sd3 = False + if "stable-diffusion-3" in hf_model_name: + self.is_sd3 = True self.batch_size = batch_size self.do_classifier_free_guidance = True self.model.set_timesteps(num_inference_steps) @@ -129,19 +132,26 @@ def step(self, noise_pred, t, sample, guidance_scale, i): class SharkSchedulerCPUWrapper: @torch.no_grad() def __init__( - self, scheduler, batch_size, num_inference_steps, dest_device, latents_dtype + self, scheduler, batch_size, dest_device, latents_dtype, conditional_timesteps=False ): self.do_classifier_free_guidance = True self.module = scheduler self.dest = dest_device - self.dtype = latents_dtype self.batch_size = batch_size self.timesteps = None + + # Enable this on init for models that use a pair of timestep values per unet step. + # this includes sd3 and some others we don't support yet. + # It allows passage of 'uncond_t' to the scale_model_input function and repeats the + # default timestep value if no 'uncond_t' is passed. + self.conditional_timesteps = conditional_timesteps + + self.dtype = latents_dtype self.torch_dtype = ( torch.float32 if latents_dtype == "float32" else torch.float16 ) - def initialize(self, sample, num_inference_steps): + def initialize_sdxl(self, sample, num_inference_steps): if isinstance(sample, ireert.DeviceArray): sample = torch.tensor(sample.to_host(), dtype=torch.float32) @@ -162,24 +172,36 @@ def initialize(self, sample, num_inference_steps): step_indexes = torch.tensor(len(self.timesteps)) timesteps = self.timesteps sample = sample * self.module.init_noise_sigma - add_time_ids = ireert.asdevicearray(self.dest, add_time_ids, self.dtype) return sample, add_time_ids, step_indexes, timesteps - def scale_model_input(self, sample, t, timesteps): + def initialize_sd(self, sample, num_inference_steps): + if isinstance(sample, ireert.DeviceArray): + sample = torch.tensor(sample.to_host(), dtype=torch.float32) + self.module.set_timesteps(num_inference_steps) + timesteps = self.module.timesteps + sample = sample * self.module.init_noise_sigma + return sample, timesteps + + def scale_model_input(self, sample, t, t_uncond=None): if self.do_classifier_free_guidance: sample = torch.cat([sample] * 2) - t = timesteps[t] + if self.conditional_timesteps: + if t_uncond: + t = torch.tensor([t, t_uncond]) + else: + t = torch.tensor([t, t]) + else: + t = torch.tensor([t]) scaled = self.module.scale_model_input(sample, t) - t = ireert.asdevicearray(self.dest, [t], self.dtype) - scaled = ireert.asdevicearray(self.dest, scaled, self.dtype) return scaled, t - def step(self, noise_pred, t, latents, guidance_scale, i): + def step(self, noise_pred, t, latents, guidance_scale): if isinstance(t, ireert.DeviceArray): t = torch.tensor(t.to_host()) + if isinstance(noise_pred, ireert.DeviceArray): + noise_pred = torch.tensor(noise_pred.to_host()) if isinstance(guidance_scale, ireert.DeviceArray): guidance_scale = torch.tensor(guidance_scale.to_host()) - noise_pred = torch.tensor(noise_pred.to_host()) if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( @@ -204,11 +226,14 @@ def export_scheduler_model( precision: str = "fp16", compile_to: str = "torch", device: str = None, - target_triple: str = None, + target: str = None, ireec_flags: str = None, exit_on_vmfb: bool = False, pipeline_dir: str = None, input_mlir: str = None, + attn_spec: str = None, + external_weights: str = None, + external_weight_path: str = None, upload_ir=False, ): dtype = torch.float16 if precision == "fp16" else torch.float32 @@ -233,9 +258,9 @@ def export_scheduler_model( vmfb_path = utils.compile_to_vmfb( input_mlir, device, - target_triple, + target, ireec_flags, - safe_name + "_" + target_triple, + safe_name, mlir_source="file", return_path=not exit_on_vmfb, ) @@ -329,9 +354,9 @@ class CompiledScheduler(CompiledModule): vmfb = utils.compile_to_vmfb( module_str, device, - target_triple, + target, ireec_flags, - safe_name + "_" + target_triple, + safe_name, return_path=True, ) if exit_on_vmfb: @@ -350,6 +375,8 @@ def get_scheduler(model_id, scheduler_id): scheduler = DPMSolverMultistepScheduler.from_pretrained( model_id, subfolder="scheduler", algorithm_type="dpmsolver++" ) + else: + raise ValueError(f"Scheduler {scheduler_id} not found.") if "Karras" in scheduler_id: scheduler.config.use_karras_sigmas = True diff --git a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py index e56737369..8d707582c 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py +++ b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py @@ -47,7 +47,7 @@ def is_valid_file(arg): "--scheduler_id", type=str, help="Scheduler ID", - default="Euler", + default="EulerDiscrete", ) ############################################################################## @@ -101,7 +101,7 @@ def is_valid_file(arg): p.add_argument( "--external_weights_dir", type=str, - default="", + default="./weights", help="Directory containing external weights for a job that requires more than one weights file. When importing, this is used to specify where to save the model weights, and at runtime, this is used to specify where to load the model weights from. Files will then be saved according to the parameters that make them unique, i.e. ___.", ) @@ -126,7 +126,7 @@ def is_valid_file(arg): p.add_argument( "--pipeline_dir", type=str, - default=None, + default="./vmfbs", help="Directory to save pipeline artifacts", ) @@ -137,6 +137,13 @@ def is_valid_file(arg): help="Do one-shot inference from tokens to image in a shrink-wrapped pipeline binary.", ) +p.add_argument( + "--cpu_scheduling", + default=True, + action="store_true", + help="Run scheduling on native pytorch CPU backend.", +) + ############################################################################## # SDXL Modelling Options # These options are used to control model defining parameters for SDXL. @@ -146,10 +153,10 @@ def is_valid_file(arg): p.add_argument("--batch_size", type=int, default=1, help="Batch size for inference") p.add_argument( - "--height", type=int, default=1024, help="Height of Stable Diffusion output image." + "--height", type=int, default=512, help="Height of Stable Diffusion output image." ) p.add_argument( - "--width", type=int, default=1024, help="Width of Stable Diffusion output image" + "--width", type=int, default=512, help="Width of Stable Diffusion output image" ) p.add_argument( "--precision", @@ -244,7 +251,7 @@ def is_valid_file(arg): p.add_argument( "--iree_target_triple", type=str, - default="", + default="x86_64-linux-gnu", help="Specify vulkan target triple or rocm/cuda target device.", ) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 3975bfbbb..2e0bbcc4d 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -17,7 +17,16 @@ schedulers, utils, ) -from .tokenization import get_weighted_text_embeddings +from turbine_models.custom_models.sdxl_inference import ( + sdxl_prompt_encoder as sdxl_clip, + unet as sdxl_unet, +) +from turbine_models.custom_models.sd3_inference import ( + sd3_text_encoders, + sd3_mmdit, +) +from turbine_models.custom_models.pipeline_base import TurbinePipelineBase, merge_arg_into_map +from turbine_models.custom_models.sd_inference.tokenization import encode_prompt from turbine_models.model_runner import vmfbRunner from transformers import CLIPTokenizer from pathlib import Path @@ -28,419 +37,324 @@ import time from datetime import datetime as dt - -device_list = [ - "cpu", - "vulkan", - "cuda", - "rocm", -] - -rt_device_list = [ - "local-task", - "local-sync", - "vulkan", - "cuda", - "rocm", - "hip", -] - -SUBMODELS = { - "clip": None, - "scheduler": None, - "unet": None, - "vae_decode": None, +# These are arguments common among submodel exports. +# They are expected to be populated in two steps: +# First, by the child class, +# and second by the base class for inference task-agnostic args. + +sd1_sd2_model_map = { + "text_encoder": { + "module_name": "compiled_text_encoder", + "keywords": ["clip"], + "dest_type": "torch", + "export_fn": clip.export_clip_model, + "export_args": { + "batch_size": 1, + "max_length": 64, + }, + }, + "unet": { + "module_name": "compiled_unet", + "keywords": ["unet"], + "export_fn": unet.export_unet_model, + "export_args": { + "batch_size": 1, + "height": 512, + "width": 512, + "max_length": 64, + "decomp_attn": None, + }, + }, + "vae": { + "module_name": "compiled_vae", + "keywords": ["vae"], + "dest_type": "numpy", + "export_fn": vae.export_vae_model, + "export_args": { + "batch_size": 1, + "height": 512, + "width": 512, + "num_channels": 4, + "decomp_attn": None, + }, + }, +} +sdxl_model_map = { + "text_encoder": { + "module_name": "compiled_clip", + "keywords": ["prompt_encoder"], + "dest_type": "torch", + "export_fn": sdxl_clip.export_prompt_encoder, + "export_args": { + "batch_size": 1, + "max_length": 64, + }, + }, + "unet": { + "module_name": "compiled_unet", + "keywords": ["unet", "!loop"], + "export_fn": sdxl_unet.export_unet_model, + "export_args": { + "batch_size": 1, + "height": 1024, + "width": 1024, + "max_length": 64, + "decomp_attn": None, + }, + }, + "vae": { + "module_name": "compiled_vae", + "keywords": ["vae"], + "dest_type": "numpy", + "export_fn": vae.export_vae_model, + "export_args": { + "batch_size": 1, + "height": 1024, + "width": 1024, + "num_channels": 4, + "decomp_attn": None, + }, + }, + "unetloop": { + "module_name": "sdxl_compiled_pipeline", + "load": False, + "keywords": ["unetloop"], + "wraps": ["unet", "scheduler"], + "export_args": { + "batch_size": 1, + "height": 1024, + "width": 1024, + "max_length": 64, + }, + }, + "fullpipeline": { + "module_name": "sdxl_compiled_pipeline", + "load": False, + "keywords": ["fullpipeline"], + "wraps": ["text_encoder", "unet", "scheduler", "vae"], + "export_args": { + "batch_size": 1, + "height": 1024, + "width": 1024, + "max_length": 64, + }, + }, +} +sd3_model_map = { + "text_encoder": { + "module_name": "compiled_text_encoder", + "keywords": ["text_encoder"], + "export_fn": sd3_text_encoders.export_text_encoders, + "export_args": { + "batch_size": 1, + "max_length": 64, + }, + }, + "mmdit": { + "module_name": "compiled_mmdit", + "keywords": ["mmdit"], + "export_fn": sd3_mmdit.export_mmdit_model, + "export_args": { + "batch_size": 1, + "height": 1024, + "width": 1024, + "max_length": 64, + "decomp_attn": None, + }, + }, + "vae": { + "module_name": "compiled_vae", + "keywords": ["vae"], + "dest_type": "numpy", + "export_fn": vae.export_vae_model, + "export_args": { + "batch_size": 1, + "height": 1024, + "width": 1024, + "num_channels": 16, + "decomp_attn": None, + }, + } } - -class SharkSDPipeline: +def get_sd_model_map(hf_model_name): + if isinstance(hf_model_name, dict): + name = hf_model_name["text_encoder"] + else: + name = hf_model_name + if name in ["stabilityai/sdxl-turbo", "stabilityai/stable-diffusion-xl-base-1.0"]: + return sdxl_model_map + elif "stabilityai/stable-diffusion-3" in name: + return sd3_model_map + else: + return sd1_sd2_model_map + +torch_dtypes = { + "fp32": torch.float32, + "fp16": torch.float16, + "float32": torch.float32, + "float16": torch.float16, +} +class SharkSDPipeline(TurbinePipelineBase): def __init__( self, - hf_model_name: str, - scheduler_id: str, + hf_model_name: str | dict[str], height: int, width: int, - precision: str, - max_length: int, batch_size: int, - num_inference_steps: int, - device: str, - iree_target_triple: str, - ireec_flags: dict = copy.deepcopy(SUBMODELS), - attn_spec: str = None, - decomp_attn: bool = False, - pipeline_dir: str | Path = "./shark_vmfbs", - external_weights_dir: str | Path = "./shark_weights", - external_weights: str = "safetensors", - custom_vae: str = None, - vae_decomp_attn: bool = True, + max_length: int | dict[int], + precision: str | dict[str], + device: str | dict[str], + target: str | dict[str], + ireec_flags: str | dict[str] = None, + attn_spec: str | dict[str] = None, + decomp_attn: bool | dict[bool] = False, + pipeline_dir: str = "./shark_vmfbs", + external_weights_dir: str = "./shark_weights", + external_weights: str | dict[str] = "safetensors", + num_inference_steps: int = 30, + cpu_scheduling: bool = True, + scheduler_id: str = None, # compatibility only + shift: float = 1.0, # compatibility only ): - self.hf_model_name = hf_model_name - self.iree_dtype = "float32" if precision == "fp32" else "float16" - self.torch_dtype = torch.float32 if precision == "fp32" else torch.float16 - self.cpu_scheduling = True - self.scheduler_id = scheduler_id + common_export_args = { + "hf_model_name": None, + "precision": None, + "compile_to": "vmfb", + "device": None, + "target": None, + "exit_on_vmfb": False, + "pipeline_dir": pipeline_dir, + "input_mlir": None, + "attn_spec": None, + "external_weights": None, + "external_weight_path": None, + } + sd_model_map = get_sd_model_map(hf_model_name) + for submodel in sd_model_map: + if "load" not in sd_model_map[submodel]: + sd_model_map[submodel]["load"] = True + sd_model_map[submodel]["export_args"]["batch_size"] = batch_size + if "max_length" in sd_model_map[submodel]["export_args"]: + max_length_sub = max_length if isinstance(max_length, int) else max_length[submodel] + sd_model_map[submodel]["export_args"]["max_length"] = max_length_sub + if "height" in sd_model_map[submodel]["export_args"]: + sd_model_map[submodel]["export_args"]["height"] = height + sd_model_map[submodel]["export_args"]["width"] = width + super().__init__( + sd_model_map, + device, + target, + ireec_flags, + precision, + attn_spec, + decomp_attn, + external_weights, + pipeline_dir, + external_weights_dir, + hf_model_name, + common_export_args, + ) + for submodel in sd_model_map: + if self.map[submodel].get("external_weights"): + weights_filename = utils.create_safe_name( + self.map[submodel]["export_args"]["hf_model_name"], + f"_{submodel}_{self.map[submodel]['precision']}", + ) + weights_filename += "." + self.map[submodel]["export_args"]["external_weights"] + self.map[submodel]["export_args"]["external_weight_path"] = weights_filename + + self.batch_size = batch_size + self.model_max_length = max_length self.height = height self.width = width - self.precision = precision - self.max_length = max_length - self.model_max_length = max_length - self.batch_size = batch_size + self.latents_dtype = torch_dtypes[self.map["unet"]["precision"]] + self.cpu_scheduling = cpu_scheduling + self.scheduler_id = scheduler_id self.num_inference_steps = num_inference_steps - self.device = device - self.iree_target_triple = iree_target_triple - self.ireec_flags = ireec_flags if ireec_flags else copy.deepcopy(SUBMODELS) - self.attn_spec = attn_spec - self.decomp_attn = decomp_attn - self.pipeline_dir = pipeline_dir - self.external_weights_dir = external_weights_dir - self.external_weights = external_weights - self.custom_vae = custom_vae - self.vae_decomp_attn = vae_decomp_attn - self.is_sdxl = "xl" in self.hf_model_name - - # FILE MANAGEMENT AND PIPELINE SETUP - - def check_prepared( - self, - mlirs: dict, - vmfbs: dict, - weights: dict, - interactive: bool = True, - ): - ready, vmfbs, weights = self.is_prepared(vmfbs, weights) - if not ready: - if interactive: - do_continue = input( - f"\nIt seems you are missing some necessary files. Would you like to generate them now? (y/n)" - ) - if do_continue.lower() != "y": - exit() - else: - do_continue = "y" - if do_continue.lower() == "y": - for submodel in vmfbs.keys(): - if vmfbs[submodel] == None: - vmfb, weight = self.export_submodel(submodel, input_mlir=mlirs) - vmfbs[submodel] = vmfb - if weights[submodel] is None: - weights[submodel] = weight - elif weights[submodel] is None and "scheduler" not in submodel: - _, weight = self.export_submodel(submodel, weights_only=True) - weights[submodel] = weight - ready, vmfbs, weights = self.is_prepared(vmfbs, weights) - if ready: - print("All necessary files found.") - return vmfbs, weights - else: - print("There was an error generating the necessary files.") - exit() - else: - print("All necessary files found. Loading pipeline.") - return vmfbs, weights - - def is_prepared(self, vmfbs, weights): - missing = [] - for key in vmfbs: - if "scheduler" in key and self.cpu_scheduling: - continue - default_filepath = os.path.join(self.pipeline_dir, key + ".vmfb") - if vmfbs[key] is not None and os.path.exists(vmfbs[key]): - continue - elif vmfbs[key] == None and os.path.exists(default_filepath): - vmfbs[key] = default_filepath - else: - missing.append(key + ".vmfb") - for w_key in weights: - if "scheduler" in w_key: - continue - if weights[w_key] is not None and os.path.exists(weights[w_key]): - continue - if self.external_weights is None: - weights[w_key] = None - continue - default_name = os.path.join( - self.external_weights_dir, w_key + "." + self.external_weights - ) - if weights[w_key] is None and os.path.exists(default_name): - weights[w_key] = os.path.join(default_name) - else: - missing.append(w_key + "." + self.external_weights) - if len(missing) > 0: - print(f"Missing files: " + ", ".join(missing)) - return False, vmfbs, weights - else: - return True, vmfbs, weights - def get_mlir_from_turbine_tank(self, submodel, container_name): - from turbine_models.turbine_tank import downloadModelArtifacts + self.text_encoder = None + self.unet = None + self.mmdit = None + self.vae = None + self.scheduler = None - safe_name = utils.create_safe_name( - self.hf_model_name, - f"_{self.max_length}_{self.height}x{self.width}_{self.precision}_{submodel}.mlir", - ) - mlir_path = downloadModelArtifacts( - safe_name, - container_name, - ) - return mlir_path + self.split_scheduler = True + if self.split_scheduler: + self.map.pop("unetloop") + self.map.pop("fullpipeline") - # IMPORT / COMPILE PHASE + self.base_model_name = hf_model_name if isinstance(hf_model_name, str) else hf_model_name.get("unet", hf_model_name.get("mmdit")) + self.is_img2img = False + self.is_sdxl = "xl" in self.base_model_name + self.is_sd3 = "stable-diffusion-3" in self.base_model_name + if self.is_sdxl: + self.tokenizers = [ + CLIPTokenizer.from_pretrained(self.base_model_name, subfolder="tokenizer"), + CLIPTokenizer.from_pretrained(self.base_model_name, subfolder="tokenizer_2"), + ] + elif not self.is_sd3: + self.tokenizer = CLIPTokenizer.from_pretrained(self.base_model_name, subfolder="tokenizer") - def get_torch_models(self, submodel): - match submodel: - case "unet": - unet_torch = unet.UnetModel( - self.hf_model_name, - ) - return unet_torch - case "vae_decode": - vae_torch = vae.VaeModel( - self.hf_model_name, - self.custom_vae, - ) - return vae_torch - def export_submodel( - self, - submodel: str, - input_mlir: str = None, - weights_only: bool = False, - ): - if not os.path.exists(self.pipeline_dir): - os.makedirs(self.pipeline_dir) - if self.external_weights_dir: - if not os.path.exists(self.external_weights_dir): - os.makedirs(external_weights_dir, exist_ok=True) - vae_external_weight_path = os.path.join( - self.external_weights_dir, "vae_decode." + self.external_weights - ) - unet_external_weight_path = os.path.join( - self.external_weights_dir, "unet." + self.external_weights - ) - clip_external_weight_path = os.path.join( - self.external_weights_dir, "clip." + self.external_weights - ) - elif self.external_weights is None: - print( - "No external weights type specified using --external_weights, weights for imported .mlir files will not be externalized." - ) - vae_external_weight_path = None - unet_external_weight_path = None - clip_external_weight_path = None - else: - print( - f"No external weights directory specified using --external_weights_dir, we assume you have your own weights in {self.pipeline_dir}." - ) - external_weights_dir = self.pipeline_dir - if not os.path.exists(self.pipeline_dir): - os.makedirs(self.pipeline_dir, exist_ok=True) - vae_external_weight_path = os.path.join( - self.pipeline_dir, "vae_decode." + self.external_weights - ) - unet_external_weight_path = os.path.join( - self.pipeline_dir, "unet." + self.external_weights + # RUN + def encode_prompts_sdxl(self, prompt, negative_prompt): + # Tokenize prompt and negative prompt. + text_input_ids_list = [] + uncond_input_ids_list = [] + + for tokenizer in self.tokenizers: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.model_max_length, + truncation=True, + return_tensors="pt", ) - clip_external_weight_path = os.path.join( - self.pipeline_dir, "clip." + self.external_weights + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=self.model_max_length, + truncation=True, + return_tensors="pt", ) - if weights_only: - input_mlir = copy.deepcopy(SUBMODELS) - match submodel: - case "clip": - _, clip_vmfb = clip.export_clip_model( - self.hf_model_name, - None, - self.max_length, - self.precision, - "vmfb", - self.external_weights, - clip_external_weight_path, - self.device, - self.iree_target_triple, - self.ireec_flags["clip"], - exit_on_vmfb=False, - pipeline_dir=self.pipeline_dir, - input_mlir=input_mlir["clip"], - td_spec=self.attn_spec, - weights_only=weights_only, - ) - return clip_vmfb, clip_external_weight_path - case "scheduler": - if self.cpu_scheduling: - return (None, None) - scheduler = schedulers.export_scheduler_model( - self.hf_model_name, - self.scheduler_id, - self.batch_size, - self.height, - self.width, - self.num_inference_steps, - self.precision, - "vmfb", - self.device, - self.iree_target_triple, - self.ireec_flags["scheduler"], - exit_on_vmfb=False, - pipeline_dir=self.pipeline_dir, - input_mlir=input_mlir["scheduler"], - ) - return scheduler, None - case "unet": - if input_mlir[submodel]: - unet_torch = None - else: - unet_torch = self.get_torch_models("unet") - - unet_vmfb = unet.export_unet_model( - unet_torch, - self.hf_model_name, - self.batch_size, - self.height, - self.width, - self.precision, - self.max_length, - None, - "vmfb", - self.external_weights, - unet_external_weight_path, - self.device, - self.iree_target_triple, - self.ireec_flags["unet"], - self.decomp_attn, - exit_on_vmfb=False, - pipeline_dir=self.pipeline_dir, - attn_spec=self.attn_spec, - input_mlir=input_mlir["unet"], - weights_only=weights_only, - ) - return unet_vmfb, unet_external_weight_path - case "vae_decode": - if not input_mlir[submodel]: - vae_torch = self.get_torch_models("vae_decode") - else: - vae_torch = None - vae_decode_vmfb = vae.export_vae_model( - vae_torch, - self.hf_model_name, - self.batch_size, - self.height, - self.width, - self.precision, - "vmfb", - self.external_weights, - vae_external_weight_path, - self.device, - self.iree_target_triple, - self.ireec_flags["vae"], - "decode", - self.vae_decomp_attn, - exit_on_vmfb=False, - pipeline_dir=self.pipeline_dir, - attn_spec=self.attn_spec, - input_mlir=input_mlir["vae_decode"], - weights_only=weights_only, - ) - return vae_decode_vmfb, vae_external_weight_path - - # LOAD + text_input_ids_list += text_inputs.input_ids.unsqueeze(0) + uncond_input_ids_list += uncond_input.input_ids.unsqueeze(0) - def load_pipeline( - self, - vmfbs: dict, - weights: dict, - rt_device: str = "local-task", - compiled_pipeline: bool = False, - ): - self.is_img2img = False - self.runners = {} - runners = {} - self.tokenizers = [] - self.tokenizers.append( - CLIPTokenizer.from_pretrained( - self.hf_model_name, - subfolder="tokenizer", - ) - ) - if self.is_sdxl: - self.tokenizers.append( - CLIPTokenizer.from_pretrained( - self.hf_model_name, - subfolder="tokenizer_2", - ) - ) - runners["clip"] = vmfbRunner(rt_device, vmfbs["clip"], weights["clip"]) - runners["unet"] = vmfbRunner(rt_device, vmfbs["unet"], weights["unet"]) - runners["vae_decode"] = vmfbRunner( - rt_device, vmfbs["vae_decode"], weights["vae_decode"] - ) - self.runners = runners - self.compiled_pipeline = False - if self.cpu_scheduling: - # torch_scheduler = schedulers.SchedulingModel( - # schedulers.get_scheduler(self.hf_model_name, self.scheduler_id), - # self.height, - # self.width, - # self.num_inference_steps, - # self.torch_dtype, - # ) - # self.scheduler = schedulers.SharkSchedulerCPUWrapper( - # self, torch_scheduler - # ) - self.scheduler = schedulers.get_scheduler( - self.hf_model_name, self.scheduler_id - ) + if self.compiled_pipeline: + return text_input_ids_list, uncond_input_ids_list else: - self.scheduler = schedulers.SharkSchedulerWrapper( - rt_device, vmfbs["scheduler"], weights["scheduler"] + prompt_embeds, add_text_embeds = self.text_encoder( + "encode_prompts", [*text_input_ids_list, *uncond_input_ids_list] ) - print("Successfully loaded pipeline.") - - # RUN + return prompt_embeds, add_text_embeds def prepare_latents( self, noise, num_inference_steps, - image, - strength, + image = None, + strength = None, ): - self.scheduler.set_timesteps(num_inference_steps) if self.is_img2img: - init_timestep = min( - int(num_inference_steps * strength), num_inference_steps - ) - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] - latents = self.encode_image(image) - latents = self.scheduler.add_noise(latents, noise, timesteps[0].repeat(1)) - return latents, [timesteps] + raise NotImplementedError("Image-to-image not supported yet.") + elif self.is_sdxl: + sample, add_time_ids, step_indexes, timesteps = self.scheduler.initialize_sdxl(noise, num_inference_steps) + return sample, add_time_ids, step_indexes, timesteps + elif self.is_sd3: + raise NotImplementedError("Stable Diffusion 3 not supported yet.") else: - self.scheduler.is_scale_input_called = True - latents = noise * self.scheduler.init_noise_sigma - return latents, self.scheduler.timesteps + sample, timesteps = self.scheduler.initialize_sd(noise, num_inference_steps) + return sample, timesteps - def generate_images( - self, - prompt: str, - negative_prompt: str = "", - batch_count: int = 1, - guidance_scale: float = 7.5, - seed: float = -1, - return_imgs: bool = False, - ): - pipe_start = time.time() + def get_rand_latents(self, seed, batch_count): samples = [] - numpy_images = [] - uint32_info = np.iinfo(np.uint32) uint32_min, uint32_max = uint32_info.min, uint32_info.max if seed < uint32_min or seed >= uint32_max: seed = randint(uint32_min, uint32_max) - - generator = torch.manual_seed(seed) for i in range(batch_count): - generator = torch.random.manual_seed(seed + i) + generator = torch.manual_seed(seed + i) rand_sample = torch.randn( ( self.batch_size, @@ -449,110 +363,187 @@ def generate_images( self.width // 8, ), generator=generator, - dtype=self.torch_dtype, + dtype=self.latents_dtype, ) samples.append(rand_sample) - # samples.append( - # ireert.asdevicearray( - # self.runners["unet"].config.device, - # rand_sample, - # dtype=self.iree_dtype, - # ) - # ) - - guidance_scale = ireert.asdevicearray( - self.runners["unet"].config.device, - np.asarray([guidance_scale]), - dtype=self.iree_dtype, - ) + return samples - tokenize_start = time.time() + def load_scheduler( + self, + scheduler_id: str, + steps: int = 30, + ): + self.scheduler = schedulers.get_scheduler(self.base_model_name, self.scheduler_id) + if self.is_sd3: + scheduler_device = self.mmdit.device + else: + scheduler_device = self.unet.device + if not self.cpu_scheduling: + self.scheduler = None + self.num_inference_steps = steps + self.scheduler_id = scheduler_id + scheduler_path = f"{scheduler_id}Scheduler_{self.num_inference_steps}" + if not os.path.exists(scheduler_path): + scheduler_path, _ = self.export_submodel("scheduler") + try: + self.scheduler = schedulers.SharkSchedulerWrapper( + scheduler_device, + scheduler_path, + ) + except: + print("JIT export of scheduler failed. Loading CPU scheduler.") + self.cpu_scheduling = True + if self.cpu_scheduling: + scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id) + self.scheduler = schedulers.SharkSchedulerCPUWrapper( + scheduler, + self.batch_size, + scheduler_device, + latents_dtype=self.latents_dtype, + ) - # Tokenize prompt and negative prompt. + def _produce_latents_sd( + self, + sample, + prompt_embeds, + negative_prompt_embeds, + steps, + guidance_scale, + ): + image = None + strength = 0 + sample, timesteps = self.prepare_latents( + sample, self.num_inference_steps, image, strength + ) + text_embeddings = torch.cat((negative_prompt_embeds, prompt_embeds), dim=0) + + for i, t in tqdm(enumerate(timesteps)): + latent_model_input, _ = self.scheduler.scale_model_input(sample, t) + timestep = torch.tensor([t]) + unet_inputs = [ + latent_model_input, + timestep, + ] + unet_inputs.extend([text_embeddings, guidance_scale]) + latents = self.unet( + "run_forward", + unet_inputs + ) + sample = self.scheduler.step( + torch.tensor(latents.to_host(), dtype=self.torch_dtype), t, sample + ).prev_sample + return sample - prompt_embeds, negative_embeds = get_weighted_text_embeddings( - self, prompt, negative_prompt + def _produce_latents_sdxl( + self, + sample, + prompt_embeds, + add_text_embeds, + steps, + guidance_scale, + ): + image = None + strength = 0 + latents, add_time_ids, step_indexes, timesteps = self.prepare_latents( + sample, self.num_inference_steps, image, strength ) + iree_inputs = [ + sample, + prompt_embeds, + add_text_embeds, + add_time_ids, + None, + ] + for i, t in tqdm(enumerate(timesteps)): + if self.cpu_scheduling: + step_index = i + else: + step_index = torch.tensor([i]) + latent_model_input, t = self.scheduler.scale_model_input( + latents, + t, + ) + noise_pred = self.unet( + "run_forward", + [ + latent_model_input, + t, + iree_inputs[1], + iree_inputs[2], + iree_inputs[3], + ], + ) + latents = self.scheduler.step( + noise_pred, + t, + latents, + guidance_scale, + ) + return latents - text_embeddings = torch.cat((negative_embeds, prompt_embeds), dim=0) - text_embeddings = ireert.asdevicearray( - self.runners["unet"].config.device, - text_embeddings, - dtype=self.iree_dtype, + def generate_images( + self, + prompt: str, + negative_prompt: str = "", + steps: int = 30, + batch_count: int = 1, + guidance_scale: float = 7.5, + seed: float = -1, + cpu_scheduling: bool = True, + scheduler_id: str = "EulerDiscrete", + return_imgs: bool = False, + ): + needs_new_scheduler = ( + (steps and steps != self.num_inference_steps) + or (cpu_scheduling != self.cpu_scheduling) + and self.split_scheduler ) - encode_prompts_end = time.time() + if not self.scheduler and not self.compiled_pipeline: + needs_new_scheduler = True + + if guidance_scale == 0: + negative_prompt = prompt + prompt = "" + + self.cpu_scheduling = cpu_scheduling + if steps and needs_new_scheduler: + self.num_inference_steps = steps + self.load_scheduler(scheduler_id, steps) + + pipe_start = time.time() + numpy_images = [] - for i in range(batch_count): - unet_start = time.time() - image = None - strength = 0 - sample, timesteps = self.prepare_latents( - samples[i], self.num_inference_steps, image, strength + samples = self.get_rand_latents(seed, batch_count) + + # Tokenize prompt and negative prompt. + if self.is_sdxl: + prompt_embeds, negative_embeds = self.encode_prompts_sdxl(prompt, negative_prompt) + else: + prompt_embeds, negative_embeds = encode_prompt( + self, prompt, negative_prompt ) - for i, t in tqdm(enumerate(timesteps)): - latents = self.scheduler.scale_model_input(sample, t).to( - self.torch_dtype + for i in range(batch_count): + produce_latents_input = [ + samples[i], + prompt_embeds, + negative_embeds, + steps, + guidance_scale, + ] + if self.is_sdxl: + latents = self._produce_latents_sdxl( + *produce_latents_input ) - timestep = torch.tensor([t]).to(self.torch_dtype).detach().numpy() - unet_inputs = [ - latents, - timestep, - ] - if self.cpu_scheduling: - for inp in unet_inputs: - inp = ireert.asdevicearray( - self.runners["unet"].config.device, - inp, - dtype=self.iree_dtype, - ) - unet_inputs.extend([text_embeddings, guidance_scale]) - latents = self.runners["unet"].ctx.modules.compiled_unet["main"]( - *unet_inputs + else: + latents = self._produce_latents_sd( + *produce_latents_input ) - sample = self.scheduler.step( - torch.tensor(latents.to_host(), dtype=self.torch_dtype), t, sample - ).prev_sample - - vae_start = time.time() - vae_out = self.runners["vae_decode"].ctx.modules.compiled_vae["main"]( - sample - ) - + image = self.vae("decode", [latents]) + numpy_images.append(image) pipe_end = time.time() - image = vae_out.to_host() - - numpy_images.append(image) - print("Batch #", i + 1, "\n") - print( - "UNet time(", - self.num_inference_steps, - "): ", - vae_start - unet_start, - "sec,", - ) - print( - "Unet average step latency: ", - (vae_start - unet_start) / self.num_inference_steps, - "sec", - ) - print("VAE time: ", pipe_end - vae_start, "sec") - print( - f"\nTotal time (txt2img, batch #{str(i+1)}): ", - (encode_prompts_end - tokenize_start) + (pipe_end - unet_start), - "sec\n", - ) - end = time.time() - print("Total CLIP time:", encode_prompts_end - tokenize_start, "sec") - print("Total tokenize time:", tokenize_start - tokenize_start, "sec") - print("Loading time: ", tokenize_start - pipe_start, "sec") - if batch_count > 1: - print( - f"Total inference time ({batch_count} batch(es)):", - end - tokenize_start, - "sec", - ) + logging.info(f"Total inference time: {pipe_end - pipe_start:.2f}s") timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") images = [] for idx, image in enumerate(numpy_images): @@ -587,9 +578,6 @@ def numpy_to_pil_image(images): if __name__ == "__main__": from turbine_models.custom_models.sd_inference.sd_cmd_opts import args - mlirs = copy.deepcopy(SUBMODELS) - vmfbs = copy.deepcopy(SUBMODELS) - weights = copy.deepcopy(SUBMODELS) ireec_flags = { "clip": args.ireec_flags + args.clip_flags, "scheduler": args.ireec_flags, @@ -597,37 +585,15 @@ def numpy_to_pil_image(images): "vae_decode": args.ireec_flags + args.vae_flags, } if not args.pipeline_dir: - pipe_id_list = [ - utils.create_safe_name(args.hf_model_name, args.iree_target_triple), - str(args.height), - str(args.width), - str(args.max_length), - args.precision, - args.device, - ] - args.pipeline_dir = os.path.join( - ".", - "_".join(pipe_id_list), - ) - if args.input_mlir: - user_mlir_list = args.input_mlir.split(",") - else: - user_mlir_list = [] - for submodel_id, mlir_path in zip(mlirs.keys(), user_mlir_list): - if submodel_id in mlir_path: - mlirs[submodel_id] = mlir_path - if not args.external_weights_dir and args.external_weights: - args.external_weights_dir = args.pipeline_dir + args.pipeline_dir = utils.create_safe_name(args.hf_model_name, "") sd_pipe = SharkSDPipeline( args.hf_model_name, - args.scheduler_id, args.height, args.width, - args.precision, - args.max_length, args.batch_size, - args.num_inference_steps, + args.max_length, + args.precision, args.device, args.iree_target_triple, ireec_flags, @@ -636,16 +602,21 @@ def numpy_to_pil_image(images): args.pipeline_dir, args.external_weights_dir, args.external_weights, - args.vae_decomp_attn, + args.num_inference_steps, + args.cpu_scheduling, + args.scheduler_id, ) - vmfbs, weights = sd_pipe.check_prepared(mlirs, vmfbs, weights) - sd_pipe.load_pipeline(vmfbs, weights, args.rt_device, args.compiled_pipeline) + sd_pipe.prepare_all() + sd_pipe.load_map() sd_pipe.generate_images( args.prompt, args.negative_prompt, + args.num_inference_steps, args.batch_count, args.guidance_scale, args.seed, + args.cpu_scheduling, + args.scheduler_id, False, ) print("Image generation complete.") diff --git a/models/turbine_models/custom_models/sd_inference/tokenization.py b/models/turbine_models/custom_models/sd_inference/tokenization.py index cfc140c57..d7c2bbf54 100644 --- a/models/turbine_models/custom_models/sd_inference/tokenization.py +++ b/models/turbine_models/custom_models/sd_inference/tokenization.py @@ -4,415 +4,163 @@ import torch import numpy as np -re_attention = re.compile( - r""" -\\\(| -\\\)| -\\\[| -\\]| -\\\\| -\\| -\(| -\[| -:([+-]?[.\d]+)\)| -\)| -]| -[^\\()\[\]:]+| -: -""", - re.X, -) - - -def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: - text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ - - res = [] - round_brackets = [] - square_brackets = [] - - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 - - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier - - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) - - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) - - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) - - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) - - if len(res) == 0: - res = [["", 1.0]] - - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1]: - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 - - return res - - -def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - # tokenize and discard the starting and the ending token - token = tokenizer(word).input_ids[1:-1] - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - print( - "Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples" - ) - return tokens, weights - - -def pad_tokens_and_weights( - tokens, - weights, - max_length, - bos, - eos, - no_boseos_middle=True, - chunk_length=77, -): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = ( - max_length if no_boseos_middle else max_embeddings_multiples * chunk_length - ) - for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) - if no_boseos_middle: - weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) +# The following is copied from Diffusers' "encode_prompt" function in the StableDiffusion pipeline. +# It has been lightly augmented to work with the SHARK-Turbine pipeline. +def encode_prompt( + pipe, + prompt, + negative_prompt=None, + num_images_per_prompt = 1, + do_classifier_free_guidance = True, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + # if lora_scale is not None and pipe.use_lora: + # pipe._lora_scale = lora_scale + + # # dynamically adjust the LoRA scale + # if not USE_PEFT_BACKEND: + # adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale) + # else: + # scale_lora_layers(pipe.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][ - j - * (chunk_length - 2) : min( - len(weights[i]), (j + 1) * (chunk_length - 2) - ) - ] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] - - return tokens, weights - - -def get_unweighted_text_embeddings( - pipe, - text_input, - chunk_length: int, - no_boseos_middle: Optional[bool] = True, -): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[ - :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 - ].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - text_input_chunk[:, -1] = text_input[0, -1] - - text_input_chunk = ireert.asdevicearray( - pipe.runners["clip"].config.device, text_input_chunk, "int64" + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + # if pipe.use_textual_inversion: + # prompt = pipe.maybe_convert_prompt(prompt, pipe.tokenizer) + + text_inputs = pipe.tokenizer( + prompt, + padding="max_length", + max_length=pipe.model_max_length, + truncation=True, + return_tensors="pt", ) - text_embedding = ( - pipe.runners["clip"].ctx.modules.compiled_clip["main"](text_input_chunk) - )[0].to_host() - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - # SHARK: Convert the result to tensor - # text_embeddings = torch.concat(text_embeddings, axis=1) - text_embeddings_np = np.concatenate(np.array(text_embeddings)) - text_embeddings = torch.from_numpy(text_embeddings_np) - else: - text_input = ireert.asdevicearray( - pipe.runners["clip"].config.device, text_input, "int64" - ) - text_embeddings = ( - pipe.runners["clip"].ctx.modules.compiled_clip["main"](text_input) - )[0].to_host() - text_embeddings = torch.from_numpy(text_embeddings) - return text_embeddings - - -# This function deals with NoneType values occuring in tokens after padding -# It switches out None with 49407 as truncating None values causes matrix dimension errors, -def filter_nonetype_tokens(tokens: List[List]): - return [[49407 if token is None else token for token in tokens[0]]] - - -def get_tokenized_inputs( - pipe, - tokenizer, - prompt, - uncond_prompt, - max_length, - max_embeddings_multiples: Optional[int] = 8, - no_boseos_middle: Optional[bool] = True, - skip_parsing: Optional[bool] = False, - skip_weighting: Optional[bool] = False, -): - if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights( - tokenizer, prompt, max_length - 2 - ) - if uncond_prompt is not None: - uncond_tokens, uncond_weights = get_prompts_with_weights( - tokenizer, uncond_prompt, max_length - 2 + text_input_ids = text_inputs.input_ids + untruncated_ids = pipe.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = pipe.tokenizer.batch_decode( + untruncated_ids[:, pipe.model_max_length - 1 : -1] + ) + print("The following text was removed due to truncation:", removed_text) + if pipe.text_encoder.metadata.get("use_attention_mask"): + attention_mask = text_inputs.attention_mask + prompt_embeds = pipe.text_encoder("encode_tokens_attn_mask", [text_input_ids, attention_mask]) + else: + attention_mask = None + prompt_embeds = pipe.text_encoder("encode_tokens", [text_input_ids]) + prompt_embeds = prompt_embeds[0] + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + # if pipe.use_textual_inversion: + # uncond_tokens = pipe.maybe_convert_prompt(uncond_tokens, pipe.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = pipe.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", ) - else: - prompt_tokens = [ - token[1:-1] - for token in tokenizer( - prompt, max_length=max_length, truncation=True - ).input_ids - ] - prompt_weights = [[1.0] * len(token) for token in prompt_tokens] - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens = [ - token[1:-1] - for token in tokenizer( - uncond_prompt, max_length=max_length, truncation=True - ).input_ids - ] - uncond_weights = [[1.0] * len(token) for token in uncond_tokens] - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - if uncond_prompt is not None: - max_length = max(max_length, max([len(token) for token in uncond_tokens])) - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (pipe.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - - max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = tokenizer.bos_token_id - eos = tokenizer.eos_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - no_boseos_middle=no_boseos_middle, - chunk_length=pipe.model_max_length, - ) - - # FIXME: This is a hacky fix caused by tokenizer padding with None values - prompt_tokens = filter_nonetype_tokens(prompt_tokens) - - # prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu") - if uncond_prompt is not None: - uncond_tokens, uncond_weights = pad_tokens_and_weights( - uncond_tokens, - uncond_weights, - max_length, - bos, - eos, - no_boseos_middle=no_boseos_middle, - chunk_length=pipe.model_max_length, - ) - - # FIXME: This is a hacky fix caused by tokenizer padding with None values - uncond_tokens = filter_nonetype_tokens(uncond_tokens) + if pipe.text_encoder.metadata.get("use_attention_mask"): + attention_mask = uncond_input.attention_mask + negative_prompt_embeds = pipe.text_encoder( + "encode_tokens_attn_mask", + [ + uncond_input.input_ids, + attention_mask, + ], + ) + else: + attention_mask = None + negative_prompt_embeds = pipe.text_encoder( + "encode_tokens", + [ + uncond_input.input_ids, + ], + ) - # uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) - uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device="cpu") - if uncond_prompt is not None: - return prompt_tokens, prompt_weights, uncond_tokens, uncond_weights - else: - return prompt_tokens, prompt_weights, None, None + negative_prompt_embeds = negative_prompt_embeds[0] -def get_weighted_text_embeddings( - pipe, - prompt: List[str], - uncond_prompt: List[str] = None, - max_embeddings_multiples: Optional[int] = 8, - no_boseos_middle: Optional[bool] = True, - skip_parsing: Optional[bool] = False, - skip_weighting: Optional[bool] = False, -): - max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2 - for tokenizer in pipe.tokenizers: - ( - prompt_tokens, - prompt_weights, - uncond_tokens, - uncond_weights, - ) = get_tokenized_inputs( - pipe, - tokenizer, - prompt, - uncond_prompt, - max_length, - max_embeddings_multiples, - no_boseos_middle, - skip_parsing, - skip_weighting, - ) + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] - # get the embeddings - text_embeddings = get_unweighted_text_embeddings( - pipe, - prompt_tokens, - pipe.model_max_length, - no_boseos_middle=no_boseos_middle, - ) - # prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) - prompt_weights = torch.tensor(prompt_weights, dtype=torch.float, device="cpu") - if uncond_prompt is not None: - uncond_embeddings = get_unweighted_text_embeddings( - pipe, - uncond_tokens, - pipe.model_max_length, - no_boseos_middle=no_boseos_middle, - ) - # uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) - uncond_weights = torch.tensor(uncond_weights, dtype=torch.float, device="cpu") + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # assign weights to the prompts and normalize in the sense of mean - # TODO: should we normalize by chunk or in a whole (current implementation)? - if (not skip_parsing) and (not skip_weighting): - previous_mean = ( - text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - ) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = ( - text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - ) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - if uncond_prompt is not None: - previous_mean = ( - uncond_embeddings.float() - .mean(axis=[-2, -1]) - .to(uncond_embeddings.dtype) - ) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = ( - uncond_embeddings.float() - .mean(axis=[-2, -1]) - .to(uncond_embeddings.dtype) - ) - uncond_embeddings *= ( - (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - ) + #if pipe.use_lora: + # Retrieve the original scale by scaling back the LoRA layers + # unimplemented + # unscale_lora_layers(pipe.text_encoder, lora_scale) - if uncond_prompt is not None: - return text_embeddings, uncond_embeddings - return text_embeddings, None + return prompt_embeds, negative_prompt_embeds diff --git a/models/turbine_models/custom_models/sd_inference/unet.py b/models/turbine_models/custom_models/sd_inference/unet.py index ac66d3108..4b37178d4 100644 --- a/models/turbine_models/custom_models/sd_inference/unet.py +++ b/models/turbine_models/custom_models/sd_inference/unet.py @@ -15,6 +15,7 @@ from shark_turbine.dynamo.passes import ( DEFAULT_DECOMPOSITIONS, ) +from shark_turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch import torch._dynamo as dynamo @@ -33,32 +34,25 @@ def __init__(self, hf_model_name): subfolder="unet", ) - def forward(self, sample, timestep, encoder_hidden_states, guidance_scale): - samples = torch.cat([sample] * 2) + def forward(self, latent_model_input, timestep, encoder_hidden_states): unet_out = self.unet.forward( - samples, timestep, encoder_hidden_states, return_dict=False + latent_model_input, timestep, encoder_hidden_states, return_dict=False )[0] - noise_pred_uncond, noise_pred_text = unet_out.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - return noise_pred + return unet_out def export_unet_model( - unet_model, hf_model_name, batch_size, height, width, precision="fp32", max_length=77, - hf_auth_token=None, compile_to="torch", external_weights=None, external_weight_path=None, device=None, - target_triple=None, + target=None, ireec_flags=None, decomp_attn=False, exit_on_vmfb=False, @@ -68,22 +62,28 @@ def export_unet_model( weights_only=False, upload_ir=False, ): - if "turbo" in hf_model_name: - do_classifier_free_guidance = False - else: - do_classifier_free_guidance = True - if pipeline_dir: - safe_name = os.path.join(pipeline_dir, f"unet") + if input_mlir: + unet_model = None else: - safe_name = utils.create_safe_name( + unet_model = UnetModel( hf_model_name, - f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_unet_{device}", ) + dtype = torch.float16 if precision == "fp16" else torch.float32 + np_dtype = "float16" if precision == "fp16" else "float32" + safe_name = utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_unet", + ) + if decomp_attn: + safe_name += "_decomp_attn" + if pipeline_dir: + safe_name = os.path.join(pipeline_dir, safe_name) + if input_mlir: vmfb_path = utils.compile_to_vmfb( input_mlir, device, - target_triple, + target, ireec_flags, safe_name, mlir_source="file", @@ -93,15 +93,6 @@ def export_unet_model( return vmfb_path mapper = {} - decomp_list = copy.deepcopy(DEFAULT_DECOMPOSITIONS) - if decomp_attn == True: - decomp_list.extend( - [ - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, - torch.ops.aten._scaled_dot_product_flash_attention.default, - ] - ) - dtype = torch.float16 if precision == "fp16" else torch.float32 if precision == "fp16": unet_model = unet_model.half() @@ -114,76 +105,94 @@ def export_unet_model( return external_weight_path sample = ( - batch_size, + batch_size * 2, unet_model.unet.config.in_channels, height // 8, width // 8, ) - encoder_hidden_states_sizes = ( unet_model.unet.config.layers_per_block, max_length, unet_model.unet.config.cross_attention_dim, ) - - class CompiledUnet(CompiledModule): - if external_weights: - params = export_parameters( - unet_model, external=True, external_scope="", name_mapper=mapper.get - ) - else: - params = export_parameters(unet_model) - - def main( - self, - sample=AbstractTensor(*sample, dtype=dtype), - timestep=AbstractTensor(1, dtype=dtype), - encoder_hidden_states=AbstractTensor( - *encoder_hidden_states_sizes, dtype=dtype - ), - guidance_scale=AbstractTensor(1, dtype=dtype), + example_forward_args = [ + torch.empty(sample, dtype=dtype), + torch.empty(1, dtype=dtype), + torch.empty(encoder_hidden_states_sizes, dtype=dtype), + ] + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + fxb = FxProgramsBuilder(unet_model) + + @fxb.export_program( + args=(example_forward_args,), + ) + def _forward( + module, + inputs, ): - return jittable(unet_model.forward, decompose_ops=decomp_list)( - sample, timestep, encoder_hidden_states, guidance_scale - ) - - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledUnet(context=Context(), import_to=import_to) + return module.forward(*inputs) - module_str = str(CompiledModule.get_mlir_module(inst)) + class CompiledUnet(CompiledModule): + run_forward = _forward + if external_weights: + externalize_module_parameters(unet_model) + + inst = CompiledUnet(context=Context(), import_to="IMPORT") + + module = CompiledModule.get_mlir_module(inst) + + model_metadata_run_forward = { + "model_name": "sd_unet", + "input_shapes": [ + sample, + (1,), + encoder_hidden_states_sizes, + ], + "input_dtypes": [np_dtype for x in range(3)], + "output_shapes": [sample], + "output_dtypes": [np_dtype], + } + + module = AddMetadataPass(module, model_metadata_run_forward, "run_forward").run() + module_str = str(module) if compile_to != "vmfb": return module_str else: - utils.compile_to_vmfb( + vmfb_path = utils.compile_to_vmfb( module_str, device, - target_triple, + target, ireec_flags, safe_name, - return_path=False, + return_path=True, attn_spec=attn_spec, ) + if exit_on_vmfb: + exit() + return vmfb_path if __name__ == "__main__": from turbine_models.custom_models.sd_inference.sd_cmd_opts import args - if args.input_mlir: - unet_model = None - else: - unet_model = UnetModel( - args.hf_model_name, - ) mod_str = export_unet_model( - unet_model, args.hf_model_name, args.batch_size, args.height, args.width, args.precision, args.max_length, - args.hf_auth_token, args.compile_to, args.external_weights, args.external_weight_path, diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index cf6b5946a..2099c4550 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -131,7 +131,8 @@ def compile_to_vmfb( save_mlir=True, attn_spec=None, winograd=False, - masked_attention=False, + masked_attention=True, + debug=False, ): flags = [] if mlir_source == "file" and not isinstance(module_str, str): @@ -199,7 +200,6 @@ def compile_to_vmfb( elif ireec_flags == None: ireec_flags = [] - debug = False if debug: flags.extend( ["--iree-hal-dump-executable-files-to=" + safe_name + "_dispatches"] diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 475cf1d1d..f76b2710c 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -13,6 +13,7 @@ from shark_turbine.dynamo.passes import ( DEFAULT_DECOMPOSITIONS, ) +from shark_turbine.transforms.general.add_metadata import AddMetadataPass from turbine_models.custom_models.sd_inference import utils import torch import torch._dynamo as dynamo @@ -54,50 +55,79 @@ def __init__( ) self.vae.load_state_dict(custom_vae) - def decode_inp(self, inp): - inp = 1 / 0.18215 * inp + def decode(self, inp): + inp = 1 / self.vae.config.scaling_factor * inp x = self.vae.decode(inp, return_dict=False)[0] return (x / 2 + 0.5).clamp(0, 1) - def encode_inp(self, inp): + def encode(self, inp): latents = self.vae.encode(inp).latent_dist.sample() - return 0.18215 * latents + return self.vae.config.scaling_factor * latents + +class SD3VaeModel(torch.nn.Module): + def __init__( + self, + hf_model_name, + ): + super().__init__() + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + + def decode(self, inp): + inp = (inp / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(inp, return_dict=False)[0] + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + return image + def encode(self, inp): + image_np = inp / 255.0 + image_np = np.moveaxis(image_np, 2, 0) + batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0) + image_torch = torch.from_numpy(batch_images) + image_torch = 2.0 * image_torch - 1.0 + image_torch = image_torch + latent = self.vae.encode(image_torch) + return latent def export_vae_model( - vae_model, hf_model_name, batch_size, height, width, precision, compile_to="torch", + num_channels=4, external_weights=None, external_weight_path=None, device=None, - target_triple=None, + target=None, ireec_flags=None, - variant="decode", decomp_attn=False, exit_on_vmfb=False, pipeline_dir=None, attn_spec=None, input_mlir=None, weights_only=False, - upload_ir=False, ): + dtype = torch.float16 if precision == "fp16" else torch.float32 + np_dtype = "float16" if precision == "fp16" else "float32" + safe_name = utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{height}x{width}_{precision}_vae", + ) + if decomp_attn: + safe_name += "_decomp_attn" if pipeline_dir: - safe_name = os.path.join(pipeline_dir, "vae_" + variant) - else: - safe_name = utils.create_safe_name( - hf_model_name, - f"_bs{batch_size}_{height}x{width}_{precision}_vae_{variant}_{device}", - ) + safe_name = os.path.join(pipeline_dir, safe_name) + if input_mlir: vmfb_path = utils.compile_to_vmfb( input_mlir, device, - target_triple, + target, ireec_flags, safe_name, mlir_source="file", @@ -105,46 +135,93 @@ def export_vae_model( attn_spec=attn_spec, ) return vmfb_path + + if "stable-diffusion-3" in hf_model_name: + vae_model = SD3VaeModel(hf_model_name) + else: + if "xl" in hf_model_name and precision == "fp16": + custom_vae = "madebyollin/sdxl-vae-fp16-fix" + else: + custom_vae = None + vae_model = VaeModel(hf_model_name, custom_vae=custom_vae) + + if dtype == torch.float16: + vae_model = vae_model.half() mapper = {} - decomp_list = DEFAULT_DECOMPOSITIONS - if decomp_attn: - decomp_list.extend( - [ - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, - torch.ops.aten._scaled_dot_product_flash_attention.default, - ] - ) - dtype = torch.float16 if precision == "fp16" else torch.float32 - vae_model = vae_model.to(dtype) utils.save_external_weights( mapper, vae_model, external_weights, external_weight_path ) if weights_only: return external_weight_path - sample = (batch_size, 4, height // 8, width // 8) - if variant == "encode": - sample = (batch_size, 3, height, width) - class CompiledVae(CompiledModule): - params = export_parameters(vae_model) + input_image_shape = (height, width, 3) + input_latents_shape = (batch_size, num_channels, height // 8, width // 8) + encode_args = [ + torch.empty( + input_image_shape, + dtype=torch.float32, + ) + ] + decode_args = [ + torch.empty( + input_latents_shape, + dtype=dtype, + ) + ] + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + fxb = FxProgramsBuilder(vae_model) + + # @fxb.export_program(args=(encode_args,)) + # def _encode(module, inputs,): + # return module.encode(*inputs) + + @fxb.export_program(args=(decode_args,)) + def _decode(module, inputs): + return module.decode(*inputs) + + class CompiledVae(CompiledModule): + decode = _decode + + if external_weights: + externalize_module_parameters(vae_model) - def main(self, inp=AbstractTensor(*sample, dtype=dtype)): - if variant == "decode": - return jittable(vae_model.decode_inp, decompose_ops=decomp_list)(inp) - elif variant == "encode": - return jittable(vae_model.encode_inp, decompose_ops=decomp_list)(inp) + inst = CompiledVae(context=Context(), import_to="IMPORT") - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledVae(context=Context(), import_to=import_to) + module = CompiledModule.get_mlir_module(inst) + + model_metadata_decode = { + "model_name": "vae_decode", + "input_shapes": [input_latents_shape], + "input_dtypes": [np_dtype], + "output_shapes": [(3, width, height) * batch_size], + "output_dtypes": ["float32"], + } + model_metadata_encode = { + "model_name": "vae_encode", + "input_shapes": [input_image_shape], + "input_dtypes": [np_dtype], + "output_shapes": [input_latents_shape], + "output_dtypes": [np_dtype], + } + module = AddMetadataPass(module, model_metadata_decode, "decode").run() - module_str = str(CompiledModule.get_mlir_module(inst)) if compile_to != "vmfb": - return module_str + return str(module) else: vmfb_path = utils.compile_to_vmfb( - module_str, + str(module), device, - target_triple, + target, ireec_flags, safe_name, return_path=not exit_on_vmfb, @@ -161,7 +238,7 @@ def main(self, inp=AbstractTensor(*sample, dtype=dtype)): else: vae_model = VaeModel( args.hf_model_name, - custom_vae=custom_vae, + custom_vae=None, ) mod_str = export_vae_model( vae_model, @@ -174,7 +251,7 @@ def main(self, inp=AbstractTensor(*sample, dtype=dtype)): external_weights=args.external_weights, external_weight_path=args.external_weight_path, device=args.device, - target_triple=args.iree_target_triple, + target=args.iree_target_triple, ireec_flags=args.ireec_flags + args.attn_flags + args.vae_flags, variant=args.vae_variant, decomp_attn=args.decomp_attn, diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index 224e63233..2b2ea00ba 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -12,6 +12,8 @@ from iree.compiler.ir import Context import numpy as np from shark_turbine.aot import * +from shark_turbine.transforms.general.add_metadata import AddMetadataPass + from turbine_models.custom_models.sd_inference import utils import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer @@ -155,29 +157,26 @@ def export_prompt_encoder( hf_model_name, hf_auth_token=None, max_length=64, + batch_size=1, precision="fp16", compile_to="torch", external_weights=None, external_weight_path=None, device=None, - target_triple=None, + target=None, ireec_flags=None, exit_on_vmfb=True, pipeline_dir=None, input_mlir=None, attn_spec=None, weights_only=False, - batchsize=1, batch_input=False, ): - if "turbo" in hf_model_name: - do_classifier_free_guidance = False - else: - do_classifier_free_guidance = True + do_classifier_free_guidance = True safe_name = utils.create_safe_name( hf_model_name, - f"_bs{batchsize}_{str(max_length)}-{precision}-prompt-encoder-{device}", + f"_bs{batch_size}_{str(max_length)}-{precision}-prompt-encoder-{device}", ) if pipeline_dir not in [None, ""]: safe_name = os.path.join(pipeline_dir, safe_name) @@ -186,9 +185,9 @@ def export_prompt_encoder( vmfb_path = utils.compile_to_vmfb( input_mlir, device, - target_triple, + target, ireec_flags, - safe_name + "_" + target_triple, + safe_name, mlir_source="file", return_path=not exit_on_vmfb, const_expr_hoisting=True, @@ -214,7 +213,7 @@ def export_prompt_encoder( precision, hf_auth_token, do_classifier_free_guidance, - batch_size=batchsize, + batch_size=batch_size, batch_input=batch_input, ) @@ -265,7 +264,16 @@ def encode_prompts_turbo( import_to = "INPUT" if compile_to == "linalg" else "IMPORT" inst = CompiledClip(context=Context(), import_to=import_to) - module_str = str(CompiledModule.get_mlir_module(inst)) + module = CompiledModule.get_mlir_module(inst) + + model_metadata_encode = { + "model_name": hf_model_name + "_text_encoder", + "input_shapes": [str((1, max_length)) for i in range(4)], + "input_dtypes": ['int64' for i in range(4)], + "use_attention_mask": False, + } + module = AddMetadataPass(module, model_metadata_encode, "encode_prompts").run() + module_str = str(module) if compile_to != "vmfb": return module_str, tokenizers @@ -273,9 +281,9 @@ def encode_prompts_turbo( vmfb_path = utils.compile_to_vmfb( module_str, device, - target_triple, + target, ireec_flags, - safe_name + "_" + target_triple, + safe_name, return_path=not exit_on_vmfb, const_expr_hoisting=True, attn_spec=attn_spec, @@ -290,6 +298,7 @@ def encode_prompts_turbo( args.hf_model_name, args.hf_auth_token, args.max_length, + args.batch_size, args.precision, args.compile_to, args.external_weights, @@ -301,7 +310,6 @@ def encode_prompts_turbo( pipeline_dir=args.pipeline_dir, input_mlir=args.input_mlir, attn_spec=args.attn_spec, - batchsize=args.batch_size, ) if args.input_mlir: exit() diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 6b45ab799..5410c57b1 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -12,9 +12,8 @@ from iree.compiler.ir import Context import numpy as np from shark_turbine.aot import * -from shark_turbine.dynamo.passes import ( - DEFAULT_DECOMPOSITIONS, -) +from shark_turbine.transforms.general.add_metadata import AddMetadataPass + from turbine_models.custom_models.sd_inference import utils import torch import torch._dynamo as dynamo @@ -72,7 +71,6 @@ def forward( @torch.no_grad() def export_unet_model( - unet_model, hf_model_name, batch_size, height, @@ -84,7 +82,7 @@ def export_unet_model( external_weights=None, external_weight_path=None, device=None, - target_triple=None, + target=None, ireec_flags=None, decomp_attn=False, exit_on_vmfb=False, @@ -93,6 +91,7 @@ def export_unet_model( input_mlir=None, weights_only=False, ): + unet_model = UnetModel(hf_model_name, hf_auth_token, precision) safe_name = utils.create_safe_name( hf_model_name, f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_unet", @@ -107,9 +106,9 @@ def export_unet_model( vmfb_path = utils.compile_to_vmfb( input_mlir, device, - target_triple, + target, ireec_flags, - safe_name + "_" + target_triple, + safe_name, mlir_source="file", return_path=not exit_on_vmfb, attn_spec=attn_spec, @@ -117,8 +116,18 @@ def export_unet_model( return vmfb_path mapper = {} - dtype = torch.float16 if precision == "fp16" else torch.float32 - + np_dtypes = { + "fp16": "float16", + "fp32": "float32", + "i8": "int8", + } + torch_dtypes = { + "fp16": torch.float16, + "fp32": torch.float32, + "i8": torch.int8, + } + dtype = torch_dtypes[precision] + np_dtype = np_dtypes[precision] if precision == "fp16": unet_model = unet_model.half() @@ -132,6 +141,12 @@ def export_unet_model( do_classifier_free_guidance = True init_batch_dim = 2 if do_classifier_free_guidance else 1 + sample = [ + batch_size, + unet_model.unet.config.in_channels, + height // 8, + width // 8, + ] prepared_latents = ( batch_size * init_batch_dim, unet_model.unet.config.in_channels, @@ -180,17 +195,34 @@ class CompiledUnet(CompiledModule): inst = CompiledUnet(context=Context(), import_to="IMPORT") - module_str = str(CompiledModule.get_mlir_module(inst)) + module = CompiledModule.get_mlir_module(inst) + + model_metadata_run_forward = { + "model_name": "sd_unet", + "input_shapes": [ + prepared_latents, + (1,), + prompt_embeds_shape, + text_embeds_shape, + time_ids_shape, + ], + "input_dtypes": [np_dtype for x in range(5)], + "output_shapes": [sample], + "output_dtypes": [np_dtype], + } + + module = AddMetadataPass(module, model_metadata_run_forward, "run_forward").run() + module_str = str(module) if compile_to != "vmfb": return module_str else: vmfb_path = utils.compile_to_vmfb( module_str, device, - target_triple, + target, ireec_flags, - safe_name + "_" + target_triple, + safe_name, return_path=True, attn_spec=attn_spec, ) diff --git a/models/turbine_models/custom_models/sdxl_inference/vae_runner.py b/models/turbine_models/custom_models/sdxl_inference/vae_runner.py index 539c99868..98aae9a28 100644 --- a/models/turbine_models/custom_models/sdxl_inference/vae_runner.py +++ b/models/turbine_models/custom_models/sdxl_inference/vae_runner.py @@ -21,49 +21,7 @@ def run_vae( def run_torch_vae(hf_model_name, custom_vae, variant, example_input): - from diffusers import AutoencoderKL - - class VaeModel(torch.nn.Module): - def __init__( - self, - hf_model_name, - custom_vae=custom_vae, - ): - super().__init__() - self.vae = None - if custom_vae in ["", None]: - self.vae = AutoencoderKL.from_pretrained( - hf_model_name, - subfolder="vae", - ) - elif not isinstance(custom_vae, dict): - try: - # custom HF repo with no vae subfolder - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - ) - except: - # some larger repo with vae subfolder - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - subfolder="vae", - ) - else: - # custom vae as a HF state dict - self.vae = AutoencoderKL.from_pretrained( - hf_model_name, - subfolder="vae", - ) - self.vae.load_state_dict(custom_vae) - - def decode_inp(self, inp): - inp = inp / 0.13025 - x = self.vae.decode(inp, return_dict=False)[0] - return (x / 2 + 0.5).clamp(0, 1) - - def encode_inp(self, inp): - latents = self.vae.encode(inp).latent_dist.sample() - return 0.13025 * latents + from turbine_models.custom_models.sd_inference.vae import VaeModel vae_model = VaeModel( hf_model_name, diff --git a/models/turbine_models/tests/pipeline_test.py b/models/turbine_models/tests/pipeline_test.py index 76e33c96a..4290e1446 100644 --- a/models/turbine_models/tests/pipeline_test.py +++ b/models/turbine_models/tests/pipeline_test.py @@ -43,8 +43,6 @@ def forward(self, x): torch.no_grad() - - def export_dummy_model(): model = TestModule() target = "x86_64-unknown-linux-gnu" @@ -85,9 +83,9 @@ class CompiledTester(CompiledModule): class TestPipeline(TurbinePipelineBase): def __init__( self, - **kwargs, + **base_args, ): - super().__init__(**kwargs) + super().__init__(**base_args) def run(self, inputs: list): return self.test_model_1("forward", *inputs) @@ -103,14 +101,12 @@ def setUp(self): "safe_name": "TestModel2xLinear", "keywords": ["Test", "Model", "2x", "Linear"], "export_fn": export_dummy_model, - "export_args": None, } } self.pipe = TestPipeline( model_map=model_map, - batch_size=1, device="cpu", - iree_target_triple="x86_64-unknown-linux-gnu", + target="x86_64-unknown-linux-gnu", pipeline_dir="./", precision="fp32", ) @@ -136,6 +132,5 @@ def test_pipeline_metadata(self): expected, metadata ) - if __name__ == "__main__": unittest.main() From c2d8d5fa46656c78c13ba7bcbe46f757de156b53 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 2 Jul 2024 02:05:49 -0500 Subject: [PATCH 02/12] Initial punet integration. --- .../custom_models/sd_inference/schedulers.py | 5 +- .../custom_models/sd_inference/sd_cmd_opts.py | 6 + .../custom_models/sd_inference/sd_pipeline.py | 27 +++- .../custom_models/sd_inference/utils.py | 2 +- .../custom_models/sdxl_inference/unet.py | 147 +++++++++++++++--- 5 files changed, 155 insertions(+), 32 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/schedulers.py b/models/turbine_models/custom_models/sd_inference/schedulers.py index 8d8b5e651..4cab7e510 100644 --- a/models/turbine_models/custom_models/sd_inference/schedulers.py +++ b/models/turbine_models/custom_models/sd_inference/schedulers.py @@ -147,6 +147,7 @@ def __init__( self.conditional_timesteps = conditional_timesteps self.dtype = latents_dtype + self.use_punet = False self.torch_dtype = ( torch.float32 if latents_dtype == "float32" else torch.float16 ) @@ -183,7 +184,7 @@ def initialize_sd(self, sample, num_inference_steps): return sample, timesteps def scale_model_input(self, sample, t, t_uncond=None): - if self.do_classifier_free_guidance: + if self.do_classifier_free_guidance and not self.use_punet: sample = torch.cat([sample] * 2) if self.conditional_timesteps: if t_uncond: @@ -202,7 +203,7 @@ def step(self, noise_pred, t, latents, guidance_scale): noise_pred = torch.tensor(noise_pred.to_host()) if isinstance(guidance_scale, ireert.DeviceArray): guidance_scale = torch.tensor(guidance_scale.to_host()) - if self.do_classifier_free_guidance: + if self.do_classifier_free_guidance and not self.use_punet: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond diff --git a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py index 8d707582c..12a4e3d3b 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py +++ b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py @@ -181,6 +181,12 @@ def is_valid_file(arg): help="Decompose attention for VAE decode only at fx graph level", ) +p.add_argument( + "--use_i8_punet", + action="store_true", + help="Use i8 quantized Partitioned UNet for inference", +) + ############################################################################## # SDXL script general options. ############################################################################## diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 2e0bbcc4d..a51612579 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -194,6 +194,8 @@ def get_sd_model_map(hf_model_name): "fp16": torch.float16, "float32": torch.float32, "float16": torch.float16, + "int8": torch.int8, + "i8": torch.int8, } class SharkSDPipeline(TurbinePipelineBase): def __init__( @@ -216,6 +218,7 @@ def __init__( cpu_scheduling: bool = True, scheduler_id: str = None, # compatibility only shift: float = 1.0, # compatibility only + use_i8_punet: bool = False, ): common_export_args = { "hf_model_name": None, @@ -296,6 +299,16 @@ def __init__( elif not self.is_sd3: self.tokenizer = CLIPTokenizer.from_pretrained(self.base_model_name, subfolder="tokenizer") + self.use_i8_punet = self.use_punet = use_i8_punet + if self.use_i8_punet: + self.map["unet"]["export_args"]["precision"] = "i8" + self.map["unet"]["export_args"]["use_punet"] = True + for i in self.map["unet"]["keywords"]: + i = i.replace("fp16", "i8").replace("fp32", "i8") + self.map["unet"]["keywords"].append("punet") + else: + self.map["unet"]["keywords"].append("!punet") + # RUN def encode_prompts_sdxl(self, prompt, negative_prompt): @@ -401,6 +414,8 @@ def load_scheduler( scheduler_device, latents_dtype=self.latents_dtype, ) + if self.use_punet: + self.scheduler.use_punet = True def _produce_latents_sd( self, @@ -447,13 +462,13 @@ def _produce_latents_sdxl( latents, add_time_ids, step_indexes, timesteps = self.prepare_latents( sample, self.num_inference_steps, image, strength ) - iree_inputs = [ - sample, + iree_unet_inputs = [ prompt_embeds, add_text_embeds, add_time_ids, - None, ] + if self.use_punet: + iree_unet_inputs.append(guidance_scale) for i, t in tqdm(enumerate(timesteps)): if self.cpu_scheduling: step_index = i @@ -468,9 +483,7 @@ def _produce_latents_sdxl( [ latent_model_input, t, - iree_inputs[1], - iree_inputs[2], - iree_inputs[3], + *iree_unet_inputs, ], ) latents = self.scheduler.step( @@ -605,6 +618,8 @@ def numpy_to_pil_image(images): args.num_inference_steps, args.cpu_scheduling, args.scheduler_id, + None, + args.use_i8_punet, ) sd_pipe.prepare_all() sd_pipe.load_map() diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 2099c4550..7dd045493 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -131,7 +131,7 @@ def compile_to_vmfb( save_mlir=True, attn_spec=None, winograd=False, - masked_attention=True, + masked_attention=False, debug=False, ): flags = [] diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 5410c57b1..cdd548809 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -7,6 +7,7 @@ import copy import os import sys +import safetensors from iree import runtime as ireert from iree.compiler.ir import Context @@ -14,14 +15,15 @@ from shark_turbine.aot import * from shark_turbine.transforms.general.add_metadata import AddMetadataPass + from turbine_models.custom_models.sd_inference import utils import torch -import torch._dynamo as dynamo -from diffusers import UNet2DConditionModel +from huggingface_hub import hf_hub_download class UnetModel(torch.nn.Module): def __init__(self, hf_model_name, hf_auth_token=None, precision="fp32"): + from diffusers import UNet2DConditionModel super().__init__() if precision == "fp16": try: @@ -68,6 +70,79 @@ def forward( )[0] return noise_pred +def get_punet_model(hf_model_name, external_weight_path, precision="i8"): + from sharktank.models.punet.model import Unet2DConditionModel as punet_unet, ClassifierFreeGuidanceUnetModel as CFGPunetModel + + if precision == "i8": + repo_id = "amd-shark/sdxl-quant-models" + subfolder = "unet/int8" + revision = "82e06d6ea22ac78102a9aded69e8ddfb9fa4ae37" + elif precision in ["fp16", "fp32"]: + repo_id = hf_model_name + subfolder = "unet" + revision = "76d28af79639c28a79fa5c6c6468febd3490a37e" + def download(filename): + return hf_hub_download( + repo_id=repo_id, subfolder=subfolder, filename=filename, revision=revision + ) + results = { + "config.json": download("config.json"), + "params.safetensors": download("params.safetensors"), + } + if precision == "i8": + results["quant_params.json"] = download("quant_params.json") + output_path = external_weight_path.split("unet")[0] + "punet_dataset_i8.irpa" + ds = get_punet_i8_dataset(results["config.json"], results["quant_params.json"], results["params.safetensors"], output_path, base_params=None) + else: + ds = None # get_punet_dataset(results["config.json"], results["params.safetensors"], base_params=None) + + cond_unet = punet_unet.from_dataset(ds) + mdl = CFGPunetModel(cond_unet) + return mdl + +def get_punet_i8_dataset(config_json_path, quant_params_path, params_path, output_path="./punet_dataset_i8.irpa", quant_params_struct=None, base_params=None): + from sharktank.models.punet.tools.import_brevitas_dataset import ( + _load_json, + _load_theta, + _get_dataset_props, + apply_per_layer_quant, + Dataset, + Theta, + InferenceTensor, + ) + # Construct the pre-transform dataset. + dataset_props = _get_dataset_props(_load_json(config_json_path)) + quant_params_struct = _load_json(quant_params_path) + with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: + quant_theta = _load_theta(st) + base_theta = None + if base_params is not None: + print("Initializing from base parameters:", args.base_params) + with safetensors.safe_open( + base_params, framework="pt", device="cpu" + ) as st: + base_theta = _load_theta(st) + + ds = Dataset(dataset_props, quant_theta if base_theta is None else base_theta) + + # The quant_params_struct has quantization parameter structs keyed by full + # layer name. We process each of these in turn to produce a per-layer + # quantization scheme where no quantized tensors escape their layer. + updated_tensors: dict[str, InferenceTensor] = {} + for layer_name, qp in quant_params_struct.items(): + print(f"Applying per-layer quants: {layer_name}") + apply_per_layer_quant(quant_theta, layer_name, qp, updated_tensors) + + # Apply updates into a new Theta. + theta = base_theta if base_theta is not None else quant_theta + flat_tensors = theta.flatten() + flat_tensors.update(updated_tensors) + ds.root_theta = Theta(flat_tensors) + + # TODO: Post-process to introduce fused cross-layer connections. + + ds.save(output_path, io_report_callback=print) + return ds @torch.no_grad() def export_unet_model( @@ -90,11 +165,17 @@ def export_unet_model( attn_spec=None, input_mlir=None, weights_only=False, + use_punet=False, ): - unet_model = UnetModel(hf_model_name, hf_auth_token, precision) + if use_punet: + unet_model = get_punet_model(hf_model_name, external_weight_path, precision) + submodel_name = "punet" + else: + unet_model = UnetModel(hf_model_name, hf_auth_token, precision) + submodel_name = "unet" safe_name = utils.create_safe_name( hf_model_name, - f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_unet", + f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_{submodel_name}", ) if pipeline_dir: safe_name = os.path.join(pipeline_dir, safe_name) @@ -131,6 +212,9 @@ def export_unet_model( if precision == "fp16": unet_model = unet_model.half() + if use_punet: + dtype = torch.float16 + utils.save_external_weights( mapper, unet_model, external_weights, external_weight_path ) @@ -143,13 +227,13 @@ def export_unet_model( sample = [ batch_size, - unet_model.unet.config.in_channels, + 4, height // 8, width // 8, ] prepared_latents = ( batch_size * init_batch_dim, - unet_model.unet.config.in_channels, + 4, height // 8, width // 8, ) @@ -164,7 +248,14 @@ def export_unet_model( torch.empty(text_embeds_shape, dtype=dtype), torch.empty(time_ids_shape, dtype=dtype), ] - + example_forward_args_dict = { + "sample": torch.rand(sample, dtype=dtype), + "timestep": torch.zeros(1, dtype=dtype), + "encoder_hidden_states": torch.rand(prompt_embeds_shape, dtype=dtype), + "text_embeds": torch.rand(text_embeds_shape, dtype=dtype), + "time_ids": torch.zeros(time_ids_shape, dtype=dtype), + "guidance_scale": torch.tensor([7.5], dtype=dtype), + } decomp_list = [] if decomp_attn == True: decomp_list = [ @@ -176,26 +267,33 @@ def export_unet_model( from_current=True, add_ops=decomp_list, ): - fxb = FxProgramsBuilder(unet_model) - - @fxb.export_program( - args=(example_forward_args,), - ) - def _forward( - module, - inputs, - ): - return module.forward(*inputs) - - class CompiledUnet(CompiledModule): - run_forward = _forward - if external_weights: externalize_module_parameters(unet_model) + if use_punet: + output = export( + unet_model, + kwargs=example_forward_args_dict, + module_name="compiled_unet", + ) + module = output.mlir_module + else: + fxb = FxProgramsBuilder(unet_model) + + @fxb.export_program( + args=(example_forward_args,), + ) + def _forward( + module, + inputs, + ): + return module.forward(*inputs) + + class CompiledUnet(CompiledModule): + run_forward = _forward - inst = CompiledUnet(context=Context(), import_to="IMPORT") + inst = CompiledUnet(context=Context(), import_to="IMPORT") - module = CompiledModule.get_mlir_module(inst) + module = CompiledModule.get_mlir_module(inst) model_metadata_run_forward = { "model_name": "sd_unet", @@ -210,6 +308,9 @@ class CompiledUnet(CompiledModule): "output_shapes": [sample], "output_dtypes": [np_dtype], } + if use_punet: + model_metadata_run_forward["input_shapes"].append((1,)) + model_metadata_run_forward["input_dtypes"].append(np_dtype) module = AddMetadataPass(module, model_metadata_run_forward, "run_forward").run() From a9ead64722219bbfe803957f4dc53782a793663c Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 2 Jul 2024 12:09:35 -0500 Subject: [PATCH 03/12] Formatting. --- .../custom_models/pipeline_base.py | 103 ++++-- .../custom_models/sd3_inference/sd3_mmdit.py | 2 +- .../sd3_inference/sd3_schedulers.py | 2 +- .../sd3_inference/sd3_text_encoders.py | 9 +- .../custom_models/sd_inference/clip.py | 18 +- .../custom_models/sd_inference/schedulers.py | 7 +- .../custom_models/sd_inference/sd_pipeline.py | 150 +++++---- .../sd_inference/tokenization.py | 314 +++++++++--------- .../custom_models/sd_inference/vae.py | 8 +- .../sdxl_inference/sdxl_prompt_encoder.py | 2 +- .../custom_models/sdxl_inference/unet.py | 41 ++- models/turbine_models/tests/pipeline_test.py | 3 + 12 files changed, 392 insertions(+), 267 deletions(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 98601ba76..f224c4854 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -38,6 +38,7 @@ "float32": torch.float32, } + def merge_arg_into_map(model_map, arg, arg_name): if isinstance(arg, dict): for key in arg.keys(): @@ -51,6 +52,7 @@ def merge_arg_into_map(model_map, arg, arg_name): model_map[key][arg_name] = arg return model_map + def merge_export_arg(model_map, arg, arg_name): if isinstance(arg, dict): for key in arg.keys(): @@ -64,7 +66,7 @@ def merge_export_arg(model_map, arg, arg_name): continue model_map[key]["export_args"][arg_name] = arg return model_map - + # def str_to_list(string): # out = string.strip("[]").replace(" ", "").split(";") @@ -72,6 +74,7 @@ def merge_export_arg(model_map, arg, arg_name): # item = ast.literal_eval(item) # return out + class PipelineComponent: """ Wraps a VMFB runner with attributes for embedded metadata, device info, utilities and @@ -98,7 +101,9 @@ def load( extra_plugin=None, ): self.module_name = module_name - print(f"Loading {module_name} from {vmfb_path} with external weights: {external_weight_path}.") + print( + f"Loading {module_name} from {vmfb_path} with external weights: {external_weight_path}." + ) self.runner = vmfbRunner( rt_device, vmfb_path, external_weight_path, extra_plugin ) @@ -117,7 +122,9 @@ def get_metadata(self): if any(x in function_name for x in ["$async", "__init"]): continue try: - self.metadata[function_name] = self.module[function_name].vm_function.reflection + self.metadata[function_name] = self.module[ + function_name + ].vm_function.reflection except: logging.warning( f"Could not get metadata for {self.module_name}['{function_name}']." @@ -126,10 +133,14 @@ def get_metadata(self): def _validate_or_convert_inputs(self, function_name, inputs): if self.metadata: - expected_input_shapes = self.metadata.get(function_name, {}).get("input_shapes") + expected_input_shapes = self.metadata.get(function_name, {}).get( + "input_shapes" + ) if expected_input_shapes: expected_input_shapes = ast.literal_eval(expected_input_shapes) - expected_input_dtypes = self.metadata.get(function_name, {}).get("input_dtypes", "") + expected_input_dtypes = self.metadata.get(function_name, {}).get( + "input_dtypes", "" + ) if expected_input_dtypes: expected_input_dtypes = ast.literal_eval(expected_input_dtypes) if not isinstance(expected_input_shapes, list): @@ -146,7 +157,9 @@ def _validate_or_convert_inputs(self, function_name, inputs): pass for i, input_dtype in enumerate(expected_input_dtypes): if not isinstance(inputs[i], ireert.DeviceArray): - if isinstance(inputs[i], torch.Tensor) or isinstance(inputs[i], torch.HalfTensor): + if isinstance(inputs[i], torch.Tensor) or isinstance( + inputs[i], torch.HalfTensor + ): new_input = inputs[i].float().cpu().numpy() else: new_input = inputs[i] @@ -175,7 +188,7 @@ def _validate_or_convert_inputs(self, function_name, inputs): for i in inputs: if not isinstance(i, ireert.DeviceArray): i = ireert.asdevicearray(self.device, i) - + def _output_cast(self, output): if isinstance(output, tuple): out_tuple = () @@ -185,16 +198,22 @@ def _output_cast(self, output): return out_tuple match self.dest_type: case "devicearray": - output = output.astype(self.dest_dtype) if output.dtype != self.dest_dtype else output + output = ( + output.astype(self.dest_dtype) + if output.dtype != self.dest_dtype + else output + ) return output case "torch": - output = torch.tensor(output.to_host(), dtype=torch_dtypes[self.dest_dtype]) + output = torch.tensor( + output.to_host(), dtype=torch_dtypes[self.dest_dtype] + ) return output case "numpy": return output.to_host().astype(np_dtypes[self.dest_dtype]) case _: return output - + def _run(self, function_name, inputs: list): return self.module[function_name](*inputs) @@ -314,7 +333,7 @@ def __init__( print(map_arguments) for arg in map_arguments.keys(): self.map = merge_arg_into_map(self.map, map_arguments[arg], arg) - + self.map = merge_arg_into_map( self.map, np_dtypes[self.map[submodel]["precision"]], "np_dtype" ) @@ -323,13 +342,17 @@ def __init__( ) for arg in common_export_args.keys(): for submodel in self.map.keys(): - self.map[submodel].get("export_args", {})[arg] = self.map[submodel].get(arg, common_export_args[arg]) + self.map[submodel].get("export_args", {})[arg] = self.map[submodel].get( + arg, common_export_args[arg] + ) for submodel in self.map.keys(): for key, value in map_arguments.items(): self.map = merge_export_arg(self.map, value, key) for key, value in self.map[submodel].get("export_args", {}).items(): if key == "hf_model_name": - self.map[submodel]["keywords"].append(utils.create_safe_name(value.split("/")[-1], "")) + self.map[submodel]["keywords"].append( + utils.create_safe_name(value.split("/")[-1], "") + ) if key == "decomp_attn": if not value: self.map[submodel]["keywords"].append("!decomp_attn") @@ -343,7 +366,6 @@ def __init__( elif key in ["max_length", "precision"]: self.map[submodel]["keywords"].append(str(value)) - self.pipeline_dir = pipeline_dir if not os.path.exists(self.pipeline_dir): os.makedirs(self.pipeline_dir) @@ -387,12 +409,16 @@ def prepare_all( for submodel in self.map.keys(): if not self.map[submodel].get("vmfb"): print("Fetching: ", submodel) - self.export_submodel(submodel, input_mlir=self.map[submodel].get("mlir")) + self.export_submodel( + submodel, input_mlir=self.map[submodel].get("mlir") + ) if not self.map[submodel]["export_args"]["external_weights"]: assert not self.map[submodel].get( "weights" ), f"External weights should not be used for a model with inlined params." - if not self.map[submodel].get("weights") and self.map[submodel]["export_args"].get("external_weights"): + if not self.map[submodel].get("weights") and self.map[submodel][ + "export_args" + ].get("external_weights"): self.export_submodel(submodel, weights_only=True) return self.prepare_all(mlirs, vmfbs, weights, interactive) @@ -434,9 +460,13 @@ def is_prepared(self, vmfbs, weights): avail_files = os.listdir(pipeline_dir) candidates = [] for filename in avail_files: - if all(str(x) in filename for x in keywords) and not any(x in filename for x in neg_keywords): + if all(str(x) in filename for x in keywords) and not any( + x in filename for x in neg_keywords + ): candidates.append(os.path.join(pipeline_dir, filename)) - if all(str(x) in filename for x in mlir_keywords) and not any(x in filename for x in neg_keywords): + if all(str(x) in filename for x in mlir_keywords) and not any( + x in filename for x in neg_keywords + ): self.map[key]["mlir"] = os.path.join(pipeline_dir, filename) if len(candidates) == 1: self.map[key]["vmfb"] = candidates[0] @@ -519,7 +549,9 @@ def export_submodel( self.map[submodel]["export_args"]["external_weight_path"] = os.path.join( self.external_weights_dir, - utils.create_safe_name(self.map[submodel]["export_args"].get("hf_model_name", ""), "") + utils.create_safe_name( + self.map[submodel]["export_args"].get("hf_model_name", ""), "" + ) + f"_{submodel}_{self.map[submodel]['precision']}." + self.map[submodel]["external_weights"], ) @@ -555,10 +587,18 @@ def export_submodel( self.map[submodel]["export_args"]["max_length"], "unet_loop", ) - dims = [self.map[submodel]["export_args"]["width"], self.map[submodel]["export_args"]["height"]] + dims = [ + self.map[submodel]["export_args"]["width"], + self.map[submodel]["export_args"]["height"], + ] dims = "x".join([str(x) for x in dims]) pipeline_keys = [ - utils.create_safe_name(self.map[submodel]["export_args"]["hf_model_name"].split("/")[-1], ""), + utils.create_safe_name( + self.map[submodel]["export_args"]["hf_model_name"].split("/")[ + -1 + ], + "", + ), "bs" + str(self.map[submodel]["export_args"]["batch_size"]), dims, self.map[submodel]["export_args"]["precision"], @@ -585,10 +625,18 @@ def export_submodel( self.map[submodel]["export_args"]["max_length"], "tokens_to_image", ) - dims = [self.map[submodel]["export_args"]["width"], self.map[submodel]["export_args"]["height"]] + dims = [ + self.map[submodel]["export_args"]["width"], + self.map[submodel]["export_args"]["height"], + ] dims = "x".join([str(x) for x in dims]) pipeline_keys = [ - utils.create_safe_name(self.map[submodel]["export_args"]["hf_model_name"].split("/")[-1], ""), + utils.create_safe_name( + self.map[submodel]["export_args"]["hf_model_name"].split("/")[ + -1 + ], + "", + ), "bs" + str(self.map[submodel]["export_args"]["batch_size"]), dims, self.map[submodel]["export_args"]["precision"], @@ -616,11 +664,14 @@ def export_submodel( exported = self.map[submodel]["export_fn"](**export_args) else: exported = self.map[submodel]["export_fn"]() - if not self.map[submodel].get("weights") and os.path.exists(self.map[submodel]["export_args"].get("external_weight_path")): - self.map[submodel]["weights"] = self.map[submodel]["export_args"].get("external_weight_path", None) + if not self.map[submodel].get("weights") and os.path.exists( + self.map[submodel]["export_args"].get("external_weight_path") + ): + self.map[submodel]["weights"] = self.map[submodel][ + "export_args" + ].get("external_weight_path", None) if not weights_only: self.map[submodel]["vmfb"] = exported - # LOAD def load_map(self): diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py index e19cac162..d87ff5993 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py @@ -249,7 +249,7 @@ class CompiledMmdit(CompiledModule): hidden_states_shape, encoder_hidden_states_shape, pooled_projections_shape, - init_batch_dim + init_batch_dim, ], "input_dtypes": [np_dtype for x in range(4)], "output_shapes": [hidden_states_shape], diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py index 6b4fe135b..2c1d04cf1 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py @@ -320,7 +320,7 @@ class CompiledScheduler(CompiledModule): inst = CompiledScheduler(context=Context(), import_to=import_to) module = CompiledModule.get_mlir_module(inst) - + model_metadata_run_init = { "model_name": "sd3_scheduler_FlowEulerDiscrete", "input_shapes": [sample], diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py index 2784e873e..33107aa9f 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py @@ -191,14 +191,17 @@ class CompiledTextEncoder(CompiledModule): save_module_parameters(external_weight_path, model) inst = CompiledTextEncoder(context=Context(), import_to="IMPORT") - + module = CompiledModule.get_mlir_module(inst) - + model_metadata_forward = { "model_name": "sd3_clip_t5xxl_text_encoders", "input_shapes": [(1, max_length, 2) for x in range(6)], "input_dtypes": ["int64" for x in range(6)], - "output_shapes": [(2*output_batchsize,max_length*2,4096), (2*output_batchsize,2048)], + "output_shapes": [ + (2 * output_batchsize, max_length * 2, 4096), + (2 * output_batchsize, 2048), + ], "output_dtypes": ["float32"], } module = AddMetadataPass(module, model_metadata_forward, "forward").run() diff --git a/models/turbine_models/custom_models/sd_inference/clip.py b/models/turbine_models/custom_models/sd_inference/clip.py index a4c177736..11705a916 100644 --- a/models/turbine_models/custom_models/sd_inference/clip.py +++ b/models/turbine_models/custom_models/sd_inference/clip.py @@ -15,6 +15,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor from turbine_models.turbine_tank import turbine_tank + @torch.no_grad() def export_clip_model( hf_model_name, @@ -87,9 +88,10 @@ def export_clip_model( ) if weights_only: return external_weight_path - + if "google/t5" in hf_model_name: input_shapes = [(batch_size, input_len), (batch_size, input_len)] + class CompiledTextEncoder(CompiledModule): if external_weights: params = export_parameters( @@ -109,8 +111,10 @@ def encode_tokens( return jittable(text_encoder_model.forward)( input_ids=inp, decoder_input_ids=decoder_input_ids ) + else: input_shapes = [str((batch_size, input_len)), str((batch_size, input_len))] + class CompiledTextEncoder(CompiledModule): if external_weights: params = export_parameters( @@ -127,7 +131,9 @@ def encode_tokens_attn_mask( inp=AbstractTensor(1, input_len, dtype=torch.int64), attn_mask=AbstractTensor(1, input_len, dtype=torch.int64), ): - return jittable(text_encoder_model.forward)(input_ids=inp, attention_mask=attn_mask) + return jittable(text_encoder_model.forward)( + input_ids=inp, attention_mask=attn_mask + ) def encode_tokens( self, @@ -142,16 +148,18 @@ def encode_tokens( model_metadata_attn_mask = { "model_name": hf_model_name + "_text_encoder", "input_shapes": input_shapes, - "input_dtypes": ['int64', 'int64'], + "input_dtypes": ["int64", "int64"], "use_attention_mask": True, } model_metadata_encode = { "model_name": hf_model_name + "_text_encoder", "input_shapes": input_shapes[0], - "input_dtypes": ['int64'], + "input_dtypes": ["int64"], "use_attention_mask": False, } - module = AddMetadataPass(module, model_metadata_attn_mask, "encode_tokens_attn_mask").run() + module = AddMetadataPass( + module, model_metadata_attn_mask, "encode_tokens_attn_mask" + ).run() module = AddMetadataPass(module, model_metadata_encode, "encode_tokens").run() module_str = str(module) diff --git a/models/turbine_models/custom_models/sd_inference/schedulers.py b/models/turbine_models/custom_models/sd_inference/schedulers.py index 4cab7e510..cba6cfdf6 100644 --- a/models/turbine_models/custom_models/sd_inference/schedulers.py +++ b/models/turbine_models/custom_models/sd_inference/schedulers.py @@ -132,7 +132,12 @@ def step(self, noise_pred, t, sample, guidance_scale, i): class SharkSchedulerCPUWrapper: @torch.no_grad() def __init__( - self, scheduler, batch_size, dest_device, latents_dtype, conditional_timesteps=False + self, + scheduler, + batch_size, + dest_device, + latents_dtype, + conditional_timesteps=False, ): self.do_classifier_free_guidance = True self.module = scheduler diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index a51612579..9c977208d 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -25,7 +25,10 @@ sd3_text_encoders, sd3_mmdit, ) -from turbine_models.custom_models.pipeline_base import TurbinePipelineBase, merge_arg_into_map +from turbine_models.custom_models.pipeline_base import ( + TurbinePipelineBase, + merge_arg_into_map, +) from turbine_models.custom_models.sd_inference.tokenization import encode_prompt from turbine_models.model_runner import vmfbRunner from transformers import CLIPTokenizer @@ -174,9 +177,10 @@ "num_channels": 16, "decomp_attn": None, }, - } + }, } + def get_sd_model_map(hf_model_name): if isinstance(hf_model_name, dict): name = hf_model_name["text_encoder"] @@ -188,7 +192,8 @@ def get_sd_model_map(hf_model_name): return sd3_model_map else: return sd1_sd2_model_map - + + torch_dtypes = { "fp32": torch.float32, "fp16": torch.float16, @@ -197,6 +202,8 @@ def get_sd_model_map(hf_model_name): "int8": torch.int8, "i8": torch.int8, } + + class SharkSDPipeline(TurbinePipelineBase): def __init__( self, @@ -239,7 +246,9 @@ def __init__( sd_model_map[submodel]["load"] = True sd_model_map[submodel]["export_args"]["batch_size"] = batch_size if "max_length" in sd_model_map[submodel]["export_args"]: - max_length_sub = max_length if isinstance(max_length, int) else max_length[submodel] + max_length_sub = ( + max_length if isinstance(max_length, int) else max_length[submodel] + ) sd_model_map[submodel]["export_args"]["max_length"] = max_length_sub if "height" in sd_model_map[submodel]["export_args"]: sd_model_map[submodel]["export_args"]["height"] = height @@ -264,8 +273,12 @@ def __init__( self.map[submodel]["export_args"]["hf_model_name"], f"_{submodel}_{self.map[submodel]['precision']}", ) - weights_filename += "." + self.map[submodel]["export_args"]["external_weights"] - self.map[submodel]["export_args"]["external_weight_path"] = weights_filename + weights_filename += ( + "." + self.map[submodel]["export_args"]["external_weights"] + ) + self.map[submodel]["export_args"][ + "external_weight_path" + ] = weights_filename self.batch_size = batch_size self.model_max_length = max_length @@ -287,17 +300,27 @@ def __init__( self.map.pop("unetloop") self.map.pop("fullpipeline") - self.base_model_name = hf_model_name if isinstance(hf_model_name, str) else hf_model_name.get("unet", hf_model_name.get("mmdit")) + self.base_model_name = ( + hf_model_name + if isinstance(hf_model_name, str) + else hf_model_name.get("unet", hf_model_name.get("mmdit")) + ) self.is_img2img = False self.is_sdxl = "xl" in self.base_model_name self.is_sd3 = "stable-diffusion-3" in self.base_model_name if self.is_sdxl: self.tokenizers = [ - CLIPTokenizer.from_pretrained(self.base_model_name, subfolder="tokenizer"), - CLIPTokenizer.from_pretrained(self.base_model_name, subfolder="tokenizer_2"), + CLIPTokenizer.from_pretrained( + self.base_model_name, subfolder="tokenizer" + ), + CLIPTokenizer.from_pretrained( + self.base_model_name, subfolder="tokenizer_2" + ), ] elif not self.is_sd3: - self.tokenizer = CLIPTokenizer.from_pretrained(self.base_model_name, subfolder="tokenizer") + self.tokenizer = CLIPTokenizer.from_pretrained( + self.base_model_name, subfolder="tokenizer" + ) self.use_i8_punet = self.use_punet = use_i8_punet if self.use_i8_punet: @@ -309,8 +332,48 @@ def __init__( else: self.map["unet"]["keywords"].append("!punet") + # LOAD + + def load_scheduler( + self, + scheduler_id: str, + steps: int = 30, + ): + self.scheduler = schedulers.get_scheduler( + self.base_model_name, self.scheduler_id + ) + if self.is_sd3: + scheduler_device = self.mmdit.device + else: + scheduler_device = self.unet.device + if not self.cpu_scheduling: + self.scheduler = None + self.num_inference_steps = steps + self.scheduler_id = scheduler_id + scheduler_path = f"{scheduler_id}Scheduler_{self.num_inference_steps}" + if not os.path.exists(scheduler_path): + scheduler_path, _ = self.export_submodel("scheduler") + try: + self.scheduler = schedulers.SharkSchedulerWrapper( + scheduler_device, + scheduler_path, + ) + except: + print("JIT export of scheduler failed. Loading CPU scheduler.") + self.cpu_scheduling = True + if self.cpu_scheduling: + scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id) + self.scheduler = schedulers.SharkSchedulerCPUWrapper( + scheduler, + self.batch_size, + scheduler_device, + latents_dtype=self.latents_dtype, + ) + if self.use_punet: + self.scheduler.use_punet = True # RUN + def encode_prompts_sdxl(self, prompt, negative_prompt): # Tokenize prompt and negative prompt. text_input_ids_list = [] @@ -346,13 +409,15 @@ def prepare_latents( self, noise, num_inference_steps, - image = None, - strength = None, + image=None, + strength=None, ): if self.is_img2img: raise NotImplementedError("Image-to-image not supported yet.") elif self.is_sdxl: - sample, add_time_ids, step_indexes, timesteps = self.scheduler.initialize_sdxl(noise, num_inference_steps) + sample, add_time_ids, step_indexes, timesteps = ( + self.scheduler.initialize_sdxl(noise, num_inference_steps) + ) return sample, add_time_ids, step_indexes, timesteps elif self.is_sd3: raise NotImplementedError("Stable Diffusion 3 not supported yet.") @@ -381,49 +446,13 @@ def get_rand_latents(self, seed, batch_count): samples.append(rand_sample) return samples - def load_scheduler( - self, - scheduler_id: str, - steps: int = 30, - ): - self.scheduler = schedulers.get_scheduler(self.base_model_name, self.scheduler_id) - if self.is_sd3: - scheduler_device = self.mmdit.device - else: - scheduler_device = self.unet.device - if not self.cpu_scheduling: - self.scheduler = None - self.num_inference_steps = steps - self.scheduler_id = scheduler_id - scheduler_path = f"{scheduler_id}Scheduler_{self.num_inference_steps}" - if not os.path.exists(scheduler_path): - scheduler_path, _ = self.export_submodel("scheduler") - try: - self.scheduler = schedulers.SharkSchedulerWrapper( - scheduler_device, - scheduler_path, - ) - except: - print("JIT export of scheduler failed. Loading CPU scheduler.") - self.cpu_scheduling = True - if self.cpu_scheduling: - scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id) - self.scheduler = schedulers.SharkSchedulerCPUWrapper( - scheduler, - self.batch_size, - scheduler_device, - latents_dtype=self.latents_dtype, - ) - if self.use_punet: - self.scheduler.use_punet = True - def _produce_latents_sd( self, sample, prompt_embeds, negative_prompt_embeds, steps, - guidance_scale, + guidance_scale, ): image = None strength = 0 @@ -440,10 +469,7 @@ def _produce_latents_sd( timestep, ] unet_inputs.extend([text_embeddings, guidance_scale]) - latents = self.unet( - "run_forward", - unet_inputs - ) + latents = self.unet("run_forward", unet_inputs) sample = self.scheduler.step( torch.tensor(latents.to_host(), dtype=self.torch_dtype), t, sample ).prev_sample @@ -455,7 +481,7 @@ def _produce_latents_sdxl( prompt_embeds, add_text_embeds, steps, - guidance_scale, + guidance_scale, ): image = None strength = 0 @@ -522,7 +548,7 @@ def generate_images( if steps and needs_new_scheduler: self.num_inference_steps = steps self.load_scheduler(scheduler_id, steps) - + pipe_start = time.time() numpy_images = [] @@ -530,7 +556,9 @@ def generate_images( # Tokenize prompt and negative prompt. if self.is_sdxl: - prompt_embeds, negative_embeds = self.encode_prompts_sdxl(prompt, negative_prompt) + prompt_embeds, negative_embeds = self.encode_prompts_sdxl( + prompt, negative_prompt + ) else: prompt_embeds, negative_embeds = encode_prompt( self, prompt, negative_prompt @@ -545,13 +573,9 @@ def generate_images( guidance_scale, ] if self.is_sdxl: - latents = self._produce_latents_sdxl( - *produce_latents_input - ) + latents = self._produce_latents_sdxl(*produce_latents_input) else: - latents = self._produce_latents_sd( - *produce_latents_input - ) + latents = self._produce_latents_sd(*produce_latents_input) image = self.vae("decode", [latents]) numpy_images.append(image) pipe_end = time.time() diff --git a/models/turbine_models/custom_models/sd_inference/tokenization.py b/models/turbine_models/custom_models/sd_inference/tokenization.py index d7c2bbf54..8c29d2d3c 100644 --- a/models/turbine_models/custom_models/sd_inference/tokenization.py +++ b/models/turbine_models/custom_models/sd_inference/tokenization.py @@ -4,163 +4,171 @@ import torch import numpy as np + # The following is copied from Diffusers' "encode_prompt" function in the StableDiffusion pipeline. # It has been lightly augmented to work with the SHARK-Turbine pipeline. def encode_prompt( - pipe, - prompt, - negative_prompt=None, - num_images_per_prompt = 1, - do_classifier_free_guidance = True, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - # if lora_scale is not None and pipe.use_lora: - # pipe._lora_scale = lora_scale - - # # dynamically adjust the LoRA scale - # if not USE_PEFT_BACKEND: - # adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale) - # else: - # scale_lora_layers(pipe.text_encoder, lora_scale) - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) + pipe, + prompt, + negative_prompt=None, + num_images_per_prompt=1, + do_classifier_free_guidance=True, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, +): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + # if lora_scale is not None and pipe.use_lora: + # pipe._lora_scale = lora_scale + + # # dynamically adjust the LoRA scale + # if not USE_PEFT_BACKEND: + # adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale) + # else: + # scale_lora_layers(pipe.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + # if pipe.use_textual_inversion: + # prompt = pipe.maybe_convert_prompt(prompt, pipe.tokenizer) + + text_inputs = pipe.tokenizer( + prompt, + padding="max_length", + max_length=pipe.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = pipe.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = pipe.tokenizer.batch_decode( + untruncated_ids[:, pipe.model_max_length - 1 : -1] + ) + print("The following text was removed due to truncation:", removed_text) + if pipe.text_encoder.metadata.get("use_attention_mask"): + attention_mask = text_inputs.attention_mask + prompt_embeds = pipe.text_encoder( + "encode_tokens_attn_mask", [text_input_ids, attention_mask] + ) else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: process multi-vector tokens if necessary - # if pipe.use_textual_inversion: - # prompt = pipe.maybe_convert_prompt(prompt, pipe.tokenizer) - - text_inputs = pipe.tokenizer( - prompt, - padding="max_length", - max_length=pipe.model_max_length, - truncation=True, - return_tensors="pt", + attention_mask = None + prompt_embeds = pipe.text_encoder("encode_tokens", [text_input_ids]) + prompt_embeds = prompt_embeds[0] + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." ) - text_input_ids = text_inputs.input_ids - untruncated_ids = pipe.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = pipe.tokenizer.batch_decode( - untruncated_ids[:, pipe.model_max_length - 1 : -1] - ) - print("The following text was removed due to truncation:", removed_text) - if pipe.text_encoder.metadata.get("use_attention_mask"): - attention_mask = text_inputs.attention_mask - prompt_embeds = pipe.text_encoder("encode_tokens_attn_mask", [text_input_ids, attention_mask]) - else: - attention_mask = None - prompt_embeds = pipe.text_encoder("encode_tokens", [text_input_ids]) - prompt_embeds = prompt_embeds[0] - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: process multi-vector tokens if necessary - # if pipe.use_textual_inversion: - # uncond_tokens = pipe.maybe_convert_prompt(uncond_tokens, pipe.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = pipe.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + # if pipe.use_textual_inversion: + # uncond_tokens = pipe.maybe_convert_prompt(uncond_tokens, pipe.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = pipe.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if pipe.text_encoder.metadata.get("use_attention_mask"): + attention_mask = uncond_input.attention_mask + negative_prompt_embeds = pipe.text_encoder( + "encode_tokens_attn_mask", + [ + uncond_input.input_ids, + attention_mask, + ], + ) + else: + attention_mask = None + negative_prompt_embeds = pipe.text_encoder( + "encode_tokens", + [ + uncond_input.input_ids, + ], + ) + + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # if pipe.use_lora: + # Retrieve the original scale by scaling back the LoRA layers + # unimplemented + # unscale_lora_layers(pipe.text_encoder, lora_scale) - if pipe.text_encoder.metadata.get("use_attention_mask"): - attention_mask = uncond_input.attention_mask - negative_prompt_embeds = pipe.text_encoder( - "encode_tokens_attn_mask", - [ - uncond_input.input_ids, - attention_mask, - ], - ) - else: - attention_mask = None - negative_prompt_embeds = pipe.text_encoder( - "encode_tokens", - [ - uncond_input.input_ids, - ], - ) - - - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - #if pipe.use_lora: - # Retrieve the original scale by scaling back the LoRA layers - # unimplemented - # unscale_lora_layers(pipe.text_encoder, lora_scale) - - return prompt_embeds, negative_prompt_embeds + return prompt_embeds, negative_prompt_embeds diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index f76b2710c..9645ebfdc 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -63,7 +63,8 @@ def decode(self, inp): def encode(self, inp): latents = self.vae.encode(inp).latent_dist.sample() return self.vae.config.scaling_factor * latents - + + class SD3VaeModel(torch.nn.Module): def __init__( self, @@ -92,6 +93,7 @@ def encode(self, inp): latent = self.vae.encode(image_torch) return latent + def export_vae_model( hf_model_name, batch_size, @@ -135,7 +137,7 @@ def export_vae_model( attn_spec=attn_spec, ) return vmfb_path - + if "stable-diffusion-3" in hf_model_name: vae_model = SD3VaeModel(hf_model_name) else: @@ -198,7 +200,7 @@ class CompiledVae(CompiledModule): inst = CompiledVae(context=Context(), import_to="IMPORT") module = CompiledModule.get_mlir_module(inst) - + model_metadata_decode = { "model_name": "vae_decode", "input_shapes": [input_latents_shape], diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index 2b2ea00ba..a1de8e031 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -269,7 +269,7 @@ def encode_prompts_turbo( model_metadata_encode = { "model_name": hf_model_name + "_text_encoder", "input_shapes": [str((1, max_length)) for i in range(4)], - "input_dtypes": ['int64' for i in range(4)], + "input_dtypes": ["int64" for i in range(4)], "use_attention_mask": False, } module = AddMetadataPass(module, model_metadata_encode, "encode_prompts").run() diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index cdd548809..dc64f58dc 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -24,6 +24,7 @@ class UnetModel(torch.nn.Module): def __init__(self, hf_model_name, hf_auth_token=None, precision="fp32"): from diffusers import UNet2DConditionModel + super().__init__() if precision == "fp16": try: @@ -70,8 +71,12 @@ def forward( )[0] return noise_pred + def get_punet_model(hf_model_name, external_weight_path, precision="i8"): - from sharktank.models.punet.model import Unet2DConditionModel as punet_unet, ClassifierFreeGuidanceUnetModel as CFGPunetModel + from sharktank.models.punet.model import ( + Unet2DConditionModel as punet_unet, + ClassifierFreeGuidanceUnetModel as CFGPunetModel, + ) if precision == "i8": repo_id = "amd-shark/sdxl-quant-models" @@ -81,10 +86,12 @@ def get_punet_model(hf_model_name, external_weight_path, precision="i8"): repo_id = hf_model_name subfolder = "unet" revision = "76d28af79639c28a79fa5c6c6468febd3490a37e" + def download(filename): return hf_hub_download( repo_id=repo_id, subfolder=subfolder, filename=filename, revision=revision ) + results = { "config.json": download("config.json"), "params.safetensors": download("params.safetensors"), @@ -92,15 +99,29 @@ def download(filename): if precision == "i8": results["quant_params.json"] = download("quant_params.json") output_path = external_weight_path.split("unet")[0] + "punet_dataset_i8.irpa" - ds = get_punet_i8_dataset(results["config.json"], results["quant_params.json"], results["params.safetensors"], output_path, base_params=None) + ds = get_punet_i8_dataset( + results["config.json"], + results["quant_params.json"], + results["params.safetensors"], + output_path, + base_params=None, + ) else: - ds = None # get_punet_dataset(results["config.json"], results["params.safetensors"], base_params=None) - + ds = None # get_punet_dataset(results["config.json"], results["params.safetensors"], base_params=None) + cond_unet = punet_unet.from_dataset(ds) mdl = CFGPunetModel(cond_unet) return mdl -def get_punet_i8_dataset(config_json_path, quant_params_path, params_path, output_path="./punet_dataset_i8.irpa", quant_params_struct=None, base_params=None): + +def get_punet_i8_dataset( + config_json_path, + quant_params_path, + params_path, + output_path="./punet_dataset_i8.irpa", + quant_params_struct=None, + base_params=None, +): from sharktank.models.punet.tools.import_brevitas_dataset import ( _load_json, _load_theta, @@ -110,6 +131,7 @@ def get_punet_i8_dataset(config_json_path, quant_params_path, params_path, outpu Theta, InferenceTensor, ) + # Construct the pre-transform dataset. dataset_props = _get_dataset_props(_load_json(config_json_path)) quant_params_struct = _load_json(quant_params_path) @@ -118,9 +140,7 @@ def get_punet_i8_dataset(config_json_path, quant_params_path, params_path, outpu base_theta = None if base_params is not None: print("Initializing from base parameters:", args.base_params) - with safetensors.safe_open( - base_params, framework="pt", device="cpu" - ) as st: + with safetensors.safe_open(base_params, framework="pt", device="cpu") as st: base_theta = _load_theta(st) ds = Dataset(dataset_props, quant_theta if base_theta is None else base_theta) @@ -144,6 +164,7 @@ def get_punet_i8_dataset(config_json_path, quant_params_path, params_path, outpu ds.save(output_path, io_report_callback=print) return ds + @torch.no_grad() def export_unet_model( hf_model_name, @@ -214,7 +235,7 @@ def export_unet_model( if use_punet: dtype = torch.float16 - + utils.save_external_weights( mapper, unet_model, external_weights, external_weight_path ) @@ -287,7 +308,7 @@ def _forward( inputs, ): return module.forward(*inputs) - + class CompiledUnet(CompiledModule): run_forward = _forward diff --git a/models/turbine_models/tests/pipeline_test.py b/models/turbine_models/tests/pipeline_test.py index 4290e1446..658402652 100644 --- a/models/turbine_models/tests/pipeline_test.py +++ b/models/turbine_models/tests/pipeline_test.py @@ -43,6 +43,8 @@ def forward(self, x): torch.no_grad() + + def export_dummy_model(): model = TestModule() target = "x86_64-unknown-linux-gnu" @@ -132,5 +134,6 @@ def test_pipeline_metadata(self): expected, metadata ) + if __name__ == "__main__": unittest.main() From 0147cc90c5d1e468fb022822ade99ea19d834c60 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 8 Jul 2024 15:09:44 -0500 Subject: [PATCH 04/12] Partitioned Unet I8 support. --- .../custom_models/pipeline_base.py | 19 +++--- .../custom_models/sd_inference/sd_cmd_opts.py | 9 ++- .../custom_models/sd_inference/sd_pipeline.py | 65 ++++++++++++++----- .../custom_models/sd_inference/utils.py | 9 +-- .../custom_models/sdxl_inference/unet.py | 8 ++- 5 files changed, 75 insertions(+), 35 deletions(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index f224c4854..45cfa8edc 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -185,9 +185,9 @@ def _validate_or_convert_inputs(self, function_name, inputs): logging.warning( f"No metadata found for {self.module_name}['{function_name}']." ) - for i in inputs: + for idx, i in enumerate(inputs): if not isinstance(i, ireert.DeviceArray): - i = ireert.asdevicearray(self.device, i) + inputs[idx] = ireert.asdevicearray(self.device, i) def _output_cast(self, output): if isinstance(output, tuple): @@ -233,9 +233,7 @@ def __call__(self, function_name, inputs: list): output = self._run_and_benchmark(function_name, inputs) else: output = self._run(function_name, inputs) - print("Output before cast: ", output) output = self._output_cast(output) - print("Output after cast: ", output) return output @@ -441,24 +439,24 @@ def is_prepared(self, vmfbs, weights): mlir_keywords.extend( [ "mlir", - self.map[key]["precision"], ] ) keywords.extend( [ "vmfb", self.map[key]["target"], - self.map[key]["precision"], ] ) - print(keywords) neg_keywords = [] for kw in keywords: if kw.startswith("!"): neg_keywords.append(kw.strip("!")) keywords.remove(kw) + mlir_keywords.remove(kw) avail_files = os.listdir(pipeline_dir) candidates = [] + # print("MLIR KEYS: ", mlir_keywords) + # print("AVAILABLE FILES: ", avail_files) for filename in avail_files: if all(str(x) in filename for x in keywords) and not any( x in filename for x in neg_keywords @@ -575,7 +573,9 @@ def export_submodel( input_mlir = None else: input_mlir = None - self.map[submodel]["export_args"]["input_mlir"] = input_mlir + self.map[submodel]["export_args"]["input_mlir"] = self.map[submodel].get( + "mlir", input_mlir + ) match submodel: case "unetloop": # SDXL ONLY FOR NOW @@ -656,10 +656,9 @@ def export_submodel( self.map[submodel]["weights"] = None case _: export_args = self.map[submodel].get("export_args", {}) - if self.map[submodel].get("input_mlir"): - export_args["input_mlir"] = self.map[submodel].get("mlir") if weights_only: export_args["weights_only"] = True + export_args["input_mlir"] = None if export_args: exported = self.map[submodel]["export_fn"](**export_args) else: diff --git a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py index 12a4e3d3b..8c68ad06c 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py +++ b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py @@ -176,11 +176,16 @@ def is_valid_file(arg): p.add_argument( "--vae_decomp_attn", - type=bool, - default=False, + action="store_true", help="Decompose attention for VAE decode only at fx graph level", ) +p.add_argument( + "--unet_decomp_attn", + action="store_true", + help="Decompose attention for unet only at fx graph level", +) + p.add_argument( "--use_i8_punet", action="store_true", diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 9c977208d..eaf5cb15a 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -253,6 +253,10 @@ def __init__( if "height" in sd_model_map[submodel]["export_args"]: sd_model_map[submodel]["export_args"]["height"] = height sd_model_map[submodel]["export_args"]["width"] = width + if "decomp_attn" in sd_model_map[submodel]["export_args"]: + sd_model_map[submodel]["export_args"]["decomp_attn"] = decomp_attn[ + submodel + ] super().__init__( sd_model_map, device, @@ -326,11 +330,19 @@ def __init__( if self.use_i8_punet: self.map["unet"]["export_args"]["precision"] = "i8" self.map["unet"]["export_args"]["use_punet"] = True - for i in self.map["unet"]["keywords"]: - i = i.replace("fp16", "i8").replace("fp32", "i8") self.map["unet"]["keywords"].append("punet") + self.map["unet"]["module_name"] = "module" + self.map["unet"]["function_name"] = "main" + self.map["unet"]["export_args"]["external_weight_path"] = ( + utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa" + ) + for idx, word in enumerate(self.map["unet"]["keywords"]): + if word in ["fp32", "fp16"]: + self.map["unet"]["keywords"][idx] = "i8" + break else: self.map["unet"]["keywords"].append("!punet") + self.map["unet"]["function_name"] = "run_forward" # LOAD @@ -469,7 +481,7 @@ def _produce_latents_sd( timestep, ] unet_inputs.extend([text_embeddings, guidance_scale]) - latents = self.unet("run_forward", unet_inputs) + latents = self.unet(self.map["unet"]["function_name"], unet_inputs) sample = self.scheduler.step( torch.tensor(latents.to_host(), dtype=self.torch_dtype), t, sample ).prev_sample @@ -488,13 +500,6 @@ def _produce_latents_sdxl( latents, add_time_ids, step_indexes, timesteps = self.prepare_latents( sample, self.num_inference_steps, image, strength ) - iree_unet_inputs = [ - prompt_embeds, - add_text_embeds, - add_time_ids, - ] - if self.use_punet: - iree_unet_inputs.append(guidance_scale) for i, t in tqdm(enumerate(timesteps)): if self.cpu_scheduling: step_index = i @@ -504,13 +509,32 @@ def _produce_latents_sdxl( latents, t, ) + unet_inputs = [ + latent_model_input, + t, + prompt_embeds, + add_text_embeds, + add_time_ids, + ] + if self.use_punet: + unet_inputs.append( + ireert.asdevicearray( + self.unet.device, + [guidance_scale], + dtype=self.map["unet"]["np_dtype"], + ) + ) + unet_inputs[1] = ireert.asdevicearray( + self.unet.device, t, dtype="int32" + ) + for inp_idx, inp in enumerate(unet_inputs): + if not isinstance(inp, ireert.DeviceArray): + unet_inputs[inp_idx] = ireert.asdevicearray( + self.unet.device, inp, dtype=self.map["unet"]["np_dtype"] + ) noise_pred = self.unet( - "run_forward", - [ - latent_model_input, - t, - *iree_unet_inputs, - ], + self.map["unet"]["function_name"], + unet_inputs, ) latents = self.scheduler.step( noise_pred, @@ -623,7 +647,14 @@ def numpy_to_pil_image(images): } if not args.pipeline_dir: args.pipeline_dir = utils.create_safe_name(args.hf_model_name, "") - + if any(x for x in [args.vae_decomp_attn, args.unet_decomp_attn]): + args.decomp_attn = { + "text_encoder": args.decomp_attn, + "unet": ( + args.unet_decomp_attn if args.unet_decomp_attn else args.decomp_attn + ), + "vae": args.vae_decomp_attn if args.vae_decomp_attn else args.decomp_attn, + } sd_pipe = SharkSDPipeline( args.hf_model_name, args.height, diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 7dd045493..4673f4f24 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -229,7 +229,7 @@ def compile_to_vmfb( # This 'attn_spec' handles a linalg_ext.attention op lowering to mfma instructions for capable targets. # This is a temporary solution, and should be removed or largely disabled once the functionality of # the TD spec is implemented in C++. - if attn_spec in ["default", "mfma"]: + if attn_spec in ["default", "mfma", "i8"]: attn_spec = get_mfma_spec_path( target_triple, os.path.dirname(safe_name), masked_attention ) @@ -300,7 +300,7 @@ def compile_to_vmfb( return safe_vmfb_name + ".vmfb" -def create_safe_name(hf_model_name, model_name_str): +def create_safe_name(hf_model_name, model_name_str=""): safe_name = hf_model_name.split("/")[-1].strip() + model_name_str safe_name = re.sub("-", "_", safe_name) safe_name = re.sub("\.", "_", safe_name) @@ -309,7 +309,7 @@ def create_safe_name(hf_model_name, model_name_str): def get_mfma_spec_path(target_chip, save_dir, masked_attention=False): if not masked_attention: - url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/no_pad/attention_and_matmul_spec_mfma.mlir" + url = "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/specs/attention_and_matmul_spec.mlir" else: url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir" attn_spec = urlopen(url).read().decode("utf-8") @@ -331,7 +331,8 @@ def get_wmma_spec_path(target_chip, save_dir, masked_attention=False): else: return None attn_spec = urlopen(url).read().decode("utf-8") - spec_path = os.path.join(save_dir, "attention_and_matmul_spec_wmma.mlir") + suffix = "masked" if masked_attention else "" + spec_path = os.path.join(save_dir, f"attention_and_matmul_spec_wmma{suffix}.mlir") with open(spec_path, "w") as f: f.write(attn_spec) return spec_path diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index dc64f58dc..d8b6c41d2 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -189,11 +189,11 @@ def export_unet_model( use_punet=False, ): if use_punet: - unet_model = get_punet_model(hf_model_name, external_weight_path, precision) submodel_name = "punet" else: - unet_model = UnetModel(hf_model_name, hf_auth_token, precision) submodel_name = "unet" + if (not decomp_attn) and use_punet: + attn_spec = "i8" safe_name = utils.create_safe_name( hf_model_name, f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_{submodel_name}", @@ -216,6 +216,10 @@ def export_unet_model( attn_spec=attn_spec, ) return vmfb_path + elif use_punet: + unet_model = get_punet_model(hf_model_name, external_weight_path, precision) + else: + unet_model = UnetModel(hf_model_name, hf_auth_token, precision) mapper = {} np_dtypes = { From 4bc57c22b27eb8b3b80b71ad2f1a851b56127d30 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 8 Jul 2024 15:55:37 -0500 Subject: [PATCH 05/12] Fixes for non-punet attn spec --- .../custom_models/sd_inference/utils.py | 20 ++++++++++++------- .../custom_models/sdxl_inference/unet.py | 6 +++++- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 4673f4f24..9b4e6159b 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -229,11 +229,14 @@ def compile_to_vmfb( # This 'attn_spec' handles a linalg_ext.attention op lowering to mfma instructions for capable targets. # This is a temporary solution, and should be removed or largely disabled once the functionality of # the TD spec is implemented in C++. - if attn_spec in ["default", "mfma", "i8"]: + + if attn_spec in ["default", "mfma", "punet"]: + use_punet = True if attn_spec in ["punet", "i8"] else False attn_spec = get_mfma_spec_path( - target_triple, os.path.dirname(safe_name), masked_attention + target_triple, os.path.dirname(safe_name), masked_attention, use_punet=use_punet ) flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): attn_spec = get_wmma_spec_path( target_triple, os.path.dirname(safe_name), masked_attention @@ -307,15 +310,18 @@ def create_safe_name(hf_model_name, model_name_str=""): return safe_name -def get_mfma_spec_path(target_chip, save_dir, masked_attention=False): - if not masked_attention: +def get_mfma_spec_path(target_chip, save_dir, masked_attention=False, use_punet=False): + if use_punet: + suffix = "_punet" url = "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/specs/attention_and_matmul_spec.mlir" + elif not masked_attention: + suffix = "" + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/no_pad/attention_and_matmul_spec_mfma.mlir" else: + suffix = "_pad" url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir" attn_spec = urlopen(url).read().decode("utf-8") - spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir") - if os.path.exists(spec_path): - return spec_path + spec_path = os.path.join(save_dir, f"attention_and_matmul_spec_mfma{suffix}.mlir") with open(spec_path, "w") as f: f.write(attn_spec) return spec_path diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index d8b6c41d2..39429e36a 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -193,7 +193,11 @@ def export_unet_model( else: submodel_name = "unet" if (not decomp_attn) and use_punet: - attn_spec = "i8" + attn_spec = "punet" + elif (not decomp_attn) and "gfx9" in target: + attn_spec = "mfma" + elif (not decomp_attn) and "gfx11" in target: + attn_spec = "wmma" safe_name = utils.create_safe_name( hf_model_name, f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_{submodel_name}", From 34d3d84facbde7c2113c62cf75ea3fb06b1e7169 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Jul 2024 15:07:56 -0500 Subject: [PATCH 06/12] Fixes to punet, device IDs --- models/requirements.txt | 4 +- .../custom_models/pipeline_base.py | 5 +- .../custom_models/sd_inference/sd_pipeline.py | 6 +- .../custom_models/sd_inference/utils.py | 7 +- .../sdxl_inference/sdxl_prompt_encoder.py | 1 + .../custom_models/sdxl_inference/unet.py | 90 +++++++------------ 6 files changed, 48 insertions(+), 65 deletions(-) diff --git a/models/requirements.txt b/models/requirements.txt index bdd1892e8..b7b7d8d2b 100644 --- a/models/requirements.txt +++ b/models/requirements.txt @@ -1,5 +1,5 @@ protobuf -shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main +gguf transformers==4.37.1 torchsde accelerate @@ -12,3 +12,5 @@ azure-storage-blob einops pytest scipy +shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main +sharktank @ git+https://github.com/nod-ai/sharktank@main#subdirectory=sharktank diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 45cfa8edc..18ebd3cc6 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -309,7 +309,7 @@ def __init__( assert ( submodel in target.keys() ), f"Target arch for {submodel} not found." - self.map[submodel]["device"] = device[submodel] + self.map[submodel]["device"] = device[submodel].split("://")[0] self.map[submodel]["driver"] = utils.iree_device_map(device[submodel]) self.map[submodel]["target"] = target[submodel] else: @@ -317,7 +317,7 @@ def __init__( target, str ), "Device and target triple must be both dicts or both strings." for submodel in self.map.keys(): - self.map[submodel]["device"] = device + self.map[submodel]["device"] = device.split("://")[0] self.map[submodel]["driver"] = utils.iree_device_map(device) self.map[submodel]["target"] = target map_arguments = { @@ -328,7 +328,6 @@ def __init__( "external_weights": external_weights, "hf_model_name": hf_model_name, } - print(map_arguments) for arg in map_arguments.keys(): self.map = merge_arg_into_map(self.map, map_arguments[arg], arg) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index eaf5cb15a..3b2449453 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -331,7 +331,7 @@ def __init__( self.map["unet"]["export_args"]["precision"] = "i8" self.map["unet"]["export_args"]["use_punet"] = True self.map["unet"]["keywords"].append("punet") - self.map["unet"]["module_name"] = "module" + self.map["unet"]["module_name"] = "compiled_punet" self.map["unet"]["function_name"] = "main" self.map["unet"]["export_args"]["external_weight_path"] = ( utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa" @@ -525,7 +525,9 @@ def _produce_latents_sdxl( ) ) unet_inputs[1] = ireert.asdevicearray( - self.unet.device, t, dtype="int32" + self.unet.device, + t, + dtype=self.map["unet"]["np_dtype"], ) for inp_idx, inp in enumerate(unet_inputs): if not isinstance(inp, ireert.DeviceArray): diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 9b4e6159b..800a66c31 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -233,10 +233,13 @@ def compile_to_vmfb( if attn_spec in ["default", "mfma", "punet"]: use_punet = True if attn_spec in ["punet", "i8"] else False attn_spec = get_mfma_spec_path( - target_triple, os.path.dirname(safe_name), masked_attention, use_punet=use_punet + target_triple, + os.path.dirname(safe_name), + masked_attention, + use_punet=use_punet, ) flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) - + elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): attn_spec = get_wmma_spec_path( target_triple, os.path.dirname(safe_name), masked_attention diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index a1de8e031..e55eada39 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -171,6 +171,7 @@ def export_prompt_encoder( attn_spec=None, weights_only=False, batch_input=False, + decomp_attn=False, # Compatibility ): do_classifier_free_guidance = True diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 39429e36a..7ae48836c 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -74,9 +74,10 @@ def forward( def get_punet_model(hf_model_name, external_weight_path, precision="i8"): from sharktank.models.punet.model import ( - Unet2DConditionModel as punet_unet, - ClassifierFreeGuidanceUnetModel as CFGPunetModel, + Unet2DConditionModel as sharktank_unet2d, + ClassifierFreeGuidanceUnetModel as sharktank_CFGPunetModel, ) + from sharktank.utils import cli if precision == "i8": repo_id = "amd-shark/sdxl-quant-models" @@ -99,70 +100,45 @@ def download(filename): if precision == "i8": results["quant_params.json"] = download("quant_params.json") output_path = external_weight_path.split("unet")[0] + "punet_dataset_i8.irpa" - ds = get_punet_i8_dataset( + ds = get_punet_dataset( results["config.json"], - results["quant_params.json"], results["params.safetensors"], output_path, + results["quant_params.json"], base_params=None, ) else: - ds = None # get_punet_dataset(results["config.json"], results["params.safetensors"], base_params=None) + ds = get_punet_dataset( + results["config.json"], + results["params.safetensors"], + output_path, + base_params=None, + ) - cond_unet = punet_unet.from_dataset(ds) - mdl = CFGPunetModel(cond_unet) + cond_unet = sharktank_unet2d.from_dataset(ds) + mdl = sharktank_CFGPunetModel(cond_unet) return mdl -def get_punet_i8_dataset( +def get_punet_dataset( config_json_path, - quant_params_path, params_path, output_path="./punet_dataset_i8.irpa", + quant_params_path=None, quant_params_struct=None, base_params=None, ): - from sharktank.models.punet.tools.import_brevitas_dataset import ( - _load_json, - _load_theta, - _get_dataset_props, - apply_per_layer_quant, - Dataset, - Theta, - InferenceTensor, + from sharktank.models.punet.tools import import_brevitas_dataset + + import_brevitas_dataset.main( + [ + f"--config-json={config_json_path}", + f"--params={params_path}", + f"--quant-params={quant_params_path}", + f"--output-irpa-file={output_path}", + ] ) - - # Construct the pre-transform dataset. - dataset_props = _get_dataset_props(_load_json(config_json_path)) - quant_params_struct = _load_json(quant_params_path) - with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: - quant_theta = _load_theta(st) - base_theta = None - if base_params is not None: - print("Initializing from base parameters:", args.base_params) - with safetensors.safe_open(base_params, framework="pt", device="cpu") as st: - base_theta = _load_theta(st) - - ds = Dataset(dataset_props, quant_theta if base_theta is None else base_theta) - - # The quant_params_struct has quantization parameter structs keyed by full - # layer name. We process each of these in turn to produce a per-layer - # quantization scheme where no quantized tensors escape their layer. - updated_tensors: dict[str, InferenceTensor] = {} - for layer_name, qp in quant_params_struct.items(): - print(f"Applying per-layer quants: {layer_name}") - apply_per_layer_quant(quant_theta, layer_name, qp, updated_tensors) - - # Apply updates into a new Theta. - theta = base_theta if base_theta is not None else quant_theta - flat_tensors = theta.flatten() - flat_tensors.update(updated_tensors) - ds.root_theta = Theta(flat_tensors) - - # TODO: Post-process to introduce fused cross-layer connections. - - ds.save(output_path, io_report_callback=print) - return ds + return import_brevitas_dataset.Dataset.load(output_path) @torch.no_grad() @@ -238,15 +214,17 @@ def export_unet_model( } dtype = torch_dtypes[precision] np_dtype = np_dtypes[precision] - if precision == "fp16": + + if precision == "fp16" and not use_punet: unet_model = unet_model.half() if use_punet: dtype = torch.float16 - utils.save_external_weights( - mapper, unet_model, external_weights, external_weight_path - ) + if not use_punet: + utils.save_external_weights( + mapper, unet_model, external_weights, external_weight_path + ) if weights_only: return external_weight_path @@ -296,13 +274,11 @@ def export_unet_model( from_current=True, add_ops=decomp_list, ): - if external_weights: - externalize_module_parameters(unet_model) if use_punet: output = export( unet_model, kwargs=example_forward_args_dict, - module_name="compiled_unet", + module_name="compiled_punet", ) module = output.mlir_module else: @@ -398,7 +374,7 @@ class CompiledUnet(CompiledModule): exit() safe_name = utils.create_safe_name( args.hf_model_name, - f"_bs{args.batch_size}_{args.max_length}_{args.height}x{args.width}_{args.precision}_unet", + f"_bs{args.batch_size}_{args.max_length}_{args.height}x{args.width}_{args.precision}_{'p' if args.use_i8_punet else ''}unet", ) if args.compile_to != "vmfb": with open(f"{safe_name}.mlir", "w+") as f: From 86f3a80f51eb0ca4f61c34297bffa3a292817244 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Jul 2024 15:15:50 -0500 Subject: [PATCH 07/12] Add TODO to encode comment in VAE export script. --- models/turbine_models/custom_models/sd_inference/vae.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 9645ebfdc..6fdbfd958 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -183,6 +183,7 @@ def export_vae_model( ): fxb = FxProgramsBuilder(vae_model) + # TODO: fix issues with exporting the encode function. # @fxb.export_program(args=(encode_args,)) # def _encode(module, inputs,): # return module.encode(*inputs) From 4c01167965d2c53112c095858b5fd224b1370341 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 9 Jul 2024 22:45:40 -0500 Subject: [PATCH 08/12] Test fixes, reduce unet model signature variance --- .../custom_models/pipeline_base.py | 6 +- .../custom_models/sd_inference/clip_runner.py | 2 +- .../custom_models/sd_inference/schedulers.py | 17 +- .../custom_models/sd_inference/sd_pipeline.py | 33 +- .../custom_models/sd_inference/unet.py | 18 +- .../custom_models/sd_inference/unet_runner.py | 51 ++- .../custom_models/sd_inference/vae.py | 1 + .../custom_models/sd_inference/vae_runner.py | 12 +- .../sdxl_inference/sdxl_prompt_encoder.py | 6 +- .../custom_models/sdxl_inference/unet.py | 37 ++- .../sdxl_inference/unet_runner.py | 4 +- .../sdxl_inference/vae_runner.py | 6 +- models/turbine_models/tests/conftest.py | 4 +- models/turbine_models/tests/sd3_test.py | 229 +++----------- models/turbine_models/tests/sd_test.py | 292 +++++------------- models/turbine_models/tests/sdxl_test.py | 268 +++++++--------- 16 files changed, 330 insertions(+), 656 deletions(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 18ebd3cc6..1da3aacfb 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -662,9 +662,9 @@ def export_submodel( exported = self.map[submodel]["export_fn"](**export_args) else: exported = self.map[submodel]["export_fn"]() - if not self.map[submodel].get("weights") and os.path.exists( - self.map[submodel]["export_args"].get("external_weight_path") - ): + if not self.map[submodel].get("weights") and self.map[submodel][ + "export_args" + ].get("external_weights", None): self.map[submodel]["weights"] = self.map[submodel][ "export_args" ].get("external_weight_path", None) diff --git a/models/turbine_models/custom_models/sd_inference/clip_runner.py b/models/turbine_models/custom_models/sd_inference/clip_runner.py index fe5310ff6..da0908fad 100644 --- a/models/turbine_models/custom_models/sd_inference/clip_runner.py +++ b/models/turbine_models/custom_models/sd_inference/clip_runner.py @@ -55,7 +55,7 @@ def run_clip( if "google/t5" in hf_model_name: inp += [ireert.asdevicearray(runner.config.device, example_input)] - results = runner.ctx.modules.compiled_clip["main"](*inp) + results = runner.ctx.modules.compiled_text_encoder["encode_tokens"](*inp) return results diff --git a/models/turbine_models/custom_models/sd_inference/schedulers.py b/models/turbine_models/custom_models/sd_inference/schedulers.py index cba6cfdf6..1a8cd8858 100644 --- a/models/turbine_models/custom_models/sd_inference/schedulers.py +++ b/models/turbine_models/custom_models/sd_inference/schedulers.py @@ -139,11 +139,12 @@ def __init__( latents_dtype, conditional_timesteps=False, ): - self.do_classifier_free_guidance = True self.module = scheduler self.dest = dest_device self.batch_size = batch_size self.timesteps = None + self.do_guidance = True + self.repeat_sample = True # Enable this on init for models that use a pair of timestep values per unet step. # this includes sd3 and some others we don't support yet. @@ -152,7 +153,6 @@ def __init__( self.conditional_timesteps = conditional_timesteps self.dtype = latents_dtype - self.use_punet = False self.torch_dtype = ( torch.float32 if latents_dtype == "float32" else torch.float16 ) @@ -170,7 +170,7 @@ def initialize_sdxl(self, sample, num_inference_steps): crops_coords_top_left = (0, 0) add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids], dtype=self.torch_dtype) - if self.do_classifier_free_guidance: + if self.do_guidance: add_time_ids = torch.cat([add_time_ids] * 2, dim=0) add_time_ids = add_time_ids.repeat(self.batch_size, 1).type( self.torch_dtype @@ -189,7 +189,7 @@ def initialize_sd(self, sample, num_inference_steps): return sample, timesteps def scale_model_input(self, sample, t, t_uncond=None): - if self.do_classifier_free_guidance and not self.use_punet: + if self.repeat_sample: sample = torch.cat([sample] * 2) if self.conditional_timesteps: if t_uncond: @@ -201,14 +201,16 @@ def scale_model_input(self, sample, t, t_uncond=None): scaled = self.module.scale_model_input(sample, t) return scaled, t - def step(self, noise_pred, t, latents, guidance_scale): + def step(self, noise_pred, t, latents, guidance_scale=None): if isinstance(t, ireert.DeviceArray): t = torch.tensor(t.to_host()) if isinstance(noise_pred, ireert.DeviceArray): noise_pred = torch.tensor(noise_pred.to_host()) + elif isinstance(noise_pred, np.ndarray): + noise_pred = torch.tensor(noise_pred) if isinstance(guidance_scale, ireert.DeviceArray): guidance_scale = torch.tensor(guidance_scale.to_host()) - if self.do_classifier_free_guidance and not self.use_punet: + if self.do_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond @@ -217,8 +219,7 @@ def step(self, noise_pred, t, latents, guidance_scale): noise_pred, t, latents, - return_dict=False, - )[0] + ).prev_sample @torch.no_grad() diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 3b2449453..868872479 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -300,9 +300,6 @@ def __init__( self.scheduler = None self.split_scheduler = True - if self.split_scheduler: - self.map.pop("unetloop") - self.map.pop("fullpipeline") self.base_model_name = ( hf_model_name @@ -313,6 +310,9 @@ def __init__( self.is_sdxl = "xl" in self.base_model_name self.is_sd3 = "stable-diffusion-3" in self.base_model_name if self.is_sdxl: + if self.split_scheduler: + self.map.pop("unetloop") + self.map.pop("fullpipeline") self.tokenizers = [ CLIPTokenizer.from_pretrained( self.base_model_name, subfolder="tokenizer" @@ -472,7 +472,7 @@ def _produce_latents_sd( sample, self.num_inference_steps, image, strength ) text_embeddings = torch.cat((negative_prompt_embeds, prompt_embeds), dim=0) - + self.scheduler.do_guidance = False for i, t in tqdm(enumerate(timesteps)): latent_model_input, _ = self.scheduler.scale_model_input(sample, t) timestep = torch.tensor([t]) @@ -480,11 +480,15 @@ def _produce_latents_sd( latent_model_input, timestep, ] - unet_inputs.extend([text_embeddings, guidance_scale]) + unet_inputs.extend([text_embeddings, [guidance_scale]]) latents = self.unet(self.map["unet"]["function_name"], unet_inputs) sample = self.scheduler.step( - torch.tensor(latents.to_host(), dtype=self.torch_dtype), t, sample - ).prev_sample + torch.tensor( + latents, dtype=torch_dtypes[self.map["unet"]["precision"]] + ), + t, + sample, + ) return sample def _produce_latents_sdxl( @@ -500,6 +504,8 @@ def _produce_latents_sdxl( latents, add_time_ids, step_indexes, timesteps = self.prepare_latents( sample, self.num_inference_steps, image, strength ) + self.scheduler.do_guidance = False + self.scheduler.repeat_sample = False for i, t in tqdm(enumerate(timesteps)): if self.cpu_scheduling: step_index = i @@ -515,15 +521,13 @@ def _produce_latents_sdxl( prompt_embeds, add_text_embeds, add_time_ids, + ireert.asdevicearray( + self.unet.device, + [guidance_scale], + dtype=self.map["unet"]["np_dtype"], + ), ] if self.use_punet: - unet_inputs.append( - ireert.asdevicearray( - self.unet.device, - [guidance_scale], - dtype=self.map["unet"]["np_dtype"], - ) - ) unet_inputs[1] = ireert.asdevicearray( self.unet.device, t, @@ -542,7 +546,6 @@ def _produce_latents_sdxl( noise_pred, t, latents, - guidance_scale, ) return latents diff --git a/models/turbine_models/custom_models/sd_inference/unet.py b/models/turbine_models/custom_models/sd_inference/unet.py index 4b37178d4..f111fb765 100644 --- a/models/turbine_models/custom_models/sd_inference/unet.py +++ b/models/turbine_models/custom_models/sd_inference/unet.py @@ -29,16 +29,24 @@ class UnetModel(torch.nn.Module): def __init__(self, hf_model_name): super().__init__() + self.do_classifier_free_guidance = True self.unet = UNet2DConditionModel.from_pretrained( hf_model_name, subfolder="unet", ) - def forward(self, latent_model_input, timestep, encoder_hidden_states): - unet_out = self.unet.forward( + def forward( + self, latent_model_input, timestep, encoder_hidden_states, guidance_scale + ): + noise_pred = self.unet.forward( latent_model_input, timestep, encoder_hidden_states, return_dict=False )[0] - return unet_out + if self.do_classifier_free_guidance: + noise_preds = noise_pred.chunk(2) + noise_pred = noise_preds[0] + guidance_scale * ( + noise_preds[1] - noise_preds[0] + ) + return noise_pred def export_unet_model( @@ -119,6 +127,7 @@ def export_unet_model( torch.empty(sample, dtype=dtype), torch.empty(1, dtype=dtype), torch.empty(encoder_hidden_states_sizes, dtype=dtype), + torch.empty(1, dtype=dtype), ] decomp_list = [] if decomp_attn == True: @@ -158,8 +167,9 @@ class CompiledUnet(CompiledModule): sample, (1,), encoder_hidden_states_sizes, + (1,), ], - "input_dtypes": [np_dtype for x in range(3)], + "input_dtypes": [np_dtype for x in range(4)], "output_shapes": [sample], "output_dtypes": [np_dtype], } diff --git a/models/turbine_models/custom_models/sd_inference/unet_runner.py b/models/turbine_models/custom_models/sd_inference/unet_runner.py index 172229e77..12e420960 100644 --- a/models/turbine_models/custom_models/sd_inference/unet_runner.py +++ b/models/turbine_models/custom_models/sd_inference/unet_runner.py @@ -15,16 +15,18 @@ def run_unet( hf_model_name, hf_auth_token, external_weight_path, + iree_dtype, ): runner = vmfbRunner(device, vmfb_path, external_weight_path) - inputs = [ - ireert.asdevicearray(runner.config.device, sample), - ireert.asdevicearray(runner.config.device, timestep), - ireert.asdevicearray(runner.config.device, encoder_hidden_states), - ireert.asdevicearray(runner.config.device, guidance_scale), + ireert.asdevicearray(runner.config.device, sample, dtype=iree_dtype), + ireert.asdevicearray(runner.config.device, timestep, dtype=iree_dtype), + ireert.asdevicearray( + runner.config.device, encoder_hidden_states, dtype=iree_dtype + ), + ireert.asdevicearray(runner.config.device, guidance_scale, dtype=iree_dtype), ] - results = runner.ctx.modules.compiled_unet["main"](*inputs) + results = runner.ctx.modules.compiled_unet["run_forward"](*inputs) return results @@ -36,32 +38,10 @@ def run_torch_unet( encoder_hidden_states, guidance_scale, ): - from diffusers import UNet2DConditionModel - - class UnetModel(torch.nn.Module): - def __init__(self, hf_model_name, hf_auth_token): - super().__init__() - self.unet = UNet2DConditionModel.from_pretrained( - hf_model_name, - subfolder="unet", - token=hf_auth_token, - ) - self.guidance_scale = 7.5 - - def forward(self, sample, timestep, encoder_hidden_states, guidance_scale): - samples = torch.cat([sample] * 2) - unet_out = self.unet.forward( - samples, timestep, encoder_hidden_states, return_dict=False - )[0] - noise_pred_uncond, noise_pred_text = unet_out.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - return noise_pred + from turbine_models.custom_models.sd_inference.unet import UnetModel unet_model = UnetModel( hf_model_name, - hf_auth_token, ) results = unet_model.forward( sample, timestep, encoder_hidden_states, guidance_scale @@ -72,15 +52,21 @@ def forward(self, sample, timestep, encoder_hidden_states, guidance_scale): if __name__ == "__main__": args = parser.parse_args() + iree_dtypes = { + "fp16": "float16", + "fp32": "float32", + } sample = torch.rand( - args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32 + args.batch_size * 2, 4, args.height // 8, args.width // 8, dtype=torch.float32 ) timestep = torch.zeros(1, dtype=torch.float32) guidance_scale = torch.Tensor([7.5], dtype=torch.float32) if args.hf_model_name == "CompVis/stable-diffusion-v1-4": - encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) + encoder_hidden_states = torch.rand(2, args.max_length, 768, dtype=torch.float32) elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base": - encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32) + encoder_hidden_states = torch.rand( + 2, args.max_length, 1024, dtype=torch.float32 + ) turbine_output = run_unet( args.device, @@ -92,6 +78,7 @@ def forward(self, sample, timestep, encoder_hidden_states, guidance_scale): args.hf_model_name, args.hf_auth_token, args.external_weight_path, + iree_dtypes[args.precision], ) print( "TURBINE OUTPUT:", diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 6fdbfd958..d9c0fd743 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -113,6 +113,7 @@ def export_vae_model( attn_spec=None, input_mlir=None, weights_only=False, + upload_ir=False, ): dtype = torch.float16 if precision == "fp16" else torch.float32 np_dtype = "float16" if precision == "fp16" else "float32" diff --git a/models/turbine_models/custom_models/sd_inference/vae_runner.py b/models/turbine_models/custom_models/sd_inference/vae_runner.py index cded33824..166021631 100644 --- a/models/turbine_models/custom_models/sd_inference/vae_runner.py +++ b/models/turbine_models/custom_models/sd_inference/vae_runner.py @@ -5,17 +5,19 @@ import torch -def run_vae(device, example_input, vmfb_path, hf_model_name, external_weight_path): +def run_vae_decode( + device, example_input, vmfb_path, hf_model_name, external_weight_path +): runner = vmfbRunner(device, vmfb_path, external_weight_path) inputs = [ireert.asdevicearray(runner.config.device, example_input)] - results = runner.ctx.modules.compiled_vae["main"](*inputs).to_host() + results = runner.ctx.modules.compiled_vae["decode"](*inputs).to_host() return results -def run_torch_vae(hf_model_name, variant, example_input): +def run_torch_vae_decode(hf_model_name, variant, example_input): from diffusers import AutoencoderKL class VaeModel(torch.nn.Module): @@ -87,7 +89,7 @@ def encode_inp(self, inp): args.batch_size, 3, args.height, args.width, dtype=torch.float32 ) print("generating turbine output:") - turbine_results = run_vae( + turbine_results = run_vae_decode( args.device, example_input, args.vmfb_path, @@ -104,7 +106,7 @@ def encode_inp(self, inp): print("generating torch output: ") from turbine_models.custom_models.sd_inference import utils - torch_output = run_torch_vae( + torch_output = run_torch_vae_decode( args.hf_model_name, args.hf_auth_token, args.variant, example_input ) print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index e55eada39..00b02d028 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -165,7 +165,7 @@ def export_prompt_encoder( device=None, target=None, ireec_flags=None, - exit_on_vmfb=True, + exit_on_vmfb=False, pipeline_dir=None, input_mlir=None, attn_spec=None, @@ -277,7 +277,7 @@ def encode_prompts_turbo( module_str = str(module) if compile_to != "vmfb": - return module_str, tokenizers + return module_str else: vmfb_path = utils.compile_to_vmfb( module_str, @@ -289,7 +289,7 @@ def encode_prompts_turbo( const_expr_hoisting=True, attn_spec=attn_spec, ) - return module_str, vmfb_path + return vmfb_path if __name__ == "__main__": diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index 7ae48836c..aab370143 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -49,18 +49,23 @@ def __init__(self, hf_model_name, hf_auth_token=None, precision="fp32"): auth_token=hf_auth_token, low_cpu_mem_usage=False, ) - # if "turbo" in hf_model_name: - # self.do_classifier_free_guidance = False - # else: self.do_classifier_free_guidance = True def forward( - self, latent_model_input, timestep, prompt_embeds, text_embeds, time_ids + self, + latent_model_input, + timestep, + prompt_embeds, + text_embeds, + time_ids, + guidance_scale, ): added_cond_kwargs = { "text_embeds": text_embeds, "time_ids": time_ids, } + if self.do_classifier_free_guidance: + latent_model_input = torch.cat([latent_model_input] * 2) noise_pred = self.unet.forward( latent_model_input, timestep, @@ -69,6 +74,11 @@ def forward( added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] + if self.do_classifier_free_guidance: + noise_preds = noise_pred.chunk(2) + noise_pred = noise_preds[0] + guidance_scale * ( + noise_preds[1] - noise_preds[0] + ) return noise_pred @@ -238,22 +248,17 @@ def export_unet_model( height // 8, width // 8, ] - prepared_latents = ( - batch_size * init_batch_dim, - 4, - height // 8, - width // 8, - ) time_ids_shape = (init_batch_dim * batch_size, 6) prompt_embeds_shape = (init_batch_dim * batch_size, max_length, 2048) text_embeds_shape = (init_batch_dim * batch_size, 1280) example_forward_args = [ - torch.empty(prepared_latents, dtype=dtype), + torch.empty(sample, dtype=dtype), torch.empty(1, dtype=dtype), torch.empty(prompt_embeds_shape, dtype=dtype), torch.empty(text_embeds_shape, dtype=dtype), torch.empty(time_ids_shape, dtype=dtype), + torch.tensor([7.5], dtype=dtype), ] example_forward_args_dict = { "sample": torch.rand(sample, dtype=dtype), @@ -282,6 +287,8 @@ def export_unet_model( ) module = output.mlir_module else: + if external_weights: + externalize_module_parameters(unet_model) fxb = FxProgramsBuilder(unet_model) @fxb.export_program( @@ -303,19 +310,17 @@ class CompiledUnet(CompiledModule): model_metadata_run_forward = { "model_name": "sd_unet", "input_shapes": [ - prepared_latents, + sample, (1,), prompt_embeds_shape, text_embeds_shape, time_ids_shape, + (1,), ], - "input_dtypes": [np_dtype for x in range(5)], + "input_dtypes": [np_dtype for x in range(6)], "output_shapes": [sample], "output_dtypes": [np_dtype], } - if use_punet: - model_metadata_run_forward["input_shapes"].append((1,)) - model_metadata_run_forward["input_dtypes"].append(np_dtype) module = AddMetadataPass(module, model_metadata_run_forward, "run_forward").run() diff --git a/models/turbine_models/custom_models/sdxl_inference/unet_runner.py b/models/turbine_models/custom_models/sdxl_inference/unet_runner.py index 9d0b405c3..c474982d7 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet_runner.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet_runner.py @@ -31,6 +31,7 @@ def run_unet( ireert.asdevicearray(runner.config.device, prompt_embeds), ireert.asdevicearray(runner.config.device, text_embeds), ireert.asdevicearray(runner.config.device, time_ids), + ireert.asdevicearray(runner.config.device, guidance_scale), ] results = runner.ctx.modules.compiled_unet["run_forward"](*inputs) @@ -56,6 +57,7 @@ def run_unet_steps( ireert.asdevicearray(runner.config.device, prompt_embeds), ireert.asdevicearray(runner.config.device, text_embeds), ireert.asdevicearray(runner.config.device, time_ids), + ireert.asdevicearray(runner.config.device, guidance_scale), ] for i, t in tqdm(enumerate(scheduler.timesteps)): timestep = t @@ -116,7 +118,7 @@ def run_torch_unet( sample = torch.rand( args.batch_size, 4, args.height // 8, args.width // 8, dtype=dtype ) - timestep = torch.ones(1, dtype=torch.int64) + timestep = torch.ones(1, dtype=dtype) prompt_embeds = torch.rand(2 * args.batch_size, args.max_length, 2048, dtype=dtype) text_embeds = torch.rand(2 * args.batch_size, 1280, dtype=dtype) time_ids = torch.rand(2 * args.batch_size, 6, dtype=dtype) diff --git a/models/turbine_models/custom_models/sdxl_inference/vae_runner.py b/models/turbine_models/custom_models/sdxl_inference/vae_runner.py index 98aae9a28..01767d322 100644 --- a/models/turbine_models/custom_models/sdxl_inference/vae_runner.py +++ b/models/turbine_models/custom_models/sdxl_inference/vae_runner.py @@ -15,7 +15,7 @@ def run_vae( ): runner = vmfbRunner(device, vmfb_path, external_weight_path) inputs = [ireert.asdevicearray(runner.config.device, example_input)] - results = runner.ctx.modules.compiled_vae["main"](*inputs) + results = runner.ctx.modules.compiled_vae["decode"](*inputs) return results @@ -28,9 +28,9 @@ def run_torch_vae(hf_model_name, custom_vae, variant, example_input): ) if variant == "decode": - results = vae_model.decode_inp(example_input) + results = vae_model.decode(example_input) elif variant == "encode": - results = vae_model.encode_inp(example_input) + results = vae_model.encode(example_input) np_torch_output = results.detach().cpu().numpy() return np_torch_output diff --git a/models/turbine_models/tests/conftest.py b/models/turbine_models/tests/conftest.py index 1c1952605..d93aa2e60 100644 --- a/models/turbine_models/tests/conftest.py +++ b/models/turbine_models/tests/conftest.py @@ -18,7 +18,7 @@ def pytest_addoption(parser): action="store", default="blurry, unsaturated, watermark, noisy, grainy, out of focus", ) - parser.addoption("--num_inference_steps", type=int, action="store", default=5) + parser.addoption("--num_inference_steps", type=int, action="store", default=2) parser.addoption("--guidance_scale", type=float, action="store", default=7.5) parser.addoption("--seed", type=float, action="store", default=0.0) parser.addoption("--vmfb_path", action="store", default="") @@ -50,4 +50,4 @@ def pytest_addoption(parser): parser.addoption("--in_channels", type=int, action="store", default=4) parser.addoption("--benchmark", action="store_true", default=False) parser.addoption("--tracy_profile", action="store_true", default=False) - parser.addoption("--compiled_pipeline", type=bool, default=True) + parser.addoption("--compiled_pipeline", type=bool, default=False) diff --git a/models/turbine_models/tests/sd3_test.py b/models/turbine_models/tests/sd3_test.py index 95309947d..a627eb287 100644 --- a/models/turbine_models/tests/sd3_test.py +++ b/models/turbine_models/tests/sd3_test.py @@ -354,191 +354,54 @@ def test03_ExportVaeModelDecode(self): np.testing.assert_allclose(torch_output, turbine, rtol, atol) + @pytest.mark.skip("Waiting on inference plumbing for generalized sd pipeline") + def test04SDPipeline(self): + from turbine_models.custom_models.sd_inference.sd_pipeline import ( + SharkSDPipeline, + ) -# def test04_ExportVaeModelEncode(self): -# if arguments["device"] in ["cpu", "vulkan", "cuda", "rocm"]: -# self.skipTest( -# "Compilation error on cpu, vulkan and rocm; To be tested on cuda." -# ) -# vae.export_vae_model( -# vae_model=self.vae_model, -# # This is a public model, so no auth required -# hf_model_name=arguments["hf_model_name"], -# batch_size=arguments["batch_size"], -# height=arguments["height"], -# width=arguments["width"], -# precision=arguments["precision"], -# compile_to="vmfb", -# external_weights=arguments["external_weights"], -# external_weight_path=self.safe_model_name -# + "_" -# + arguments["precision"] -# + "_vae_encode." -# + arguments["external_weights"], -# device=arguments["device"], -# target_triple=arguments["iree_target_triple"], -# ireec_flags=arguments["ireec_flags"], -# variant="encode", -# decomp_attn=arguments["decomp_attn"], -# exit_on_vmfb=True, -# ) -# arguments["external_weight_path"] = ( -# self.safe_model_name -# + "_" -# + arguments["precision"] -# + "_vae_encode." -# + arguments["external_weights"] -# ) -# arguments["vmfb_path"] = ( -# self.safe_model_name -# + "_" -# + str(arguments["height"]) -# + "x" -# + str(arguments["width"]) -# + "_" -# + arguments["precision"] -# + "_vae_encode_" -# + arguments["device"] -# + ".vmfb" -# ) -# example_input = torch.ones( -# arguments["batch_size"], -# 3, -# arguments["height"], -# arguments["width"], -# dtype=torch.float32, -# ) -# example_input_torch = example_input -# if arguments["precision"] == "fp16": -# example_input = example_input.half() -# turbine = vae_runner.run_vae( -# arguments["rt_device"], -# example_input, -# arguments["vmfb_path"], -# arguments["hf_model_name"], -# arguments["external_weight_path"], -# ) -# torch_output = vae_runner.run_torch_vae( -# arguments["hf_model_name"], -# ( -# "madebyollin/sdxl-vae-fp16-fix" -# if arguments["precision"] == "fp16" -# else "" -# ), -# "encode", -# example_input_torch, -# ) -# if arguments["benchmark"] or arguments["tracy_profile"]: -# run_benchmark( -# "vae_encode", -# arguments["vmfb_path"], -# arguments["external_weight_path"], -# arguments["rt_device"], -# height=arguments["height"], -# width=arguments["width"], -# precision=arguments["precision"], -# tracy_profile=arguments["tracy_profile"], -# ) -# rtol = 4e-2 -# atol = 4e-2 -# np.testing.assert_allclose(torch_output, turbine, rtol, atol) + current_args = copy.deepcopy(default_arguments) + decomp_attn = { + "text_encoder": False, + "unet": False, + "vae": current_args["vae_decomp_attn"], + } + sd_pipe = SharkSDPipeline( + current_args["hf_model_name"], + current_args["height"], + current_args["width"], + current_args["batch_size"], + current_args["max_length"], + current_args["precision"], + current_args["device"], + current_args["iree_target_triple"], + ireec_flags=None, # ireec_flags + attn_spec=current_args["attn_spec"], + decomp_attn=decomp_attn, + pipeline_dir="test_vmfbs", # pipeline_dir + external_weights_dir="test_weights", # external_weights_dir + external_weights=current_args["external_weights"], + num_inference_steps=current_args["num_inference_steps"], + cpu_scheduling=True, + scheduler_id=current_args["scheduler_id"], + shift=None, # shift + use_i8_punet=False, + ) + sd_pipe.prepare_all() + sd_pipe.load_map() + output = sd_pipe.generate_images( + current_args["prompt"], + current_args["negative_prompt"], + current_args["num_inference_steps"], + 1, # batch count + current_args["guidance_scale"], + current_args["seed"], + current_args["cpu_scheduling"], + current_args["scheduler_id"], + True, # return_img + ) + assert output is not None -# def test05_t2i_generate_images(self): -# if arguments["device"] in ["vulkan", "cuda", "rocm"]: -# self.skipTest( -# "Have issues with submodels on vulkan, cuda; ROCM hangs on mi250 despite submodels working." -# ) -# mlirs = { -# "vae_decode": None, -# "prompt_encoder": None, -# "scheduled_unet": None, -# "pipeline": None, -# "full_pipeline": None, -# } -# vmfbs = { -# "vae_decode": None, -# "prompt_encoder": None, -# "scheduled_unet": None, -# "pipeline": None, -# "full_pipeline": None, -# } -# weights = { -# "vae_decode": None, -# "prompt_encoder": None, -# "scheduled_unet": None, -# "pipeline": None, -# "full_pipeline": None, -# } -# -# if not arguments["pipeline_dir"]: -# pipe_id_list = [ -# "sdxl_1_0", -# str(arguments["height"]), -# str(arguments["width"]), -# str(arguments["max_length"]), -# arguments["precision"], -# arguments["device"], -# ] -# arguments["pipeline_dir"] = os.path.join( -# ".", -# "_".join(pipe_id_list), -# ) -# ireec_flags = { -# "unet": arguments["ireec_flags"], -# "vae": arguments["ireec_flags"], -# "clip": arguments["ireec_flags"], -# "pipeline": arguments["ireec_flags"], -# } -# user_mlir_list = [] -# for submodel_id, mlir_path in zip(mlirs.keys(), user_mlir_list): -# if submodel_id in mlir_path: -# mlirs[submodel_id] = mlir_path -# external_weights_dir = arguments["pipeline_dir"] -# sdxl_pipe = sdxl_compiled_pipeline.SharkSDXLPipeline( -# arguments["hf_model_name"], -# arguments["scheduler_id"], -# arguments["height"], -# arguments["width"], -# arguments["precision"], -# arguments["max_length"], -# arguments["batch_size"], -# arguments["num_inference_steps"], -# arguments["device"], -# arguments["iree_target_triple"], -# ireec_flags, -# arguments["attn_spec"], -# arguments["decomp_attn"], -# arguments["pipeline_dir"], -# external_weights_dir, -# arguments["external_weights"], -# ) -# vmfbs, weights = sdxl_pipe.check_prepared( -# mlirs, vmfbs, weights, interactive=False -# ) -# sdxl_pipe.load_pipeline( -# vmfbs, weights, arguments["rt_device"], arguments["compiled_pipeline"] -# ) -# sdxl_pipe.generate_images( -# arguments["prompt"], -# arguments["negative_prompt"], -# 1, -# arguments["guidance_scale"], -# arguments["seed"], -# ) -# print("Image generation complete.") -# os.remove(os.path.join(arguments["pipeline_dir"], "prompt_encoder.vmfb")) -# os.remove( -# os.path.join( -# arguments["pipeline_dir"], -# arguments["scheduler_id"] -# + "_unet_" -# + str(arguments["num_inference_steps"]) -# + ".vmfb", -# ) -# ) -# os.remove(os.path.join(arguments["pipeline_dir"], "vae_decode.vmfb")) -# os.remove(os.path.join(arguments["pipeline_dir"], "full_pipeline.vmfb")) -# if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) diff --git a/models/turbine_models/tests/sd_test.py b/models/turbine_models/tests/sd_test.py index 7af7dcb10..738738702 100644 --- a/models/turbine_models/tests/sd_test.py +++ b/models/turbine_models/tests/sd_test.py @@ -23,6 +23,7 @@ import os import copy import platform +from PIL import Image from turbine_models.turbine_tank import turbine_tank @@ -30,8 +31,8 @@ "hf_auth_token": None, "hf_model_name": "CompVis/stable-diffusion-v1-4", "safe_model_name": "stable-diffusion_v1_4", - "scheduler_id": "PNDM", - "num_inference_steps": 5, + "scheduler_id": "EulerDiscrete", + "num_inference_steps": 2, "batch_size": 1, "height": 512, "width": 512, @@ -47,91 +48,39 @@ "rt_device": "local-task", "iree_target_triple": "x86_64-linux-gnu", "prompt": "a photograph of an astronaut riding a horse", + "negative_prompt": "blurry, out of focus", "in_channels": 4, + "vae_decomp_attn": True, + "seed": 0, + "use_i8_punet": False, + "attn_spec": None, + "cpu_scheduling": True, } UPLOAD_IR = os.environ.get("TURBINE_TANK_ACTION", "not_upload") == "upload" -unet_model = unet.UnetModel( - # This is a public model, so no auth required - default_arguments["hf_model_name"], -) - -vae_model = vae.VaeModel( - # This is a public model, so no auth required - default_arguments["hf_model_name"], - custom_vae=None, -) - -scheduler = schedulers.get_scheduler( - default_arguments["hf_model_name"], default_arguments["scheduler_id"] -) -scheduler_module = schedulers.SchedulingModel( - scheduler, - default_arguments["height"], - default_arguments["width"], - default_arguments["num_inference_steps"], - default_arguments["precision"], -) - - # TODO: this is a mess, don't share args across tests, create a copy for each test class StableDiffusionTest(unittest.TestCase): - def testExportT5Model(self): + def testExportClipModel(self): current_args = copy.deepcopy(default_arguments) - current_args["hf_model_name"] = "google/t5-v1_1-small" - blob_name = clip.export_clip_model( - hf_model_name=current_args["hf_model_name"], - max_length=64, - precision=current_args["precision"], - compile_to="vmfb", - external_weights=None, - external_weight_path=None, - device="cpu", - target_triple=None, - exit_on_vmfb=False, - upload_ir=UPLOAD_IR, - ) - current_args["vmfb_path"] = blob_name - turbine = clip_runner.run_clip( - current_args["rt_device"], - current_args["prompt"], - current_args["vmfb_path"], - current_args["hf_model_name"], - current_args["hf_auth_token"], - None, - ) - torch_output = clip_runner.run_torch_clip( - current_args["hf_model_name"], - current_args["hf_auth_token"], - current_args["prompt"], + current_args["hf_model_name"] = "CompVis/stable-diffusion-v1-4" + safe_prefix = utils.create_safe_name( + current_args["hf_model_name"].split("/")[-1], "clip" ) - err = utils.largest_error(torch_output, turbine[0]) - assert err < 9e-4 - if UPLOAD_IR: - new_blob_name = blob_name.split(".") - new_blob_name = new_blob_name[0] + "-pass.mlir" - turbine_tank.changeBlobName(blob_name, new_blob_name) - del current_args - - def testExportClipVitLarge14(self): - current_args = copy.deepcopy(default_arguments) - current_args["hf_model_name"] = "openai/clip-vit-large-patch14" - safe_prefix = "clip_vit_large_patch14" blob_name = clip.export_clip_model( hf_model_name=current_args["hf_model_name"], - max_length=64, + max_length=current_args["max_length"], precision=current_args["precision"], compile_to="vmfb", external_weights="safetensors", external_weight_path=safe_prefix + ".safetensors", device="cpu", - target_triple=None, + target=current_args["iree_target_triple"], exit_on_vmfb=False, upload_ir=UPLOAD_IR, ) current_args["external_weight_path"] = safe_prefix + ".safetensors" - current_args["vmfb_path"] = safe_prefix + "_clip.vmfb" + current_args["vmfb_path"] = blob_name turbine = clip_runner.run_clip( current_args["rt_device"], current_args["prompt"], @@ -155,67 +104,26 @@ def testExportClipVitLarge14(self): os.remove(current_args["external_weight_path"]) os.remove(current_args["vmfb_path"]) - def testExportClipModel(self): + def testExportUnetModel(self): current_args = copy.deepcopy(default_arguments) - current_args["hf_model_name"] = "CompVis/stable-diffusion-v1-4" - blob_name = clip.export_clip_model( + blob_name = unet.export_unet_model( hf_model_name=current_args["hf_model_name"], - max_length=64, + batch_size=current_args["batch_size"], + height=current_args["height"], + width=current_args["width"], precision=current_args["precision"], + max_length=current_args["max_length"], compile_to="vmfb", external_weights="safetensors", - external_weight_path=safe_prefix + ".safetensors", + external_weight_path="stable_diffusion_unet.safetensors", device="cpu", - target_triple=None, - exit_on_vmfb=False, - upload_ir=UPLOAD_IR, - ) - current_args["external_weight_path"] = "stable_diffusion_v1_4_clip.safetensors" - current_args["vmfb_path"] = "stable_diffusion_v1_4_clip.vmfb" - turbine = clip_runner.run_clip( - current_args["rt_device"], - current_args["prompt"], - current_args["vmfb_path"], - current_args["hf_model_name"], - current_args["hf_auth_token"], - current_args["external_weight_path"], - ) - torch_output = clip_runner.run_torch_clip( - current_args["hf_model_name"], - current_args["hf_auth_token"], - current_args["prompt"], - ) - err = utils.largest_error(torch_output, turbine[0]) - assert err < 9e-5 - if UPLOAD_IR: - new_blob_name = blob_name.split(".") - new_blob_name = new_blob_name[0] + "-pass.mlir" - turbine_tank.changeBlobName(blob_name, new_blob_name) - if platform.system() != "Windows": - os.remove(current_args["external_weight_path"]) - os.remove(current_args["vmfb_path"]) - - def testExportUnetModel(self): - current_args = copy.deepcopy(default_arguments) - blob_name = unet.export_unet_model( - unet_model, - current_args["hf_model_name"], - current_args["batch_size"], - current_args["height"], - current_args["width"], - current_args["precision"], - current_args["max_length"], - None, - "vmfb", - "safetensors", - "stable_diffusion_unet.safetensors", - "cpu", + target=current_args["iree_target_triple"], upload_ir=UPLOAD_IR, ) current_args["external_weight_path"] = "stable_diffusion_unet.safetensors" current_args["vmfb_path"] = blob_name sample = torch.rand( - current_args["batch_size"], + current_args["batch_size"] * 2, current_args["in_channels"], current_args["height"] // 8, current_args["width"] // 8, @@ -245,6 +153,7 @@ def testExportUnetModel(self): current_args["hf_model_name"], current_args["hf_auth_token"], current_args["external_weight_path"], + "float32", ) torch_output = unet_runner.run_torch_unet( current_args["hf_model_name"], @@ -268,17 +177,17 @@ def testExportUnetModel(self): def testExportVaeModelDecode(self): current_args = copy.deepcopy(default_arguments) blob_name = vae.export_vae_model( - vae_model, - current_args["hf_model_name"], - current_args["batch_size"], - current_args["height"], - current_args["width"], - current_args["precision"], - "vmfb", - "safetensors", - "stable_diffusion_v1_4_vae.safetensors", - "cpu", - variant="decode", + hf_model_name=current_args["hf_model_name"], + batch_size=current_args["batch_size"], + height=current_args["height"], + width=current_args["width"], + precision=current_args["precision"], + compile_to="vmfb", + external_weights="safetensors", + external_weight_path="stable_diffusion_v1_4_vae.safetensors", + device="cpu", + target=current_args["iree_target_triple"], + decomp_attn=current_args["vae_decomp_attn"], upload_ir=UPLOAD_IR, ) current_args["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" @@ -290,14 +199,14 @@ def testExportVaeModelDecode(self): current_args["width"] // 8, dtype=torch.float32, ) - turbine = vae_runner.run_vae( + turbine = vae_runner.run_vae_decode( current_args["rt_device"], example_input, current_args["vmfb_path"], current_args["hf_model_name"], current_args["external_weight_path"], ) - torch_output = vae_runner.run_torch_vae( + torch_output = vae_runner.run_torch_vae_decode( current_args["hf_model_name"], "decode", example_input, @@ -311,107 +220,54 @@ def testExportVaeModelDecode(self): del torch_output del turbine os.remove("stable_diffusion_v1_4_vae.safetensors") - os.remove("blob_name") + os.remove(blob_name) - def testExportVaeModelEncode(self): - current_args = copy.deepcopy(default_arguments) - blob_name = vae.export_vae_model( - vae_model, - current_args["hf_model_name"], - current_args["batch_size"], - current_args["height"], - current_args["width"], - current_args["precision"], - "vmfb", - "safetensors", - "stable_diffusion_v1_4_vae.safetensors", - "cpu", - variant="encode", - upload_ir=UPLOAD_IR, + def testSDPipeline(self): + from turbine_models.custom_models.sd_inference.sd_pipeline import ( + SharkSDPipeline, ) - current_args["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" - current_args["vmfb_path"] = blob_name - example_input = torch.rand( - current_args["batch_size"], - 3, - current_args["height"], - current_args["width"], - dtype=torch.float32, - ) - turbine = vae_runner.run_vae( - current_args["rt_device"], - example_input, - current_args["vmfb_path"], - current_args["hf_model_name"], - current_args["external_weight_path"], - ) - torch_output = vae_runner.run_torch_vae( - current_args["hf_model_name"], - "encode", - example_input, - ) - err = utils.largest_error(torch_output, turbine) - assert err < 3e-3 - if UPLOAD_IR: - new_blob_name = blob_name.split(".") - new_blob_name = new_blob_name[0] + "-pass.mlir" - turbine_tank.changeBlobName(blob_name, new_blob_name) - os.remove("stable_diffusion_v1_4_vae.safetensors") - os.remove(blob_name) - @unittest.expectedFailure - def testExportPNDMScheduler(self): current_args = copy.deepcopy(default_arguments) - safe_name = "stable_diffusion_v1_4_scheduler" - blob_name = schedulers.export_scheduler_model( + decomp_attn = { + "text_encoder": False, + "unet": False, + "vae": current_args["vae_decomp_attn"], + } + sd_pipe = SharkSDPipeline( current_args["hf_model_name"], - current_args["scheduler_id"], - current_args["batch_size"], current_args["height"], current_args["width"], - current_args["num_inference_steps"], + current_args["batch_size"], + current_args["max_length"], current_args["precision"], - "vmfb", current_args["device"], current_args["iree_target_triple"], - upload_ir=UPLOAD_IR, - ) - current_args["external_weight_path"] = safe_name + ".safetensors" - current_args["vmfb_path"] = blob_name - sample = torch.rand( - current_args["batch_size"], - 4, - current_args["height"] // 8, - current_args["width"] // 8, - dtype=torch.float32, + ireec_flags=None, # ireec_flags + attn_spec=current_args["attn_spec"], + decomp_attn=decomp_attn, + pipeline_dir="test_vmfbs", # pipeline_dir + external_weights_dir="test_weights", # external_weights_dir + external_weights=current_args["external_weights"], + num_inference_steps=current_args["num_inference_steps"], + cpu_scheduling=True, + scheduler_id=current_args["scheduler_id"], + shift=None, # shift + use_i8_punet=current_args["use_i8_punet"], ) - encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) - turbine = schedulers_runner.run_scheduler( - current_args["rt_device"], - sample, - encoder_hidden_states, - current_args["vmfb_path"], - current_args["hf_model_name"], - current_args["hf_auth_token"], - current_args["external_weight_path"], - ) - torch_output = schedulers_runner.run_torch_scheduler( - current_args["hf_model_name"], - scheduler, + sd_pipe.prepare_all() + sd_pipe.load_map() + output = sd_pipe.generate_images( + current_args["prompt"], + current_args["negative_prompt"], current_args["num_inference_steps"], - sample, - encoder_hidden_states, + 1, # batch count + current_args["guidance_scale"], + current_args["seed"], + current_args["cpu_scheduling"], + current_args["scheduler_id"], + True, # return_img ) - err = utils.largest_error(torch_output, turbine) - assert err < 9e-3 - if UPLOAD_IR: - new_blob_name = blob_name.split(".") - new_blob_name = new_blob_name[0] + "-pass.mlir" - turbine_tank.changeBlobName(blob_name, new_blob_name) - os.remove("stable_diffusion_v1_4_scheduler.safetensors") - os.remove("stable_diffusion_v1_4_scheduler.vmfb") - del torch_output - del turbine + assert output is not None if __name__ == "__main__": diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index fa44673ac..031dd43ec 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -9,7 +9,7 @@ import torch from transformers import CLIPTokenizer from turbine_models.custom_models.sd_inference.utils import create_safe_name -from turbine_models.custom_models.sd_inference import schedulers +from turbine_models.custom_models.sd_inference import schedulers, vae from turbine_models.custom_models.sdxl_inference import ( sdxl_prompt_encoder, sdxl_prompt_encoder_runner, @@ -17,7 +17,6 @@ unet_runner, sdxl_scheduled_unet, sdxl_scheduled_unet_runner, - vae, vae_runner, sdxl_compiled_pipeline, ) @@ -81,20 +80,6 @@ def command_line_args(request): class StableDiffusionXLTest(unittest.TestCase): def setUp(self): self.safe_model_name = create_safe_name(arguments["hf_model_name"], "") - self.unet_model = unet.UnetModel( - # This is a public model, so no auth required - arguments["hf_model_name"], - precision=arguments["precision"], - ) - self.vae_model = vae.VaeModel( - # This is a public model, so no auth required - arguments["hf_model_name"], - custom_vae=( - "madebyollin/sdxl-vae-fp16-fix" - if arguments["precision"] == "fp16" - else None - ), - ) def test01_ExportPromptEncoder(self): if arguments["device"] in ["vulkan", "cuda"]: @@ -104,23 +89,17 @@ def test01_ExportPromptEncoder(self): arguments["external_weight_path"] = ( "prompt_encoder." + arguments["external_weights"] ) - _, prompt_encoder_vmfb = sdxl_prompt_encoder.export_prompt_encoder( + prompt_encoder_vmfb = sdxl_prompt_encoder.export_prompt_encoder( arguments["hf_model_name"], - None, - arguments["max_length"], - arguments["precision"], - "vmfb", - "safetensors", - arguments["external_weight_path"], - arguments["device"], - arguments["iree_target_triple"], - arguments["ireec_flags"], - False, - None, - None, - arguments["attn_spec"], - False, - arguments["batch_size"], + hf_auth_token=None, + max_length=arguments["max_length"], + batch_size=arguments["batch_size"], + precision=arguments["precision"], + compile_to="vmfb", + external_weights="safetensors", + external_weight_path=arguments["external_weight_path"], + device=arguments["device"], + target=arguments["iree_target_triple"], ) tokenizer_1 = CLIPTokenizer.from_pretrained( arguments["hf_model_name"], @@ -177,9 +156,7 @@ def test01_ExportPromptEncoder(self): def test02_ExportUnetModel(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Unknown error on vulkan; To be tested on cuda.") - unet.export_unet_model( - unet_model=self.unet_model, - # This is a public model, so no auth required + unet_vmfb = unet.export_unet_model( hf_model_name=arguments["hf_model_name"], batch_size=arguments["batch_size"], height=arguments["height"], @@ -195,10 +172,11 @@ def test02_ExportUnetModel(self): + "_unet." + arguments["external_weights"], device=arguments["device"], - target_triple=arguments["iree_target_triple"], + target=arguments["iree_target_triple"], ireec_flags=arguments["ireec_flags"], decomp_attn=arguments["decomp_attn"], attn_spec=arguments["attn_spec"], + exit_on_vmfb=False, ) arguments["external_weight_path"] = ( self.safe_model_name @@ -207,20 +185,7 @@ def test02_ExportUnetModel(self): + "_unet." + arguments["external_weights"] ) - arguments["vmfb_path"] = ( - self.safe_model_name - + "_" - + str(arguments["max_length"]) - + "_" - + str(arguments["height"]) - + "x" - + str(arguments["width"]) - + "_" - + arguments["precision"] - + "_unet_" - + arguments["device"] - + ".vmfb" - ) + arguments["vmfb_path"] = unet_vmfb dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32 sample = torch.rand( ( @@ -231,7 +196,7 @@ def test02_ExportUnetModel(self): ), dtype=dtype, ) - timestep = torch.zeros(1, dtype=torch.int64) + timestep = torch.zeros(1, dtype=dtype) prompt_embeds = torch.rand( (2 * arguments["batch_size"], arguments["max_length"], 2048), dtype=dtype, @@ -286,9 +251,7 @@ def test02_ExportUnetModel(self): def test03_ExportVaeModelDecode(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Compilation error on vulkan; To be tested on cuda.") - vae.export_vae_model( - vae_model=self.vae_model, - # This is a public model, so no auth required + vae_vmfb = vae.export_vae_model( hf_model_name=arguments["hf_model_name"], batch_size=arguments["batch_size"], height=arguments["height"], @@ -302,12 +265,11 @@ def test03_ExportVaeModelDecode(self): + "_vae_decode." + arguments["external_weights"], device=arguments["device"], - target_triple=arguments["iree_target_triple"], + target=arguments["iree_target_triple"], ireec_flags=arguments["ireec_flags"], - variant="decode", - decomp_attn=arguments["decomp_attn"], + decomp_attn=True, attn_spec=arguments["attn_spec"], - exit_on_vmfb=True, + exit_on_vmfb=False, ) arguments["external_weight_path"] = ( self.safe_model_name @@ -316,18 +278,7 @@ def test03_ExportVaeModelDecode(self): + "_vae_decode." + arguments["external_weights"] ) - arguments["vmfb_path"] = ( - self.safe_model_name - + "_" - + str(arguments["height"]) - + "x" - + str(arguments["width"]) - + "_" - + arguments["precision"] - + "_vae_decode_" - + arguments["device"] - + ".vmfb" - ) + arguments["vmfb_path"] = vae_vmfb example_input = torch.ones( arguments["batch_size"], 4, @@ -376,7 +327,7 @@ def test04_ExportVaeModelEncode(self): self.skipTest( "Compilation error on cpu, vulkan and rocm; To be tested on cuda." ) - vae.export_vae_model( + vae_vmfb = vae.export_vae_model( vae_model=self.vae_model, # This is a public model, so no auth required hf_model_name=arguments["hf_model_name"], @@ -392,10 +343,9 @@ def test04_ExportVaeModelEncode(self): + "_vae_encode." + arguments["external_weights"], device=arguments["device"], - target_triple=arguments["iree_target_triple"], + target=arguments["iree_target_triple"], ireec_flags=arguments["ireec_flags"], - variant="encode", - decomp_attn=arguments["decomp_attn"], + decomp_attn=True, exit_on_vmfb=True, ) arguments["external_weight_path"] = ( @@ -405,18 +355,7 @@ def test04_ExportVaeModelEncode(self): + "_vae_encode." + arguments["external_weights"] ) - arguments["vmfb_path"] = ( - self.safe_model_name - + "_" - + str(arguments["height"]) - + "x" - + str(arguments["width"]) - + "_" - + arguments["precision"] - + "_vae_encode_" - + arguments["device"] - + ".vmfb" - ) + arguments["vmfb_path"] = vae_vmfb example_input = torch.ones( arguments["batch_size"], 3, @@ -460,100 +399,105 @@ def test04_ExportVaeModelEncode(self): np.testing.assert_allclose(torch_output, turbine, rtol, atol) def test05_t2i_generate_images(self): - if arguments["device"] in ["vulkan", "cuda", "rocm"]: + if arguments["device"] in ["vulkan", "cuda"]: self.skipTest( "Have issues with submodels on vulkan, cuda; ROCM hangs on mi250 despite submodels working." ) - mlirs = { - "vae_decode": None, - "prompt_encoder": None, - "scheduled_unet": None, - "pipeline": None, - "full_pipeline": None, - } - vmfbs = { - "vae_decode": None, - "prompt_encoder": None, - "scheduled_unet": None, - "pipeline": None, - "full_pipeline": None, - } - weights = { - "vae_decode": None, - "prompt_encoder": None, - "scheduled_unet": None, - "pipeline": None, - "full_pipeline": None, - } + from turbine_models.custom_models.sd_inference.sd_pipeline import ( + SharkSDPipeline, + ) - if not arguments["pipeline_dir"]: - pipe_id_list = [ - "sdxl_1_0", - str(arguments["height"]), - str(arguments["width"]), - str(arguments["max_length"]), - arguments["precision"], - arguments["device"], - ] - arguments["pipeline_dir"] = os.path.join( - ".", - "_".join(pipe_id_list), - ) - ireec_flags = { - "unet": arguments["ireec_flags"], - "vae": arguments["ireec_flags"], - "clip": arguments["ireec_flags"], - "pipeline": arguments["ireec_flags"], + decomp_attn = { + "text_encoder": False, + "unet": False, + "vae": True, } - user_mlir_list = [] - for submodel_id, mlir_path in zip(mlirs.keys(), user_mlir_list): - if submodel_id in mlir_path: - mlirs[submodel_id] = mlir_path - external_weights_dir = arguments["pipeline_dir"] - sdxl_pipe = sdxl_compiled_pipeline.SharkSDXLPipeline( + sd_pipe = SharkSDPipeline( arguments["hf_model_name"], - arguments["scheduler_id"], arguments["height"], arguments["width"], - arguments["precision"], - arguments["max_length"], arguments["batch_size"], - arguments["num_inference_steps"], + arguments["max_length"], + arguments["precision"], arguments["device"], arguments["iree_target_triple"], - ireec_flags, - arguments["attn_spec"], - arguments["decomp_attn"], - arguments["pipeline_dir"], - external_weights_dir, - arguments["external_weights"], - ) - vmfbs, weights = sdxl_pipe.check_prepared( - mlirs, vmfbs, weights, interactive=False - ) - sdxl_pipe.load_pipeline( - vmfbs, weights, arguments["rt_device"], arguments["compiled_pipeline"] - ) - sdxl_pipe.generate_images( + ireec_flags=None, # ireec_flags + attn_spec=arguments["attn_spec"], + decomp_attn=decomp_attn, + pipeline_dir="test_vmfbs", # pipeline_dir + external_weights_dir="test_weights", # external_weights_dir + external_weights=arguments["external_weights"], + num_inference_steps=arguments["num_inference_steps"], + cpu_scheduling=True, + scheduler_id=arguments["scheduler_id"], + shift=None, # shift + use_i8_punet=False, + ) + sd_pipe.prepare_all() + sd_pipe.load_map() + output = sd_pipe.generate_images( arguments["prompt"], arguments["negative_prompt"], - 1, + arguments["num_inference_steps"], + 1, # batch count arguments["guidance_scale"], arguments["seed"], + True, + arguments["scheduler_id"], + True, # return_img ) - print("Image generation complete.") - os.remove(os.path.join(arguments["pipeline_dir"], "prompt_encoder.vmfb")) - os.remove( - os.path.join( - arguments["pipeline_dir"], - arguments["scheduler_id"] - + "_unet_" - + str(arguments["num_inference_steps"]) - + ".vmfb", + assert output is not None + + @pytest.mark.skip(reason="Needs sdxl_quantized branch of IREE") + def test06_t2i_generate_images_punet(self): + if arguments["device"] in ["vulkan", "cuda", "rocm"]: + self.skipTest( + "Have issues with submodels on vulkan, cuda; ROCM hangs on mi250 despite submodels working." ) + from turbine_models.custom_models.sd_inference.sd_pipeline import ( + SharkSDPipeline, + ) + + decomp_attn = { + "text_encoder": False, + "unet": False, + "vae": True, + } + sd_pipe = SharkSDPipeline( + arguments["hf_model_name"], + arguments["height"], + arguments["width"], + arguments["batch_size"], + arguments["max_length"], + arguments["precision"], + arguments["device"], + arguments["iree_target_triple"], + ireec_flags=None, # ireec_flags + attn_spec=arguments["attn_spec"], + decomp_attn=decomp_attn, + pipeline_dir="test_vmfbs", # pipeline_dir + external_weights_dir="test_weights", # external_weights_dir + external_weights=arguments["external_weights"], + num_inference_steps=arguments["num_inference_steps"], + cpu_scheduling=True, + scheduler_id=arguments["scheduler_id"], + shift=None, # shift + use_i8_punet=True, + ) + sd_pipe.prepare_all() + sd_pipe.load_map() + output = sd_pipe.generate_images( + arguments["prompt"], + arguments["negative_prompt"], + arguments["num_inference_steps"], + 1, # batch count + arguments["guidance_scale"], + arguments["seed"], + True, + arguments["scheduler_id"], + True, # return_img ) - os.remove(os.path.join(arguments["pipeline_dir"], "vae_decode.vmfb")) - os.remove(os.path.join(arguments["pipeline_dir"], "full_pipeline.vmfb")) + assert output is not None if __name__ == "__main__": From 22eb78c96d72d2a51dad2a4cacf6b451fdf38b52 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Tue, 9 Jul 2024 22:49:42 -0500 Subject: [PATCH 09/12] Update test_shark.yml --- .github/workflows/test_shark.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_shark.yml b/.github/workflows/test_shark.yml index 301376a47..6f2e4b4ed 100644 --- a/.github/workflows/test_shark.yml +++ b/.github/workflows/test_shark.yml @@ -20,7 +20,7 @@ jobs: strategy: matrix: version: [3.11] - os: [nodai-ubuntu-builder-large] + os: [nodai-amdgpu-mi250-x86-64] runs-on: ${{matrix.os}} steps: From 4d8960d0f20cb3eedd840768fb1a2297acb3a004 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 10 Jul 2024 00:33:58 -0500 Subject: [PATCH 10/12] Use hip driver unless rocm-legacy is used as device string. --- .../custom_models/pipeline_base.py | 5 +- .../custom_models/sd_inference/utils.py | 50 +++++++++++++------ 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 1da3aacfb..d46e20b84 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -309,7 +309,7 @@ def __init__( assert ( submodel in target.keys() ), f"Target arch for {submodel} not found." - self.map[submodel]["device"] = device[submodel].split("://")[0] + self.map[submodel]["device"] = utils.iree_backend_map(device[submodel]) self.map[submodel]["driver"] = utils.iree_device_map(device[submodel]) self.map[submodel]["target"] = target[submodel] else: @@ -317,9 +317,10 @@ def __init__( target, str ), "Device and target triple must be both dicts or both strings." for submodel in self.map.keys(): - self.map[submodel]["device"] = device.split("://")[0] + self.map[submodel]["device"] = utils.iree_backend_map(device) self.map[submodel]["driver"] = utils.iree_device_map(device) self.map[submodel]["target"] = target + map_arguments = { "ireec_flags": ireec_flags, "precision": precision, diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 800a66c31..e891642e6 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -12,17 +12,6 @@ # DPMSolverSDEScheduler, ) -_IREE_DEVICE_MAP = { - "cpu": "local-task", - "cpu-task": "local-task", - "cpu-sync": "local-sync", - "cuda": "cuda", - "vulkan": "vulkan", - "metal": "metal", - "rocm": "rocm", - "hip": "hip", - "intel-gpu": "level_zero", -} # If flags are verified to work on a specific model and improve performance without regressing numerics, add them to this dictionary. If you are working with bleeding edge flags, please add them manually with the --ireec_flags argument. MI_flags = { "all": [ @@ -102,22 +91,53 @@ ], } +_IREE_DRIVER_MAP = { + "cpu": "local-task", + "cpu-task": "local-task", + "cpu-sync": "local-sync", + "cuda": "cuda", + "vulkan": "vulkan", + "metal": "metal", + "rocm": "hip", + "rocm-legacy": "rocm", + "hip": "hip", + "intel-gpu": "level_zero", +} + +_IREE_BACKEND_MAP = { + "cpu": "llvm-cpu", + "rocm": "rocm", + "rocm-legacy": "rocm", + "hip": "rocm", + "cuda": "cuda", + "vulkan": "vulkan-spirv", + "metal": "metal", +} + def iree_device_map(device): uri_parts = device.split("://", 2) iree_driver = ( - _IREE_DEVICE_MAP[uri_parts[0]] - if uri_parts[0] in _IREE_DEVICE_MAP + _IREE_DRIVER_MAP[uri_parts[0]] + if uri_parts[0] in _IREE_DRIVER_MAP else uri_parts[0] ) if len(uri_parts) == 1: return iree_driver - elif "rocm" in uri_parts: - return "rocm" else: return f"{iree_driver}://{uri_parts[1]}" +def iree_backend_map(device): + uri_parts = device.split("://", 2) + iree_device = ( + _IREE_BACKEND_MAP[uri_parts[0]] + if uri_parts[0] in _IREE_BACKEND_MAP + else uri_parts[0] + ) + return iree_device + + def compile_to_vmfb( module_str, device, From e8de325a422221ea592dbbe6d064be5ef4d40bc1 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 10 Jul 2024 03:10:01 -0500 Subject: [PATCH 11/12] Switch device args to hip in CI and xfail prompt encoder test for now. --- models/turbine_models/custom_models/sd_inference/utils.py | 2 +- models/turbine_models/tests/sdxl_test.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index e891642e6..447076d42 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -194,7 +194,7 @@ def compile_to_vmfb( ] ) device = "vulkan-spirv" - elif device == "rocm": + elif device in ["rocm", "hip"]: flags.extend( [ "--iree-hal-target-backends=rocm", diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 031dd43ec..da9dfdafe 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -82,9 +82,9 @@ def setUp(self): self.safe_model_name = create_safe_name(arguments["hf_model_name"], "") def test01_ExportPromptEncoder(self): - if arguments["device"] in ["vulkan", "cuda"]: + if arguments["device"] in ["vulkan", "cuda", "rocm"]: self.skipTest( - "Compilation error on vulkan; Runtime error on rocm; To be tested on cuda." + "Compilation error on vulkan; recent numerics regression (nans) on hip driver, To be tested on cuda." ) arguments["external_weight_path"] = ( "prompt_encoder." + arguments["external_weights"] @@ -400,9 +400,7 @@ def test04_ExportVaeModelEncode(self): def test05_t2i_generate_images(self): if arguments["device"] in ["vulkan", "cuda"]: - self.skipTest( - "Have issues with submodels on vulkan, cuda; ROCM hangs on mi250 despite submodels working." - ) + self.skipTest("Have issues with submodels on vulkan, cuda") from turbine_models.custom_models.sd_inference.sd_pipeline import ( SharkSDPipeline, ) From 708b8d0a17aebc98a86c43a1dac086e2eee8fe78 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 10 Jul 2024 03:32:41 -0500 Subject: [PATCH 12/12] Address comments --- .../sd_inference/tokenization.py | 5 ++- .../custom_models/sd_inference/unet.py | 2 +- .../custom_models/sdxl_inference/unet.py | 35 +++++++++++-------- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/tokenization.py b/models/turbine_models/custom_models/sd_inference/tokenization.py index 8c29d2d3c..e35d37e06 100644 --- a/models/turbine_models/custom_models/sd_inference/tokenization.py +++ b/models/turbine_models/custom_models/sd_inference/tokenization.py @@ -3,6 +3,7 @@ import re import torch import numpy as np +import warnings # The following is copied from Diffusers' "encode_prompt" function in the StableDiffusion pipeline. @@ -86,7 +87,9 @@ def encode_prompt( removed_text = pipe.tokenizer.batch_decode( untruncated_ids[:, pipe.model_max_length - 1 : -1] ) - print("The following text was removed due to truncation:", removed_text) + warnings.warn( + "The following text was removed due to truncation: " + removed_text + ) if pipe.text_encoder.metadata.get("use_attention_mask"): attention_mask = text_inputs.attention_mask prompt_embeds = pipe.text_encoder( diff --git a/models/turbine_models/custom_models/sd_inference/unet.py b/models/turbine_models/custom_models/sd_inference/unet.py index f111fb765..dac967b8a 100644 --- a/models/turbine_models/custom_models/sd_inference/unet.py +++ b/models/turbine_models/custom_models/sd_inference/unet.py @@ -130,7 +130,7 @@ def export_unet_model( torch.empty(1, dtype=dtype), ] decomp_list = [] - if decomp_attn == True: + if decomp_attn: decomp_list = [ torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, torch.ops.aten._scaled_dot_product_flash_attention.default, diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index aab370143..4d3af598c 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -107,22 +107,31 @@ def download(filename): "config.json": download("config.json"), "params.safetensors": download("params.safetensors"), } + output_dir = os.path.dirname(external_weight_path) + if precision == "i8": results["quant_params.json"] = download("quant_params.json") - output_path = external_weight_path.split("unet")[0] + "punet_dataset_i8.irpa" + ds_filename = ( + os.path.basename(external_weight_path).split("unet")[0] + + "punet_dataset_i8.irpa" + ) + output_path = os.path.join(output_dir, ds_filename) ds = get_punet_dataset( results["config.json"], results["params.safetensors"], output_path, results["quant_params.json"], - base_params=None, ) else: + ds_filename = ( + os.path.basename(external_weight_path).split("unet")[0] + + f"punet_dataset_{precision}.irpa" + ) + output_path = os.path.join(output_dir, ds_filename) ds = get_punet_dataset( results["config.json"], results["params.safetensors"], output_path, - base_params=None, ) cond_unet = sharktank_unet2d.from_dataset(ds) @@ -133,21 +142,19 @@ def download(filename): def get_punet_dataset( config_json_path, params_path, - output_path="./punet_dataset_i8.irpa", + output_path, quant_params_path=None, - quant_params_struct=None, - base_params=None, ): from sharktank.models.punet.tools import import_brevitas_dataset - import_brevitas_dataset.main( - [ - f"--config-json={config_json_path}", - f"--params={params_path}", - f"--quant-params={quant_params_path}", - f"--output-irpa-file={output_path}", - ] - ) + ds_import_args = [ + f"--config-json={config_json_path}", + f"--params={params_path}", + f"--output-irpa-file={output_path}", + ] + if quant_params_path: + ds_import_args.extend([f"--quant-params={quant_params_path}"]) + import_brevitas_dataset.main(ds_import_args) return import_brevitas_dataset.Dataset.load(output_path)