diff --git a/.gitignore b/.gitignore index f5fe49941..54f4c40cc 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,7 @@ wheelhouse *.safetensors *.gguf *.vmfb -*.mlir \ No newline at end of file +*.mlir +*.npy +*.png +*tmp* diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py b/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py new file mode 100644 index 000000000..535135daa --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py @@ -0,0 +1,345 @@ +import argparse +import os +from pathlib import Path + + +def path_expand(s): + return Path(s).expanduser().resolve() + + +def is_valid_file(arg): + if not os.path.exists(arg): + return None + else: + return arg + + +# Note: this is where command-line options for the scripts in this directory +# are defined along with their defaults. Thus, they should not be referenced +# within modelling or inference code, only at the entry point to the script. + +# We should consider separating out the options that are "model configs" from +# the options that control the compiler, runtime, and script behavior, +# when applicable, as the formermost would best be kept in a separate +# config or imported from huggingface. + +p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter +) + +############################################################################## +# SD3 Source Options +############################################################################## + +p.add_argument( + "--hf_auth_token", + type=str, + help="The Hugging Face auth token, if required", + default=None, +) +p.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="stabilityai/stable-diffusion-3-medium-diffusers", +) +p.add_argument( + "--scheduler_id", + type=str, + help="Scheduler ID", + default="EulerDiscrete", +) +p.add_argument( + "--model_path", + type=str, + help="Path to model .safetensors from which the model is defined.", + default=None, +) +p.add_argument( + "--vae_model_path", + type=str, + help="Path to vae model .safetensors from which the model is defined.", + default=None, +) + +############################################################################## +# SD3 Inference Options +# These options are used to control runtime parameters for SD3 inference. +############################################################################## + +p.add_argument( + "--prompt", + type=str, + default=" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + help="Prompt input to stable diffusion.", +) + +p.add_argument( + "--negative_prompt", + type=str, + default="Watermark, blurry, oversaturated, low resolution, pollution", + help="Negative prompt input to stable diffusion.", +) + +p.add_argument( + "--num_inference_steps", type=int, default=30, help="Number of UNet inference steps" +) + +p.add_argument( + "--batch_count", + type=int, + default=1, + help="Number of batches to run for a single prompt", +) + +p.add_argument( + "--guidance_scale", + type=float, + default=7.5, + help="Scale by which to adjust prompt guidance to the unconditional noise prediction output of UNet after each iteration.", +) + +p.add_argument( + "--seed", type=float, default=0, help="Seed for random number/latents generation." +) + +p.add_argument( + "--denoise", + type=float, + default=1.0, + help="Denoising factor for image to image", +) + +p.add_argument( + "--external_weight_path", + type=str, + default="", + help="Path to external weights file, for jobs with one weights filepath. When importing, this is used to specify where to save the model weights, and at runtime, this is used to specify where to load the model weights from.", +) + +p.add_argument( + "--external_weights_dir", + type=str, + default="", + help="Directory containing external weights for a job that requires more than one weights file. When importing, this is used to specify where to save the model weights, and at runtime, this is used to specify where to load the model weights from. Files will then be saved according to the parameters that make them unique, i.e. ___.", +) + +p.add_argument( + "--vmfb_path", type=str, default="", help="path to vmfb containing compiled module" +) + +p.add_argument( + "--pipeline_vmfb_path", + type=str, + default="", + help="path to vmfb containing compiled meta-module", +) + +p.add_argument( + "--scheduler_vmfb_path", + type=str, + default="", + help="path to vmfb containing compiled scheduler", +) + +p.add_argument( + "--split_scheduler", + default=False, + action="store_true", + help="Use a decoupled unet and scheduler for better QOL.", +) + +p.add_argument( + "--cpu_scheduling", + default=False, + action="store_true", + help="Run scheduling on torch cpu (will be slower due to data movement costs).", +) + +p.add_argument( + "--external_weight_file", + type=str, + default=None, + help="Path to external weights, used in benchmark scripts.", +) + +p.add_argument( + "--pipeline_dir", + type=str, + default=None, + help="Directory to save pipeline artifacts", +) + +p.add_argument( + "--compiled_pipeline", + default=False, + action="store_true", + help="Do one-shot inference from tokens to image in a shrink-wrapped pipeline binary.", +) + +############################################################################## +# SD3 Modelling Options +# These options are used to control model defining parameters for SD3. +# These are MLIR - changing variables! If you change them, you will need +# to import/download and recompile the model. +############################################################################## + +p.add_argument("--batch_size", type=int, default=1, help="Batch size for inference") +p.add_argument( + "--height", type=int, default=1024, help="Height of Stable Diffusion output image." +) +p.add_argument( + "--width", type=int, default=1024, help="Width of Stable Diffusion output image" +) +p.add_argument( + "--precision", + type=str, + default="fp16", + help="Precision of Stable Diffusion weights and graph.", +) +p.add_argument( + "--max_length", type=int, default=77, help="Sequence Length of Stable Diffusion" +) +p.add_argument("--vae_variant", type=str, default="decode", help="encode, decode") +p.add_argument( + "--shift", type=float, default=3, help="Sampling shift value for sd3 scheduling" +) +p.add_argument( + "--vae_decomp_attn", + type=bool, + default=False, + help="Decompose attention for VAE decode only at fx graph level", +) + +############################################################################## +# SD3 script general options. +############################################################################## + +p.add_argument("--compile_to", type=str, default="mlir", help="torch, linalg, vmfb") +p.add_argument( + "--init_image", + type=str, + default=None, + help="Path to initial image for inference", +) +p.add_argument( + "--external_weights", + type=str, + default=None, + choices=["safetensors", "irpa", "gguf", None], + help="Externalizes model weights from the torch dialect IR and its successors", +) + +# See --external_weight_path and --external_weight_dir to specify where to save the model weights. +p.add_argument( + "--weights_only", + action="store_true", + help="Just grab the weights for your model and exit instead of exporting any IR.", +) +p.add_argument( + "--compare_vs_torch", + action="store_true", + help="Runs both turbine vmfb and a torch model to compare results", +) +p.add_argument( + "--decomp_attn", + default=False, + action="store_true", + help="Decompose attention at fx graph level", +) +p.add_argument( + "--exit_on_vmfb", + default=True, + action="store_false", + help="Exit program on vmfb compilation completion. Most scripts will also save .mlir if this is disabled.", +) +p.add_argument( + "--input_mlir", + type=str, + default=None, + help="Path to input mlir file to compile. Comma-separate paths to provide more than one input to pipelines.", +) +p.add_argument( + "--download_mlir", + default=False, + action="store_true", + help="Download missing mlir files from Azure storage.", +) +p.add_argument( + "--container_name", + type=str, + default=None, + help="Azure storage container name to download mlir files from.", +) +p.add_argument( + "--export", + type=str, + default="all", + help="clip, mmdit, vae, all") +p.add_argument( + "--output", + type=str, + default="SD3_output.png", + help="Path to output file for generated images.", +) + +############################################################################## +# IREE Compiler Options +############################################################################## + +p.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") + +p.add_argument( + "--rt_device", + type=str, + default="local-task", + help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.", +) + +# TODO: Bring in detection for target triple +p.add_argument( + "--iree_target_triple", + type=str, + default="", + help="Specify vulkan target triple or rocm/cuda target device.", +) + +p.add_argument("--ireec_flags", type=str, default="", help="extra iree-compile options") + +p.add_argument( + "--attn_flags", + type=str, + default="", + help="extra iree-compile options for models with iree_linalg_ext.attention ops.", +) + +p.add_argument( + "--attn_spec", + type=str, + default=None, + help="extra iree-compile options for models with iree_linalg_ext.attention ops. Set this to 'default' if you are using mfma-capable hardware with ROCM.", +) + +p.add_argument( + "--clip_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling CLIP/prompt_encoder. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + +p.add_argument( + "--vae_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling VAE. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + +p.add_argument( + "--unet_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling unet. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + + +args, unknown = p.parse_known_args() diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_full.py b/models/turbine_models/custom_models/sd3_inference/sd3_full.py new file mode 100644 index 000000000..f1335b4de --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_full.py @@ -0,0 +1,266 @@ +# Copyrigh 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys + +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo + +import safetensors +import argparse +from turbine_models.turbine_tank import turbine_tank +SEED = 1 + +def export_vae(model, + height, + width, + compile_to="torch", + external_weight_prefix=None, + device=None, + target_triple=None, + max_alloc="", + upload_ir=False, + dtype=torch.float32, + ): + + mapper = {} + utils.save_external_weights( + mapper, model, "safetensors", external_weight_prefix + ) + latent_shape = [1, 16, height//8, width//8] + input_arg = torch.empty(latent_shape) + input_arg = (input_arg.to(dtype),) + if external_weight_prefix != None and len(external_weight_prefix)>1: + externalize_module_parameters(model) + + exported = export(model, args=input_arg) + + module_str = str(exported.mlir_module) + safe_name = utils.create_safe_name(str(dtype).lstrip("torch."), "_mmdit") + if compile_to != "vmfb": + return module_str + else: + print("compiling to vmfb") + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) + return module_str + +def export_unet_dynamic( + unet_model, + height, + width, + compile_to="torch", + external_weight_path=None, + device=None, + target_triple=None, + max_alloc="", + upload_ir=False, + dtype=torch.float32, +): + + cond_shape = [1,154, 4096]#77, 4096] + pool_shape = [1, 2048] + latent_shape = [1, 16, height//8, width//8] + if dtype == torch.float16: + unet_model=unet_model.half() + mapper = {} + utils.save_external_weights( + mapper, unet_model, "safetensors", external_weight_path + ) + + if weights_only: + return external_weight_path + + fxb = FxProgramsBuilder(unet_model) + + sigmas = torch.export.Dim("sigmas") + dynamic_shapes = {"sigmas": {0: sigmas}, "latent": {}, "noise": {}} + example_init_args = [ + torch.empty([19], dtype=dtype), + torch.empty(latent_shape, dtype=dtype), + torch.empty(latent_shape, dtype=dtype), + ] + example_sampling_args = [ + torch.empty(latent_shape, dtype=dtype), + torch.empty(1, dtype=dtype), + torch.empty(1, dtype=dtype), + torch.empty(cond_shape, dtype=dtype), + torch.empty(pool_shape, dtype=dtype), + torch.empty(cond_shape, dtype=dtype), + torch.empty(pool_shape, dtype=dtype), + torch.empty(1, dtype=dtype), + ] + @fxb.export_program( + args=(example_init_args,), + dynamic_shapes=dynamic_shapes + ) + def _initialize(module, inputs): + # 1.0 is denoise currently symfloat not supported in fx_importer + return module.init_dynamic(*inputs) + + @fxb.export_program(args=(example_sampling_args,)) + def _do_sampling(module, inputs): + return module.do_sampling(*inputs) + + class CompiledTresleches(CompiledModule): + initialize = _initialize + do_sampling = _do_sampling + # _vae_decode = vae_decode + + if external_weights: + externalize_module_parameters(unet_model) + save_module_parameters(external_weight_path, unet_model) + + inst = CompiledTresleches(context=Context(), import_to="IMPORT") + module_str = str(CompiledModule.get_mlir_module(inst)) + print("exported model") + + safe_name = utils.create_safe_name(str(dtype).lstrip("torch."), "_mmdit") + if compile_to != "vmfb": + return module_str + else: + print("compiling to vmfb") + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) + return module_str + +def export_preprocessor( + model, + compile_to="torch", + external_weight_path=None, + device=None, + target_triple=None, + max_alloc="", + dtype=torch.float32, + height=512, + width=512, + ): + external_weights="safetensors" + def get_noise(): + latent = torch.ones(1, 16, height // 8, width // 8, device="cpu") * 0.0609 + generator = torch.manual_seed(SEED) + return torch.randn(latent.size(), dtype=latent.dtype, layout=latent.layout, generator=generator, device="cpu") + + input_args = [torch.empty([1,77,2], dtype=torch.int64) for x in range(6)] + input_args += get_noise() + if dtype==torch.float16: + model = model.half() + + mapper = {} + + utils.save_external_weights( + mapper, model, external_weights, external_weight_path + ) + + if external_weight_path !=None and len(external_weight_path)>1: + print("externalizing weights") + externalize_module_parameters(model) + + exported = export(model, args=tuple(input_args)) + print("exported model") + + #import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + #inst = CompiledTresleches(context=Context(), import_to=import_to) + + #module_str = str(CompiledModule.get_mlir_module(inst)) + module_str = str(exported.mlir_module) + safe_name = utils.create_safe_name("sd3", "clips") + if compile_to != "vmfb": + return module_str + else: + print("compiling to vmfb") + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) + return module_str + +@torch.no_grad() +def main(args): + import turbine_sd3 + from safetensors import safe_open + vulkan_max_allocation="4294967296" if args.device=="vulkan" else "" + #st_file = "/mnt2/tresleches/models/sd3_8b_beta.safetensors" + st_file = "/mnt2/tresleches/models/sd3_2b_512_alpha.safetensors" + dtype = torch.float32 + if args.precision == "f16": + dtype = torch.float16 + elif args.precision == "bf16": + dtype = torch.bfloat16 + print(args.export) + + + if args.export in ["dynamic"]: + print("exporting dynamic") + unet_model = turbine_sd3.SD3Inferencer(model=st_file, vae=turbine_sd3.VAEFile, shift=1.0, dtype=dtype).eval() + mod_str = export_unet_dynamic( + unet_model=unet_model, + height=args.height, + width=args.width, + compile_to=args.compile_to, + external_weight_path=args.external_weight_path, + device=args.device, + target_triple=args.iree_target_triple, + max_alloc=vulkan_max_allocation, + upload_ir=False, + dtype=dtype, + ) + safe_name = utils.create_safe_name("hc_sd3", "-unet") + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") + export_pre = args.export in ["all", "clip"] + print(export_pre) + if export_pre: + print("exporting preprocessor") + pre = turbine_sd3.Preprocess() + mod_str = export_preprocessor( + model=pre, + compile_to=args.compile_to, + external_weight_path=args.external_weight_path, + device=args.device, + target_triple=args.iree_target_triple, + max_alloc=vulkan_max_allocation, + dtype=dtype, + height=args.height, + width=args.width, + ) + safe_name = utils.create_safe_name("hc_sd3", "_preprocess") + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") + should_export_vae = args.export in ["all", "vae"] + if should_export_vae: + print("exporting vae") + from turbine_impls import SDVAE + with turbine_sd3.safe_open(turbine_sd3.VAEFile, framework="pt", device="cpu") as f: + vae = SDVAE(device="cpu", dtype=dtype).eval().cpu() + prefix = "" + if any(k.startswith("first_stage_model.") for k in f.keys()): + prefix = "first_stage_model." + turbine_sd3.load_into(f, vae, prefix, "cpu", dtype) + print("Something") + mod_str = export_vae( + model=vae, + height=args.height, + width=args.width, + compile_to=args.compile_to, + external_weight_prefix=args.external_weight_path, + device=args.device, + target_triple=args.iree_target_triple, + max_alloc=vulkan_max_allocation, + dtype=dtype, + ) + safe_name = utils.create_safe_name("hc_sd3", "_vae") + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + torch._dynamo.config.capture_scalar_outputs = True + main(args) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_log.txt b/models/turbine_models/custom_models/sd3_inference/sd3_log.txt new file mode 100644 index 000000000..292bf6be8 Binary files /dev/null and b/models/turbine_models/custom_models/sd3_inference/sd3_log.txt differ diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py new file mode 100644 index 000000000..fe3ae2b4e --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py @@ -0,0 +1,115 @@ +import argparse +from turbine_models.model_runner import vmfbRunner +from turbine_models.custom_models.sd_inference import utils, schedulers +from iree import runtime as ireert +import torch +import numpy as np +from tqdm.auto import tqdm +from shark_turbine.ops.iree import trace_tensor + +torch.random.manual_seed(0) + + +def run_mmdit_turbine( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + lora_scale, + args, +): + torch_dtype = torch.float16 if args.precision == "fp16" else torch.float32 + mmdit_runner = vmfbRunner( + args.device, + args.vmfb_path, + args.external_weight_path, + ) + iree_inputs = [ + ireert.asdevicearray(mmdit_runner.config.device, hidden_states), + ireert.asdevicearray(mmdit_runner.config.device, encoder_hidden_states), + ireert.asdevicearray(mmdit_runner.config.device, pooled_projections), + ireert.asdevicearray(mmdit_runner.config.device, timestep), + ireert.asdevicearray(mmdit_runner.config.device, lora_scale), + ] + noise_pred = mmdit_runner.ctx.modules.compiled_mmdit["run_forward"](*iree_inputs).to_host() + return noise_pred + + +@torch.no_grad() +def run_diffusers_mmdit( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + lora_scale, + args, +): + from turbine_models.custom_models.sd3_inference.turbine_mmdit import MMDiTModel + mmdit_model = MMDiTModel( + args.hf_model_name, + dtype=torch.float32, + ) + noise_pred = mmdit_model.forward( + hidden_states.float(), encoder_hidden_states.float(), pooled_projections.float(), timestep.float(), lora_scale.float() + ) + + return noise_pred.numpy() + + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + import numpy as np + import os + + torch.random.manual_seed(0) + + if args.precision == "fp16": + dtype = torch.float16 + else: + dtype = torch.float32 + + hidden_states = torch.randn( + (args.batch_size, 16, args.height // 8, args.width // 8), dtype=dtype + ) + encoder_hidden_states = torch.randn( + (args.batch_size, args.max_length, 4096), dtype=dtype + ) + pooled_projections = torch.randn((args.batch_size, 2048), dtype=dtype) + timestep = torch.tensor([0], dtype=dtype) + lora_scale = torch.tensor([1.0], dtype=dtype) + + turbine_output = run_mmdit_turbine( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + lora_scale, + args, + ) + print( + "TURBINE SPLIT OUTPUT:", + turbine_output, + turbine_output.shape, + turbine_output.dtype, + ) + turbine_output = turbine_output + + if args.compare_vs_torch: + print("generating torch output: ") + torch_output = run_diffusers_mmdit( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + lora_scale, + args, + ) + print("torch OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + + print("\n(torch (comfy) image latents to iree image latents): ") + + np.testing.assert_allclose( + turbine_output, torch_output, rtol=4e-2, atol=4e-2 + ) + print("passed!") + diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py new file mode 100644 index 000000000..87492a701 --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py @@ -0,0 +1,322 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +from typing import List + +import torch +from shark_turbine.aot import * +import shark_turbine.ops.iree as ops +from iree.compiler.ir import Context +import iree.runtime as ireert +import numpy as np + +from diffusers import ( + FlowMatchEulerDiscreteScheduler, +) + +from turbine_models.turbine_tank import turbine_tank +from turbine_models.custom_models.sd_inference import utils +from turbine_models.model_runner import vmfbRunner + + +class SharkSchedulerWrapper: + def __init__(self, rt_device, vmfb): + self.runner = vmfbRunner(rt_device, vmfb, None) + + def initialize(self, sample): + sample, time_ids, steps, timesteps = self.runner.ctx.modules.compiled_scheduler[ + "run_init" + ](sample) + return sample, steps.to_host(), timesteps + + def prepare_model_input(self, sample, t, timesteps): + return self.runner.ctx.modules.compiled_scheduler["run_prep"]( + sample, t, timesteps + ) + + def step(self, noise_pred, t, sample, guidance_scale, step_index): + return self.runner.ctx.modules.compiled_scheduler["run_step"]( + noise_pred, t, sample, guidance_scale, step_index + ) + + +class FlowSchedulingModel(torch.nn.Module): + def __init__( + self, + hf_model_name, + num_inference_steps, + dtype, + ): + super().__init__() + # For now, assumes SDXL implementation. May not need parametrization for other models, + # but keeping hf_model_name in case. + self.model = FlowMatchEulerDiscreteScheduler.from_pretrained(hf_model_name, subfolder="scheduler") + self.do_classifier_free_guidance = True + self.model.set_timesteps(num_inference_steps) + self.timesteps = self.model.timesteps + self.dtype = dtype + + # TODO: Make steps dynamic here + def initialize(self, sample): + step_count = torch.tensor(len(self.timesteps)) + timesteps = self.model.timesteps + # ops.trace_tensor("timesteps", self.timesteps) + return ( + sample.type(self.dtype), + step_count, + timesteps.type(torch.float32), + ) + + def prepare_model_input(self, sample, t, timesteps): + t = timesteps[t] + t = t.expand(sample.shape[0]) + if self.do_classifier_free_guidance: + latent_model_input = torch.cat([sample] * 2) + else: + latent_model_input = sample + return latent_model_input.type(self.dtype), t.type(self.dtype) + + def step(self, noise_pred, t, sample, guidance_scale, i): + self.model._step_index = i + + if self.do_classifier_free_guidance: + noise_preds = noise_pred.chunk(2) + noise_pred = noise_preds[0] + guidance_scale * ( + noise_preds[1] - noise_preds[0] + ) + sample = self.model.step(noise_pred, t, sample, return_dict=False)[0] + return sample.type(self.dtype) + + +class SharkSchedulerCPUWrapper: + @torch.no_grad() + def __init__( + self, scheduler, batch_size, num_inference_steps, dest_device, latents_dtype + ): + self.do_classifier_free_guidance = True + self.module = scheduler + self.dest = dest_device + self.dtype = latents_dtype + self.batch_size = batch_size + self.module.set_timesteps(num_inference_steps) + self.timesteps = self.module.timesteps + self.torch_dtype = ( + torch.float32 if latents_dtype == "float32" else torch.float16 + ) + + def initialize(self, sample): + if isinstance(sample, ireert.DeviceArray): + sample = torch.tensor(sample.to_host(), dtype=torch.float32) + step_indexes = torch.tensor(len(self.module.timesteps)) + timesteps = self.timesteps + return sample, step_indexes, timesteps + + def scale_model_input(self, sample, t, timesteps): + if self.do_classifier_free_guidance: + sample = torch.cat([sample] * 2) + t = timesteps[t] + t = t.expand(sample.shape[0]) + t = ireert.asdevicearray(self.dest, [t], self.dtype) + sample = ireert.asdevicearray(self.dest, sample, self.dtype) + return sample, t + + def step(self, noise_pred, t, latents, guidance_scale, i): + if isinstance(t, ireert.DeviceArray): + t = torch.tensor(t.to_host()) + if isinstance(guidance_scale, ireert.DeviceArray): + guidance_scale = torch.tensor(guidance_scale.to_host()) + noise_pred = torch.tensor(noise_pred.to_host()) + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + return self.module.step( + noise_pred, + t, + latents, + return_dict=False, + )[0] + + +@torch.no_grad() +def export_scheduler_model( + hf_model_name: str, + batch_size: int = 1, + height: int = 512, + width: int = 512, + num_inference_steps: int = 30, + precision: str = "fp16", + compile_to: str = "torch", + device: str = None, + target_triple: str = None, + ireec_flags: str = None, + exit_on_vmfb: bool = False, + pipeline_dir: str = None, + input_mlir: str = None, + upload_ir=False, +): + dtype = torch.float16 if precision == "fp16" else torch.float32 + scheduler_module = FlowSchedulingModel( + hf_model_name, num_inference_steps, dtype + ) + if pipeline_dir: + vmfb_names = [ + "EulerFlowScheduler", + str(num_inference_steps), + ] + vmfb_name = "_".join(vmfb_names) + safe_name = os.path.join(pipeline_dir, vmfb_name) + else: + vmfb_names = [ + "EulerFlowScheduler", + f"bs{args.batch_size}_{args.height}x{args.width}", + precision, + str(num_inference_steps), + target_triple, + ] + vmfb_name = "_".join(vmfb_names) + safe_name = utils.create_safe_name(hf_model_name, "_" + vmfb_name) + + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target_triple, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + ) + return vmfb_path + + do_classifier_free_guidance = True + if do_classifier_free_guidance: + init_batch_dim = 2 + else: + init_batch_dim = 1 + + sample = ( + batch_size, + 16, + height // 8, + width // 8, + ) + noise_pred_shape = ( + batch_size * init_batch_dim, + 16, + height // 8, + width // 8, + ) + example_init_args = [torch.empty(sample, dtype=dtype)] + example_prep_args = ( + torch.empty(sample, dtype=dtype), + torch.empty(1, dtype=torch.int64), + torch.empty([19], dtype=torch.float32), + ) + timesteps = torch.export.Dim("timesteps") + prep_dynamic_args = { + "sample": {}, + "t": {}, + "timesteps": {0: timesteps}, + } + example_step_args = [ + torch.empty(noise_pred_shape, dtype=dtype), + torch.empty(1, dtype=dtype), + torch.empty(sample, dtype=dtype), + torch.empty(1, dtype=dtype), + torch.empty(1, dtype=torch.int64), + ] + + fxb = FxProgramsBuilder(scheduler_module) + + @fxb.export_program( + args=(example_init_args,), + ) + def _initialize(module, sample): + return module.initialize(*sample) + + @fxb.export_program( + args=example_prep_args, + dynamic_shapes=prep_dynamic_args, + ) + def _prep(module, sample, t, timesteps): + return module.prepare_model_input(sample, t, timesteps) + + @fxb.export_program( + args=(example_step_args,), + ) + def _step(module, inputs): + return module.step(*inputs) + + decomp_list = [] + # if decomp_attn == True: + # decomp_list.extend( + # [ + # torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + # torch.ops.aten._scaled_dot_product_flash_attention.default, + # ] + # ) + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + + class CompiledScheduler(CompiledModule): + run_init = _initialize + run_prep = _prep + run_step = _step + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledScheduler(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + + if compile_to != "vmfb": + return module_str + elif compile_to == "vmfb": + vmfb = utils.compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags, + safe_name, + return_path=True, + ) + if exit_on_vmfb: + exit() + return vmfb + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + mod_str = export_scheduler_model( + args.hf_model_name, + args.batch_size, + args.height, + args.width, + args.num_inference_steps, + args.precision, + args.compile_to, + args.device, + args.iree_target_triple, + args.ireec_flags, + exit_on_vmfb=False, + input_mlir=args.input_mlir, + ) + vmfb_names = [ + "EulerFlowScheduler", + f"bs{args.batch_size}_{args.height}x{args.width}", + args.precision, + str(args.num_inference_steps), + args.iree_target_triple, + ] + safe_name = "_".join(vmfb_names) + if args.compile_to != "vmfb": + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py new file mode 100644 index 000000000..895f27bf7 --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py @@ -0,0 +1,227 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys + +import safetensors +from iree import runtime as ireert +import iree.compiler as ireec +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +import torch +from turbine_models.custom_models.sd3_inference.text_encoder_impls import SDClipModel, SDXLClipG, T5XXLModel, load_into +from huggingface_hub import hf_hub_download +from safetensors import safe_open + +CLIPG_CONFIG = { + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32 +} + +CLIPL_CONFIG = { + "hidden_act": "quick_gelu", + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12 +} + +T5_CONFIG = { + "d_ff": 10240, + "d_model": 4096, + "num_heads": 64, + "num_layers": 24, + "vocab_size": 32128 +} + +class TextEncoderModule(torch.nn.Module): + @torch.no_grad() + def __init__( + self, + batch_size=1, + ): + super().__init__() + self.dtype = torch.float16 + self.clip_l = SDClipModel( + layer="hidden", + layer_idx=-2, + device="cpu", + dtype=self.dtype, + layer_norm_hidden_state=False, + return_projected_pooled=False, + textmodel_json_config=CLIPL_CONFIG + ).half() + clip_l_weights = hf_hub_download( + repo_id="stabilityai/stable-diffusion-3-medium", + filename="text_encoders/clip_l.safetensors" + ) + with safe_open(clip_l_weights, framework="pt", device="cpu") as f: + load_into(f, self.clip_l.transformer, "", "cpu", self.dtype) + self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=self.dtype).half() + clip_g_weights = hf_hub_download( + repo_id="stabilityai/stable-diffusion-3-medium", + filename="text_encoders/clip_g.safetensors" + ) + with safe_open(clip_g_weights, framework="pt", device="cpu") as f: + load_into(f, self.clip_g.transformer, "", "cpu", self.dtype) + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=self.dtype).half() + t5_weights = hf_hub_download( + repo_id="stabilityai/stable-diffusion-3-medium", + filename="text_encoders/t5xxl_fp16.safetensors" + ) + with safe_open(t5_weights, framework="pt", device="cpu") as f: + load_into(f, self.t5xxl.transformer, "", "cpu", self.dtype) + + self.do_classifier_free_guidance = True + self.batch_size = batch_size + + def get_cond(self, tokens_l, tokens_g, tokens_t5xxl): + l_out, l_pooled = self.clip_l.forward(tokens_l) + g_out, g_pooled = self.clip_g.forward(tokens_g) + t5_out, _ = self.t5xxl.forward(tokens_t5xxl) + lg_out = torch.cat([l_out, g_out], dim=-1) + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) + + def forward(self, tokens_g, tokens_l, tokens_t5xxl, neg_g, neg_l, neg_t5): + conditioning, cond_pool = self.get_cond(tokens_l, tokens_g, tokens_t5xxl) + neg_cond, neg_cond_pool = self.get_cond(neg_l, neg_g, neg_t5) + + prompt_embeds = torch.cat([neg_cond, conditioning], dim=0) + pooled_prompt_embeds = torch.cat([cond_pool, neg_cond_pool], dim=0) + + return prompt_embeds, pooled_prompt_embeds + +@torch.no_grad() +def export_text_encoder( + hf_model_name, + hf_auth_token=None, + max_length=64, + precision="fp16", + compile_to="torch", + external_weights=None, + external_weight_path=None, + device=None, + target_triple=None, + ireec_flags=None, + exit_on_vmfb=True, + pipeline_dir=None, + input_mlir=None, + attn_spec=None, + output_batchsize=1, + decomp_attn=True, +): + if pipeline_dir not in [None, ""]: + safe_name = os.path.join(pipeline_dir, "text_encoders") + else: + safe_name = utils.create_safe_name( + hf_model_name, f"_{str(max_length)}_{precision}_text_encoders-{device}" + ) + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target_triple, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + const_expr_hoisting=True, + attn_spec=attn_spec, + ) + return vmfb_path + model = TextEncoderModule( + batch_size=output_batchsize, + ) + mapper = {} + + assert ".safetensors" not in external_weight_path, "Original parameters format incompatible with IREE safetensors parser. Use '.irpa' instead." + + input_args = [torch.empty([1,77,2], dtype=torch.int64) for x in range(6)] + + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + fxb = FxProgramsBuilder(model) + + @fxb.export_program( + args=(input_args,), + ) + def _forward( + module, + inputs, + ): + return module.forward(*inputs) + + class CompiledTextEncoder(CompiledModule): + encode_tokens = _forward + + if external_weights: + externalize_module_parameters(model) + save_module_parameters(external_weight_path, model) + + inst = CompiledTextEncoder(context=Context(), import_to="IMPORT") + + module_str = str(CompiledModule.get_mlir_module(inst)) + + if compile_to != "vmfb": + return module_str + else: + vmfb_path = utils.compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags, + safe_name, + return_path=not exit_on_vmfb, + const_expr_hoisting=True, + attn_spec=attn_spec, + ) + return module_str, vmfb_path + + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + mod_str, _ = export_text_encoder( + args.hf_model_name, + args.hf_auth_token, + args.max_length, + args.precision, + args.compile_to, + args.external_weights, + args.external_weight_path, + args.device, + args.iree_target_triple, + args.ireec_flags + args.clip_flags, + exit_on_vmfb=True, + pipeline_dir=args.pipeline_dir, + input_mlir=args.input_mlir, + attn_spec=args.attn_spec, + output_batchsize=args.batch_size, + ) + if args.input_mlir or args.weights_only or args.compile_to=="vmfb": + exit() + safe_name = utils.create_safe_name( + args.hf_model_name, f"_{str(args.max_length)}_{args.precision}_text_encoders" + ) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders_runner.py b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders_runner.py new file mode 100644 index 000000000..1093f4b27 --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_text_encoders_runner.py @@ -0,0 +1,116 @@ +from turbine_models.model_runner import vmfbRunner +from text_encoder_impls import SD3Tokenizer, T5XXLTokenizer, SDXLClipGTokenizer +from iree import runtime as ireert +import torch +import numpy as np + + +def run_prompt_encoder( + vmfb_path, + device, + external_weight_path, + input_ids, + uncond_input_ids, +): + prompt_encoder_runner = vmfbRunner(device, vmfb_path, external_weight_path) + # np.save("input0.npy", input_ids[0].numpy()) + # np.save("input1.npy", input_ids[1].numpy()) + # np.save("input2.npy", input_ids[2].numpy()) + # np.save("input3.npy", uncond_input_ids[0].numpy()) + # np.save("input4.npy", uncond_input_ids[1].numpy()) + # np.save("input5.npy", uncond_input_ids[2].numpy()) + prompt_encoder_inputs = [ + ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[0]), + ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[1]), + ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[2]), + ireert.asdevicearray(prompt_encoder_runner.config.device, uncond_input_ids[0]), + ireert.asdevicearray(prompt_encoder_runner.config.device, uncond_input_ids[1]), + ireert.asdevicearray(prompt_encoder_runner.config.device, uncond_input_ids[2]), + + ] + encoded_outputs = prompt_encoder_runner.ctx.modules.compiled_text_encoder["encode_tokens"]( + *prompt_encoder_inputs + ) + for i in encoded_outputs: + i = i.to_host() + del prompt_encoder_inputs + return encoded_outputs + + +def run_tokenize( + tokenizer, + prompt, + negative_prompt, +): + + prompt_tokens_dict = tokenizer.tokenize_with_weights(prompt) + neg_prompt_tokens_dict = tokenizer.tokenize_with_weights(negative_prompt) + text_input_ids_list = list(prompt_tokens_dict.values()) + uncond_input_ids_list = list(neg_prompt_tokens_dict.values()) + return text_input_ids_list, uncond_input_ids_list + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + tokenizer = SD3Tokenizer() + + text_input_ids_list, uncond_input_ids_list = run_tokenize( + tokenizer, + args.prompt, + args.negative_prompt, + ) + turbine_output1, turbine_output2 = run_prompt_encoder( + args.vmfb_path, + args.rt_device, + args.external_weight_path, + text_input_ids_list, + uncond_input_ids_list, + ) + print( + "TURBINE OUTPUT 1:", + turbine_output1.to_host(), + turbine_output1.shape, + turbine_output1.dtype, + ) + + print( + "TURBINE OUTPUT 2:", + turbine_output2.to_host(), + turbine_output2.shape, + turbine_output2.dtype, + ) + + if args.compare_vs_torch: + print("generating torch output: ") + from turbine_models.custom_models.sd_inference import utils + from turbine_models.custom_models.sd3_inference.sd3_text_encoders import ( + TextEncoderModule, + ) + + torch_encoder_model = TextEncoderModule( + args.batch_size, + ) + torch_output1, torch_output2 = torch_encoder_model.forward( + *text_input_ids_list, *uncond_input_ids_list + ) + np.save("torch_output1.npy", torch_output1) + np.save("torch_output2.npy", torch_output2) + print( + "TORCH OUTPUT 1:", torch_output1, torch_output1.shape, torch_output1.dtype + ) + + print( + "TORCH OUTPUT 2:", torch_output2, torch_output2.shape, torch_output2.dtype + ) + rtol = 4e-2 + atol = 4e-2 + + np.testing.assert_allclose( + torch_output1, turbine_output1.to_host(), rtol, atol, verbose=True + ) + np.testing.assert_allclose( + torch_output2, turbine_output2.to_host(), rtol, atol, verbose=True + ) + print("Passed!") + # TODO: Figure out why we occasionally segfault without unlinking output variables + turbine_output1, turbine_output2 = (None, None) diff --git a/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py b/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py new file mode 100644 index 000000000..09a69fe9c --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py @@ -0,0 +1,537 @@ +### This file contains impls for underlying related models (CLIP, T5, etc) + +import torch, math +from torch import nn +from transformers import CLIPTokenizer, T5TokenizerFast +from shark_turbine import ops + +################################################################################################# +### Core/Utility +################################################################################################# + + +def attention(q, k, v, heads, mask=None): + """Convenience wrapper around a basic attention operation""" + b, _, dim_head = q.shape + #ops.iree.trace_tensor("attention_q", q[0,0,:5]) + #ops.iree.trace_tensor("attention_k", k[0,0,:5]) + #ops.iree.trace_tensor("attention_v", v[0,0,:5]) + dim_head //= heads + q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v)) + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + #ops.iree.trace_tensor("attention_out", out[0,0,:5]) + return out.transpose(1, 2).reshape(b, -1, heads * dim_head) + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks""" + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, dtype=None, device=None): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) + self.act = act_layer + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) + + def forward(self, x): + x = self.fc1(x) + #ops.iree.trace_tensor("mlpfx", x[0,0,:5]) + x = self.act(x) + #ops.iree.trace_tensor("mlpact", x[0,0,:5]) + x = self.fc2(x) + #ops.iree.trace_tensor("mlpanotherfc", x[0,0,:5]) + return x + +def load_into(f, model, prefix, device, dtype=None): + """Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module.""" + for key in f.keys(): + if key.startswith(prefix) and not key.startswith("loss."): + path = key[len(prefix):].split(".") + obj = model + for p in path: + if obj is list: + obj = obj[int(p)] + else: + obj = getattr(obj, p, None) + if obj is None: + print(f"Skipping key '{key}' in safetensors file as '{p}' does not exist in python model") + break + if obj is None: + continue + try: + tensor = f.get_tensor(key).to(device=device) + if dtype is not None: + tensor = tensor.to(dtype=dtype) + obj.requires_grad_(False) + obj.set_(tensor) + except Exception as e: + print(f"Failed to load key '{key}' in safetensors file: {e}") + raise e + +################################################################################################# +### CLIP +################################################################################################# + + +class CLIPAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device): + super().__init__() + self.heads = heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + + def forward(self, x, mask=None): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + out = attention(q, k, v, self.heads, mask) + return self.out_proj(out) + + +ACTIVATIONS = { + "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), + "gelu": torch.nn.functional.gelu, +} + +class CLIPLayer(torch.nn.Module): + def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) + self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + #self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) + self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device) + + def forward(self, x, mask=None): + x += self.self_attn(self.layer_norm1(x), mask) + x += self.mlp(self.layer_norm2(x)) + return x + + +class CLIPEncoder(torch.nn.Module): + def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)]) + + def forward(self, x, mask=None, intermediate_output=None): + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layers) + intermediate_output + intermediate = None + for i, l in enumerate(self.layers): + x = l(x, mask) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + + +class CLIPEmbeddings(torch.nn.Module): + def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): + super().__init__() + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens): + return self.token_embedding(input_tokens) + self.position_embedding.weight + + +class CLIPTextModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + super().__init__() + self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) + self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) + self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True): + x = self.embeddings(input_tokens) + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output) + x = self.final_layer_norm(x) + if i is not None and final_layer_norm_intermediate: + i = self.final_layer_norm(i) + pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),] + return x, i, pooled_output + + +class CLIPTextModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_hidden_layers"] + self.text_model = CLIPTextModel_(config_dict, dtype, device) + embed_dim = config_dict["hidden_size"] + self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection.weight.copy_(torch.eye(embed_dim)) + self.dtype = dtype + + def get_input_embeddings(self): + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, embeddings): + self.text_model.embeddings.token_embedding = embeddings + + def forward(self, *args, **kwargs): + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out, x[2]) + +class SDTokenizer: + def __init__(self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None): + self.tokenizer = tokenizer + self.max_length = max_length + self.min_length = min_length + empty = self.tokenizer('')["input_ids"] + if has_start_token: + self.tokens_start = 1 + self.start_token = empty[0] + self.end_token = empty[1] + else: + self.tokens_start = 0 + self.start_token = None + self.end_token = empty[0] + self.pad_with_end = pad_with_end + self.pad_to_max_length = pad_to_max_length + vocab = self.tokenizer.get_vocab() + self.inv_vocab = {v: k for k, v in vocab.items()} + self.max_word_length = 8 + + + def tokenize_with_weights(self, text:str): + """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" + if self.pad_with_end: + pad_token = self.end_token + else: + pad_token = 0 + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0)) + to_tokenize = text.replace("\n", " ").split(' ') + to_tokenize = [x for x in to_tokenize if x != ""] + for word in to_tokenize: + batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]]) + batch.append((self.end_token, 1.0)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + return [batch] + + +class SDXLClipGTokenizer(SDTokenizer): + def __init__(self, tokenizer): + super().__init__(pad_with_end=False, tokenizer=tokenizer) + + +class SD3Tokenizer: + def __init__(self): + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) + self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + self.t5xxl = T5XXLTokenizer() + + def tokenize_with_weights(self, text:str): + out = {} + out["g"] = self.clip_g.tokenize_with_weights(text) + out["l"] = self.clip_l.tokenize_with_weights(text) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text) + for k, v in out.items(): + out[k] = torch.tensor(v, dtype=torch.int64, device="cpu") + return out + + +class ClipTokenWeightEncoder: + def encode_token_weights(self, token_weight_pairs): + #tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + tokens = token_weight_pairs[:,:,0] + out, pooled = self(tokens) + if pooled is not None: + first_pooled = pooled[0:1].cpu() + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2).cpu(), first_pooled + + +class SDClipModel(torch.nn.Module): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = ["last", "pooled", "hidden"] + def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, return_projected_pooled=True): + super().__init__() + assert layer in self.LAYERS + self.transformer = model_class(textmodel_json_config, dtype, device) + self.num_layers = self.transformer.num_layers + self.max_length = max_length + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.set_clip_options({"layer": layer_idx}) + self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) + + def encode_token_weights(self, token_weight_pairs): + pass + + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + if layer_idx is None or abs(layer_idx) > self.num_layers: + self.layer = "last" + else: + self.layer = "hidden" + self.layer_idx = layer_idx + + def forward(self, token_weight_pairs): + #tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + tokens = token_weight_pairs[:,:,0] + #backup_embeds = self.transformer.get_input_embeddings() + #device = backup_embeds.weight.device + #tokens = torch.LongTensor(tokens).to(device) + outputs = self.transformer(tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) + #self.transformer.set_input_embeddings(backup_embeds) + if self.layer == "last": + z = outputs[0] + else: + z = outputs[1] + pooled_output = None + if len(outputs) >= 3: + if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() + out, pooled = z.float(), pooled_output + if pooled is not None: + first_pooled = pooled[0:1].cpu() + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2).cpu(), first_pooled + + +class SDXLClipG(SDClipModel): + """Wraps the CLIP-G model into the SD-CLIP-Model interface""" + def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None): + if layer == "penultimate": + layer="hidden" + layer_idx=-2 + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) + + +class T5XXLModel(SDClipModel): + """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" + def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5) + + +################################################################################################# +### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl +################################################################################################# + + +class T5XXLTokenizer(SDTokenizer): + """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + def __init__(self): + super().__init__(pad_with_end=False, tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) + + +class T5LayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) + self.variance_epsilon = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + return self.weight.to(device=x.device, dtype=x.dtype) * x + + +class T5DenseGatedActDense(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") + hidden_linear = self.wi_1(x) + x = hidden_gelu * hidden_linear + x = self.wo(x) + return x + + +class T5LayerFF(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x): + forwarded_states = self.layer_norm(x) + forwarded_states = self.DenseReluDense(forwarded_states) + x += forwarded_states + return x + + +class T5Attention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) + self.num_heads = num_heads + self.relative_attention_bias = None + if relative_attention_bias: + self.relative_attention_num_buckets = 32 + self.relative_attention_max_distance = 128 + self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)) + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward(self, x, past_bias=None): + q = self.q(x) + k = self.k(x) + v = self.v(x) + if self.relative_attention_bias is not None: + past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) + if past_bias is not None: + mask = past_bias + out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask) + return self.o(out), past_bias + + +class T5LayerSelfAttention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x, past_bias=None): + output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) + x += output + return x, past_bias + + +class T5Block(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.layer = torch.nn.ModuleList() + self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device)) + self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) + + def forward(self, x, past_bias=None): + x, past_bias = self.layer[0](x, past_bias) + x = self.layer[-1](x) + return x, past_bias + + +class T5Stack(torch.nn.Module): + def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): + super().__init__() + self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) + self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)]) + self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): + intermediate = None + x = self.embed_tokens(input_ids) + past_bias = None + for i, l in enumerate(self.block): + x, past_bias = l(x, past_bias) + if i == intermediate_output: + intermediate = x.clone() + x = self.final_layer_norm(x) + if intermediate is not None and final_layer_norm_intermediate: + intermediate = self.final_layer_norm(intermediate) + return x, intermediate + + +class T5(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_layers"] + self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device) + self.dtype = dtype + + def get_input_embeddings(self): + return self.encoder.embed_tokens + + def set_input_embeddings(self, embeddings): + self.encoder.embed_tokens = embeddings + + def forward(self, *args, **kwargs): + return self.encoder(*args, **kwargs) diff --git a/models/turbine_models/custom_models/sd3_inference/turbine_mmdit.py b/models/turbine_models/custom_models/sd3_inference/turbine_mmdit.py new file mode 100644 index 000000000..1cdebc076 --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/turbine_mmdit.py @@ -0,0 +1,217 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import copy +import os +import sys +import math + +from safetensors import safe_open +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from shark_turbine.dynamo.passes import ( + DEFAULT_DECOMPOSITIONS, +) +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import SD3Transformer2DModel + + +class MMDiTModel(torch.nn.Module): + def __init__( + self, + hf_model_name = "stabilityai/stable-diffusion-3-medium-diffusers", + dtype=torch.float16, + ): + super().__init__() + self.mmdit = SD3Transformer2DModel.from_pretrained( + hf_model_name, + subfolder="transformer", + torch_dtype=dtype, + low_cpu_mem_usage=False, + ) + + + def forward( + self, hidden_states, encoder_hidden_states, pooled_projections, timestep, lora_scale, + ): + joint_attention_kwargs = { + "scale": lora_scale, + } + noise_pred = self.mmdit(hidden_states, encoder_hidden_states, pooled_projections, timestep,joint_attention_kwargs, return_dict=False)[0] + return noise_pred + + +@torch.no_grad() +def export_mmdit_model( + mmdit_model, + hf_model_name, + batch_size, + height, + width, + precision="fp32", + max_length=77, + hf_auth_token=None, + compile_to="torch", + external_weights=None, + external_weight_path=None, + device=None, + target_triple=None, + ireec_flags=None, + decomp_attn=False, + exit_on_vmfb=False, + pipeline_dir=None, + attn_spec=None, + input_mlir=None, + weights_only=False, +): + dtype = torch.float16 if args.precision == "fp16" else torch.float32 + if pipeline_dir: + safe_name = os.path.join(pipeline_dir, f"mmdit") + else: + safe_name = utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_mmdit", + ) + if decomp_attn == True: + ireec_flags += ",--iree-opt-aggressively-propagate-transposes=False" + + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target_triple, + ireec_flags, + safe_name + "_" + target_triple, + mlir_source="file", + return_path=not exit_on_vmfb, + attn_spec=attn_spec, + ) + return vmfb_path + + mapper = {} + + utils.save_external_weights( + mapper, mmdit_model, external_weights, external_weight_path + ) + + if weights_only: + return external_weight_path + + do_classifier_free_guidance = True + init_batch_dim = 2 if do_classifier_free_guidance else 1 + + hidden_states_shape = ( + batch_size, + 16, + height // 8, + width // 8, + ) + encoder_hidden_states_shape = (batch_size, 77, 4096) + pooled_projections_shape = (batch_size, 2048) + example_forward_args = [ + torch.empty(hidden_states_shape, dtype=dtype), + torch.empty(encoder_hidden_states_shape, dtype=dtype), + torch.empty(pooled_projections_shape, dtype=dtype), + torch.empty(1, dtype=dtype), + torch.empty(1, dtype=dtype), + ] + + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + fxb = FxProgramsBuilder(mmdit_model) + + @fxb.export_program( + args=(example_forward_args,), + ) + def _forward( + module, + inputs, + ): + return module.forward(*inputs) + + class CompiledMmdit(CompiledModule): + run_forward = _forward + + if external_weights: + externalize_module_parameters(mmdit_model) + + inst = CompiledMmdit(context=Context(), import_to="IMPORT") + + module_str = str(CompiledModule.get_mlir_module(inst)) + + if compile_to != "vmfb": + return module_str + else: + vmfb_path = utils.compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags, + safe_name, + return_path=True, + attn_spec=attn_spec, + ) + if exit_on_vmfb: + exit() + return vmfb_path + + +if __name__ == "__main__": + import logging + + logging.basicConfig(level=logging.DEBUG) + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + if args.input_mlir: + mmdit_model = None + else: + mmdit_model = MMDiTModel( + args.hf_model_name, + dtype=torch.float16 if args.precision == "fp16" else torch.float32 + ) + mod_str = export_mmdit_model( + mmdit_model, + args.hf_model_name, + args.batch_size, + args.height, + args.width, + args.precision, + args.max_length, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_path, + args.device, + args.iree_target_triple, + args.ireec_flags + args.attn_flags + args.unet_flags, + args.decomp_attn, + attn_spec=args.attn_spec, + input_mlir=args.input_mlir, + weights_only=args.weights_only, + ) + if args.input_mlir: + exit() + safe_name = utils.create_safe_name( + args.hf_model_name, + f"_bs{args.batch_size}_{args.max_length}_{args.height}x{args.width}_{args.precision}_mmdit", + ) + if args.compile_to != "vmfb": + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 840c8bd1a..ce48dff33 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -3,6 +3,7 @@ import numpy as np import os import safetensors +import safetensors.numpy as safe_numpy import re from diffusers import ( PNDMScheduler, @@ -270,14 +271,22 @@ def save_external_weights( model, external_weights=None, external_weight_file=None, + force_format=False, ): if external_weights is not None: if external_weights in ["safetensors", "irpa"]: mod_params = dict(model.named_parameters()) + mod_buffers = dict(model.named_buffers()) + mod_params.update(mod_buffers) for name in mod_params: mapper["params." + name] = name if external_weight_file and not os.path.isfile(external_weight_file): - safetensors.torch.save_file(mod_params, external_weight_file) + if not force_format: + safetensors.torch.save_file(mod_params, external_weight_file) + else: + for x in mod_params.keys(): + mod_params[x] = mod_params[x].numpy() + safe_numpy.save_file(mod_params, external_weight_file) print("Saved params to", external_weight_file)