From 7388e14fe838c81616ca45585adae46e3d02dd7a Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 20 Jun 2024 16:22:34 -0500 Subject: [PATCH] SDXL: fix scheduled unet modes --- .../custom_models/sd_inference/schedulers.py | 1 - .../custom_models/sd_inference/utils.py | 20 +++- .../sdxl_inference/sdxl_cmd_opts.py | 2 +- .../sdxl_inference/sdxl_compiled_pipeline.py | 112 ++++++++++-------- .../sdxl_inference/sdxl_scheduled_unet.py | 12 +- 5 files changed, 91 insertions(+), 56 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/schedulers.py b/models/turbine_models/custom_models/sd_inference/schedulers.py index 7b2248152..2c8d618c6 100644 --- a/models/turbine_models/custom_models/sd_inference/schedulers.py +++ b/models/turbine_models/custom_models/sd_inference/schedulers.py @@ -223,7 +223,6 @@ def export_scheduler_model( f"{height}x{width}", precision, str(num_inference_steps), - target_triple, ] vmfb_name = "_".join(vmfb_names) safe_name = utils.create_safe_name(hf_model_name, "_" + vmfb_name) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 84700185c..8822d0144 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -240,6 +240,16 @@ def compile_to_vmfb( flags.pop(idx) print("Compiling to", device, "with flags:", flags) + # Forces a standard for naming files: + # If safe_name has target triple in it, get rid of target triple in mlir name + # + if target_triple not in safe_name: + safe_vmfb_name = safe_name + "_" + target_triple + safe_mlir_name = safe_name + else: + safe_vmfb_name = safe_name + safe_mlir_name = "".join(safe_name.split(target_triple)) + if mlir_source == "file": flatbuffer_blob = ireec.compile_file( module_str, @@ -249,9 +259,9 @@ def compile_to_vmfb( ) elif mlir_source == "str": if save_mlir: - with open(f"{safe_name}.mlir", "w+") as f: + with open(f"{safe_mlir_name}.mlir", "w+") as f: f.write(module_str) - print("Saved to", safe_name + ".mlir") + print("Saved to", safe_mlir_name + ".mlir") flatbuffer_blob = ireec.compile_str( module_str, target_backends=[device], @@ -260,11 +270,11 @@ def compile_to_vmfb( ) else: raise ValueError("mlir_source must be either 'file' or 'str'") - with open(f"{safe_name}.vmfb", "wb+") as f: + with open(f"{safe_vmfb_name}.vmfb", "wb+") as f: f.write(flatbuffer_blob) - print("Saved to", safe_name + ".vmfb") + print(f"Saved to {safe_vmfb_name}.vmfb") if return_path == True: - return safe_name + ".vmfb" + return safe_vmfb_name + ".vmfb" def create_safe_name(hf_model_name, model_name_str): diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py index 5d5bde32f..c1c21301b 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py @@ -125,7 +125,7 @@ def is_valid_file(arg): p.add_argument( "--split_scheduler", - default=True, + default=False, action="store_true", help="Use a decoupled unet and scheduler for better QOL.", ) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py index 71e5730b4..8d6ce4ed1 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py @@ -119,6 +119,7 @@ def __init__( self.custom_vae = custom_vae self.cpu_scheduling = cpu_scheduling self.compiled_pipeline = False + self.split_scheduler = False # TODO: set this based on user-inputted guidance scale and negative prompt. self.do_classifier_free_guidance = True # False if any(x in hf_model_name for x in ["turbo", "lightning"]) else True self._interrupt = False @@ -166,7 +167,7 @@ def check_prepared( print("There was an error generating the necessary files.") exit() else: - print("All necessary files found. Loading pipeline.") + print("All necessary files found.") return vmfbs, weights def is_prepared(self, vmfbs, weights): @@ -175,10 +176,11 @@ def is_prepared(self, vmfbs, weights): for key in vmfbs: if key == "scheduled_unet": keywords = [ - "unet", + "DiffusionModule", self.scheduler_id, - self.num_inference_steps, + str(self.num_inference_steps), self.precision, + self.max_length, dims, ] device_key = "unet" @@ -192,38 +194,44 @@ def is_prepared(self, vmfbs, weights): device_key = "clip" else: keywords = [key, self.precision, self.max_length, dims] - device_key = key - avail_files = os.listdir(self.pipeline_dir) - keywords.append("vmfb") - keywords.append( - utils.create_safe_name(self.hf_model_name.split("/")[-1], "") + device_key = "unet" + keywords.extend( + [ + utils.create_safe_name(self.hf_model_name.split("/")[-1], ""), + "vmfb", + "bs" + str(self.batch_size), + self.devices[device_key]["target"], + ] ) - keywords.append(self.devices[device_key]["target"]) + avail_files = os.listdir(self.pipeline_dir) for filename in avail_files: if all(str(x) in filename for x in keywords): vmfbs[key] = os.path.join(self.pipeline_dir, filename) if not vmfbs[key]: missing.append(key + " vmfb") + for w_key in weights: - if any(x in w_key for x in ["pipeline", "scheduler"]): - continue - if weights[w_key] is not None: - continue - if self.external_weights is None: - continue - default_name = os.path.join( - self.external_weights_dir, w_key + "." + self.external_weights - ) - if weights[w_key] is None and os.path.exists(default_name): - weights[w_key] = os.path.join(default_name) - elif w_key in ["scheduled_unet"] and os.path.exists( - os.path.join(self.external_weights_dir, "unet." + self.external_weights) + if any(x in w_key for x in ["pipeline", "scheduler"]) or ( + self.external_weights is None ): - weights[w_key] = os.path.join( - self.external_weights_dir, "unet." + self.external_weights + continue + elif weights[w_key] is not None: + print("Weights already found for ", w_key, "at: ", weights[w_key]) + elif w_key == "vae_decode": + keywords = ["vae", self.vae_precision] + elif w_key in ["prompt_encoder", "clip"]: + keywords = ["prompt_encoder", self.precision] + elif w_key in ["scheduled_unet", "unet"]: + keywords = ["unet", self.precision] + avail_weights = os.listdir(self.external_weights_dir) + for filename in avail_weights: + if all(str(x) in filename for x in keywords): + weights[w_key] = os.path.join(self.external_weights_dir, filename) + if not weights[w_key]: + missing.append( + " ".join([keywords[0], keywords[1], self.external_weights]) ) - else: - missing.append(w_key + "." + self.external_weights) + if len(missing) > 0: print(f"Missing files: " + ", ".join(missing)) return False, vmfbs, weights @@ -476,12 +484,20 @@ def export_submodel( self.max_length, "unet_loop", ) + pipeline_keys = [ + utils.create_safe_name(self.hf_model_name.split("/")[-1], ""), + "bs" + str(self.batch_size), + f"{str(self.width)}x{str(self.height)}", + self.precision, + str(self.max_length), + "pipeline", + ] pipeline_vmfb = utils.compile_to_vmfb( pipeline_file, self.devices["unet"]["device"], self.devices["unet"]["target"], self.ireec_flags["pipeline"], - os.path.join(self.pipeline_dir, "pipeline"), + os.path.join(self.pipeline_dir, "_".join(pipeline_keys)), return_path=True, mlir_source="str", ) @@ -495,12 +511,20 @@ def export_submodel( self.max_length, "tokens_to_image", ) + pipeline_keys = [ + utils.create_safe_name(self.hf_model_name.split("/")[-1], ""), + "bs" + str(self.batch_size), + f"{str(self.width)}x{str(self.height)}", + self.precision, + str(self.max_length), + "full_pipeline", + ] pipeline_vmfb = utils.compile_to_vmfb( pipeline_file, self.devices["unet"]["device"], self.devices["unet"]["target"], self.ireec_flags["pipeline"], - os.path.join(self.pipeline_dir, "full_pipeline"), + os.path.join(self.pipeline_dir, "_".join(pipeline_keys)), return_path=True, mlir_source="str", ) @@ -631,9 +655,13 @@ def generate_images( progress=None, ): needs_new_scheduler = ( - steps and steps != self.num_inference_steps - ) or cpu_scheduling != self.cpu_scheduling + (steps and steps != self.num_inference_steps) + or (cpu_scheduling != self.cpu_scheduling) + and self.split_scheduler + ) + self.cpu_scheduling = cpu_scheduling + if steps and not self.compiled_pipeline and needs_new_scheduler: self.num_inference_steps = steps if ( @@ -953,13 +981,10 @@ def numpy_to_pil_image(images): map = empty_pipe_dict if args.split_scheduler: - map["scheduler"] = None map["unet"] = None map.pop("scheduled_unet") map.pop("pipeline") map.pop("full_pipeline") - if args.cpu_scheduling: - map.pop("scheduler") mlirs = copy.deepcopy(map) vmfbs = copy.deepcopy(map) weights = copy.deepcopy(map) @@ -1002,20 +1027,12 @@ def numpy_to_pil_image(images): "scheduler": args.ireec_flags, } if not args.pipeline_dir: - pipe_id_list = [ - args.hf_model_name.split("/")[-1], - str(args.height), - str(args.width), - str(args.max_length), - args.precision, - args.device, - ] - if args.decomp_attn: - pipe_id_list.append("decomp") args.pipeline_dir = os.path.join( ".", - "_".join(pipe_id_list), + utils.create_safe_name(args.hf_model_name, ""), ) + if not os.path.exists(args.pipeline_dir): + os.makedirs(args.pipeline_dir, exist_ok=True) if args.input_mlir: user_mlir_list = args.input_mlir.split(",") else: @@ -1027,14 +1044,15 @@ def numpy_to_pil_image(images): args.external_weights_dir = args.pipeline_dir sdxl_pipe = SharkSDXLPipeline( args.hf_model_name, - args.scheduler_id, args.height, args.width, args.precision, args.max_length, args.batch_size, + args.num_inference_steps, devices, targets, + args.scheduler_id, ireec_flags, args.attn_spec, args.decomp_attn, @@ -1045,9 +1063,9 @@ def numpy_to_pil_image(images): custom_vae=None, vae_precision=args.vae_precision, ) + vmfbs, weights = sdxl_pipe.check_prepared(mlirs, vmfbs, weights) - if args.cpu_scheduling: - vmfbs["scheduler"] = None + if args.npu_delegate_path: extra_device_args = {"npu_delegate_path": args.npu_delegate_path} else: diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py index b8bffe768..fd9adaa8f 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py @@ -171,10 +171,18 @@ def export_scheduled_unet_model( # else: # do_classifier_free_guidance = True do_classifier_free_guidance = True - + filename_keys = [ + f"bs{batch_size}", + str(max_length), + f"{height}x{width}", + precision, + scheduler_id, + "DiffusionModule", + str(num_inference_steps), + ] safe_name = utils.create_safe_name( hf_model_name, - f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_scheduled_unet_{str(num_inference_steps)}", + "_".join(filename_keys), ) if pipeline_dir: safe_name = os.path.join(pipeline_dir, safe_name)