From 575bcd04ac7cc6892461e398e3491c5208af1179 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 11 Apr 2024 10:59:59 -0500 Subject: [PATCH] Cleanup comments and redundant code. --- core/shark_turbine/aot/builtins/jittable.py | 9 +- .../sd_inference/sdxl_split_schedulers.py | 280 ------------------ .../custom_models/sd_inference/unet.py | 14 +- .../custom_models/sd_inference/utils.py | 3 +- .../custom_models/sd_inference/vae.py | 14 +- .../sdxl_inference/sdxl_prompt_encoder.py | 39 --- .../sdxl_inference/sdxl_scheduled_unet.py | 12 - .../sdxl_inference/sdxl_schedulers.py | 5 - 8 files changed, 18 insertions(+), 358 deletions(-) delete mode 100644 models/turbine_models/custom_models/sd_inference/sdxl_split_schedulers.py diff --git a/core/shark_turbine/aot/builtins/jittable.py b/core/shark_turbine/aot/builtins/jittable.py index 6542750e3..29a90617b 100644 --- a/core/shark_turbine/aot/builtins/jittable.py +++ b/core/shark_turbine/aot/builtins/jittable.py @@ -214,13 +214,6 @@ def flat_wrapped_f(*args): if "functorch_functionalize" in self._passes: transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args) - for node in transformed_f.graph.nodes: # type: ignore - if node.op == "call_function": - if node.target == torch._ops.ops.aten.lift_fresh_copy.default: - print(f"replaced lift_fresh_copy") - node.target = torch._ops.ops.aten.clone.default - transformed_f.recompile() # type: ignore - # Ask dynamo to give us an aten graph. # TODO: Cache this for repeated calls. logger.debug("Performing dynamo.export(constraints=%r)", constraints) @@ -233,7 +226,7 @@ def flat_wrapped_f(*args): ) logger.debug("Invoking dynamo trace") gm, guards = exported_f(*flat_pytorch_args) - logger.debug("Dyanmo trace complete") + logger.debug("Dynamo trace complete") # TODO: Add debug logging for the exported graph module. # gm.print_readable() diff --git a/models/turbine_models/custom_models/sd_inference/sdxl_split_schedulers.py b/models/turbine_models/custom_models/sd_inference/sdxl_split_schedulers.py deleted file mode 100644 index 80ebf6dd2..000000000 --- a/models/turbine_models/custom_models/sd_inference/sdxl_split_schedulers.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# from @aviator19941's gist : https://gist.github.com/aviator19941/4e7967bd1787c83ee389a22637c6eea7 - -import os -import sys - -from iree import runtime as ireert -from iree.compiler.ir import Context -import numpy as np -from shark_turbine.aot import * -from turbine_models.custom_models.sd_inference import utils -import torch -import torch._dynamo as dynamo -from diffusers import UNet2DConditionModel -from shark_turbine.dynamo.passes import ( - DEFAULT_DECOMPOSITIONS, -) - -import safetensors -import argparse - -parser = argparse.ArgumentParser() -parser.add_argument( - "--hf_auth_token", - type=str, - help="The Hugging Face auth token, required", - default=None, -) -parser.add_argument( - "--hf_model_name", - type=str, - help="HF model name", - default="stabilityai/stable-diffusion-xl-base-1.0", -) -parser.add_argument( - "--scheduler_id", - type=str, - help="Scheduler ID", - default="PNDM", -) -parser.add_argument( - "--num_inference_steps", type=int, default=30, help="Number of inference steps" -) -parser.add_argument( - "--batch_size", type=int, default=1, help="Batch size for inference" -) -parser.add_argument( - "--height", type=int, default=1024, help="Height of Stable Diffusion" -) -parser.add_argument("--width", type=int, default=1024, help="Width of Stable Diffusion") -parser.add_argument( - "--precision", type=str, default="fp32", help="Precision of Stable Diffusion" -) -parser.add_argument( - "--max_length", type=int, default=77, help="Sequence Length of Stable Diffusion" -) -parser.add_argument( - "--compile_to", type=str, default="torch", help="torch, linalg, vmfb" -) -parser.add_argument("--external_weight_path", type=str, default="") -parser.add_argument( - "--external_weights", - type=str, - default=None, - help="saves ir/vmfb without global weights for size and readability, options [safetensors]", -) -parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") -# TODO: Bring in detection for target triple -parser.add_argument( - "--iree_target_triple", - type=str, - default="x86_64-unknown-unknown-eabi-elf", - help="Specify vulkan target triple or rocm/cuda target device.", -) -parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") - - -class SDXLScheduler(torch.nn.Module): - def __init__( - self, - hf_model_name, - num_inference_steps, - scheduler, - hf_auth_token=None, - precision="fp32", - ): - super().__init__() - self.scheduler = scheduler - self.scheduler.set_timesteps(num_inference_steps) - self.guidance_scale = 7.5 - - def schd_add_init_noise(self, sample): - # print(sample, self.scheduler.init_noise_sigma) - sample = sample * self.scheduler.init_noise_sigma - return sample - - def schd_scale_model_input(self, sample, t): - latent_model_input = torch.cat([sample] * 2) - t = t.unsqueeze(0) - # print('UNSQUEEZE T:', t) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, timestep=t - ) - return latent_model_input - - def schd_step(self, sample, t, noise_pred): - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - sample = self.scheduler.step(noise_pred, t, sample, return_dict=False)[0] - return sample - - -def export_scheduler( - scheduler, - hf_model_name, - batch_size, - height, - width, - hf_auth_token=None, - compile_to="torch", - external_weights=None, - external_weight_path=None, - device=None, - target_triple=None, - max_alloc=None, -): - mapper = {} - utils.save_external_weights( - mapper, scheduler, external_weights, external_weight_path - ) - - decomp_list = DEFAULT_DECOMPOSITIONS - - decomp_list.extend( - [ - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, - torch.ops.aten._scaled_dot_product_flash_attention.default, - ] - ) - # encoder_hidden_states_sizes = (2, 77, 768) - # if hf_model_name == "stabilityai/stable-diffusion-2-1-base": - # encoder_hidden_states_sizes = (2, 77, 1024) - - # tensor shapes for tracing - # sample = torch.randn(1, 4, 128, 128) - sample = (batch_size, 4, height // 8, width // 8) - noise_pred = (batch_size * 2, 4, height // 8, width // 8) - - class CompiledScheduler(CompiledModule): - if external_weights: - params = export_parameters( - scheduler, external=True, external_scope="", name_mapper=mapper.get - ) - else: - params = export_parameters(scheduler) - - def main_init_noise( - self, - sample=AbstractTensor(*sample, dtype=torch.float32), - ): - return jittable(scheduler.schd_add_init_noise)(sample) - - def main_scale_model( - self, - sample=AbstractTensor(*sample, dtype=torch.float32), - t=AbstractTensor(1, dtype=torch.int32), - ): - return jittable(scheduler.schd_scale_model_input)(sample, t) - - def main_step( - self, - noise_pred=AbstractTensor(*noise_pred, dtype=torch.float32), - t=AbstractTensor(1, dtype=torch.int32), - ): - return jittable(scheduler.schd_step)(noise_pred, t) - - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledScheduler(context=Context(), import_to=import_to) - - module_str = str(CompiledModule.get_mlir_module(inst)) - - safe_name = utils.create_safe_name(hf_model_name, "-scheduler") - with open(f"{safe_name}.mlir", "w+") as f: - f.write(module_str) - print("Saved to", safe_name + ".mlir") - - if compile_to != "vmfb": - return module_str - else: - utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) - - -# hf_model_name = "stabilityai/stable-diffusion-xl-base-1.0" -# from diffusers import ( -# EulerDiscreteScheduler, -# ) -# scheduler = EulerDiscreteScheduler.from_pretrained(hf_model_name, subfolder="scheduler") -# scheduler_module = SDXLScheduler(hf_model_name, 3, scheduler, hf_auth_token=None, precision="fp32") -# sample = torch.randn(1, 4, 128, 128) -# prompt_embeds = torch.randn(2, 77, 2048) -# text_embeds = torch.randn(2, 1280) -# time_ids = torch.randn(2, 6) - -# sample = (1, 4, 128, 128) -# prompt_embeds = (2, 77, 2048) -# text_embeds = (2, 1280) -# time_ids = (2, 6) -# sample=AbstractTensor(*sample, dtype=torch.float32), -# prompt_embeds=AbstractTensor(*prompt_embeds, dtype=torch.float32), -# text_embeds = AbstractTensor(*text_embeds, dtype=torch.float32), -# time_ids = AbstractTensor(*time_ids, dtype=torch.float32), - -# inputs = (sample, prompt_embeds, text_embeds, time_ids,) - -# print(scheduler_module.forward(sample, prompt_embeds, text_embeds, time_ids)) - - -# from torch.fx.experimental.proxy_tensor import make_fx -# fx_g = make_fx( -# scheduler_module, -# decomposition_table={}, -# tracing_mode="symbolic", -# _allow_non_fake_inputs=True, -# _allow_fake_constant=False, -# )(*inputs) -# print(fx_g) - - -if __name__ == "__main__": - args = parser.parse_args() - hf_model_name = "stabilityai/stable-diffusion-xl-base-1.0" - from diffusers import ( - EulerDiscreteScheduler, - ) - - scheduler = EulerDiscreteScheduler.from_pretrained( - hf_model_name, subfolder="scheduler" - ) - scheduler_module = SDXLScheduler( - args.hf_model_name, - args.num_inference_steps, - scheduler, - hf_auth_token=None, - precision="fp32", - ) - - # sample = torch.randn((1, 4, 128, 128)) - # # sample = (batch_size, 4, height // 8, width // 8) - # prompt_embeds = torch.randn((2, 77, 2048)) - # text_embeds = torch.randn((2, 1280)) - # time_ids = torch.randn((2, 6), dtype=torch.int32) - # print(scheduler_module.forward(sample, prompt_embeds, text_embeds, time_ids)) - - print("export scheduler begin") - mod_str = export_scheduler( - scheduler_module, - args.hf_model_name, - args.batch_size, - args.height, - args.width, - args.hf_auth_token, - args.compile_to, - args.external_weights, - args.external_weight_path, - args.device, - args.iree_target_triple, - args.vulkan_max_allocation, - ) - print("export scheduler complete") - safe_name = utils.create_safe_name(args.hf_model_name, "-scheduler") - with open(f"{safe_name}.mlir", "w+") as f: - f.write(mod_str) - print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sd_inference/unet.py b/models/turbine_models/custom_models/sd_inference/unet.py index 8e47ceea9..18657ae86 100644 --- a/models/turbine_models/custom_models/sd_inference/unet.py +++ b/models/turbine_models/custom_models/sd_inference/unet.py @@ -101,15 +101,17 @@ def export_unet_model( target_triple=None, max_alloc=None, upload_ir=False, + decomp_attn=True, ): mapper = {} decomp_list = DEFAULT_DECOMPOSITIONS - decomp_list.extend( - [ - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, - torch.ops.aten._scaled_dot_product_flash_attention.default, - ] - ) + if decomp_attn: + decomp_list.extend( + [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + ] + ) dtype = torch.float16 if precision == "fp16" else torch.float32 unet_model = unet_model.to(dtype) utils.save_external_weights( diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index dc6ec1c7e..2ce0ef601 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -8,12 +8,10 @@ EulerDiscreteScheduler, ) -winograd_params = "keys=unet.down_blocks.2.resnets.0.conv2.weight keys=unet.down_blocks.2.resnets.1.conv1.weight keys=unet.down_blocks.2.resnets.1.conv2.weight keys=unet.mid_block.resnets.0.conv1.weight keys=unet.mid_block.resnets.0.conv2.weight keys=unet.mid_block.resnets.1.conv1.weight keys=unet.mid_block.resnets.1.conv2.weight keys=unet.up_blocks.0.resnets.0.conv2.weight keys=unet.up_blocks.0.resnets.1.conv2.weight keys=unet.up_blocks.0.resnets.2.conv2.weight keys=unet.up_blocks.0.resnets.0.conv1.weight keys=unet.up_blocks.0.resnets.1.conv1.weight keys=unet.up_blocks.0.resnets.2.conv1.weight keys=unet.up_blocks.0.upsamplers.0.conv.weight" # If flags are verified to work on a specific model and improve performance without regressing numerics, add them to this dictionary. If you are working with bleeding edge flags, please add them manually with the --ireec_flags argument. gfx94X_flags = { "all": [ "--iree-global-opt-propagate-transposes=true", - "--iree-opt-const-eval=false", "--iree-opt-outer-dim-concat=true", "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode", "--iree-vm-target-truncate-unsupported-floats", @@ -95,6 +93,7 @@ def compile_to_vmfb( "--iree-hal-target-backends=rocm", "--iree-rocm-target-chip=" + target_triple, "--verify=false", + "--iree-opt-const-eval=false", ] ) elif device == "cuda": diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 2c83d1b72..0916acda0 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -116,15 +116,17 @@ def export_vae_model( max_alloc=None, variant="decode", upload_ir=False, + decomp_attn=True, ): mapper = {} decomp_list = DEFAULT_DECOMPOSITIONS - decomp_list.extend( - [ - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, - torch.ops.aten._scaled_dot_product_flash_attention.default, - ] - ) + if decomp_attn: + decomp_list.extend( + [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + ] + ) dtype = torch.float16 if precision == "fp16" else torch.float32 vae_model = vae_model.to(dtype) utils.save_external_weights( diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index aca838c3d..1c6b6331c 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -39,45 +39,6 @@ def __init__( ) self.do_classifier_free_guidance = do_classifier_free_guidance - # self.tokenizer_1 = CLIPTokenizer.from_pretrained( - # hf_model_name, - # subfolder="tokenizer", - # token=hf_auth_token, - # model_max_length=max_length, - # ) - # self.tokenizer_2 = CLIPTokenizer.from_pretrained( - # hf_model_name, - # subfolder="tokenizer_2", - # token=hf_auth_token, - # model_max_length=max_length, - # ) - # def tokenize(self, prompt, negative_prompt): - # text_input_ids_1 = self.tokenizer_1( - # prompt, - # padding="max_length", - # truncation=True, - # return_tensors="pt", - # ).input_ids - # uncond_input_ids_1 = self.tokenizer_2( - # negative_prompt, - # padding="max_length", - # truncation=True, - # return_tensors="pt", - # ).input_ids - # text_input_ids_2 = self.tokenizer_2( - # prompt, - # padding="max_length", - # truncation=True, - # return_tensors="pt", - # ).input_ids - # uncond_input_ids_2 = self.tokenizer_2( - # negative_prompt, - # padding="max_length", - # truncation=True, - # return_tensors="pt", - # ).input_ids - # return text_input_ids_1, uncond_input_ids_1, text_input_ids_2, uncond_input_ids_2 - def forward( self, text_input_ids_1, text_input_ids_2, uncond_input_ids_1, uncond_input_ids_2 ): diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py index d00ca1c35..f74c707e7 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py @@ -276,18 +276,6 @@ def export_pipeline_module(args): full_pipeline_file = ( pipe_prefix + "f32" if args.precision == "fp32" else pipe_prefix + "f16" ) - # pipeline_vmfb_path = utils.compile_to_vmfb( - # os.path.join( - # os.path.realpath(os.path.dirname(__file__)), pipeline_file + ".mlir" - # ), - # args.device, - # args.iree_target_triple, - # args.ireec_flags, - # "sdxl_pipeline_" + args.precision + "_" + args.iree_target_triple, - # return_path=True, - # const_expr_hoisting=False, - # mlir_source="file", - # ) full_pipeline_vmfb_path = utils.compile_to_vmfb( os.path.join( os.path.realpath(os.path.dirname(__file__)), full_pipeline_file + ".mlir" diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py index 568d616b2..a3ae29595 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py @@ -120,12 +120,7 @@ def export_scheduler( torch.ops.aten._scaled_dot_product_flash_attention.default, ] ) - # encoder_hidden_states_sizes = (2, 77, 768) - # if hf_model_name == "stabilityai/stable-diffusion-2-1-base": - # encoder_hidden_states_sizes = (2, 77, 1024) - # tensor shapes for tracing - # sample = torch.randn(1, 4, 128, 128) sample = (batch_size, 4, height // 8, width // 8) prompt_embeds = (2, 77, 2048) text_embeds = (2, 1280)