From fabd52c50439e59f9ad204f4945eae2cf85a0aa2 Mon Sep 17 00:00:00 2001 From: Avinash Sharma Date: Fri, 16 Feb 2024 16:00:20 -0800 Subject: [PATCH] SD PNDMScheduler + Unet example through Turbine (#403) TODO: Need to update the rest of the schedulers in diffusers upstream for e2e test to work. Xfailed for now. --- core/shark_turbine/dynamo/passes.py | 1 + .../custom_models/sd_inference/schedulers.py | 178 ++++++++++++++++++ .../sd_inference/schedulers_runner.py | 172 +++++++++++++++++ .../custom_models/sd_inference/utils.py | 22 +++ .../custom_models/sd_inference/vae_runner.py | 2 +- models/turbine_models/tests/sd_test.py | 65 ++++++- 6 files changed, 438 insertions(+), 2 deletions(-) create mode 100644 models/turbine_models/custom_models/sd_inference/schedulers.py create mode 100644 models/turbine_models/custom_models/sd_inference/schedulers_runner.py diff --git a/core/shark_turbine/dynamo/passes.py b/core/shark_turbine/dynamo/passes.py index 88c08f6ad..68261f50c 100644 --- a/core/shark_turbine/dynamo/passes.py +++ b/core/shark_turbine/dynamo/passes.py @@ -48,6 +48,7 @@ torch.ops.aten._log_softmax_backward_data, torch.ops.aten.lift_fresh_copy.default, torch.ops.aten._unsafe_index.Tensor, + torch.ops.aten.unbind.int, # decompositions added manually in this file torch.ops.aten._scaled_dot_product_flash_attention.default, ] diff --git a/models/turbine_models/custom_models/sd_inference/schedulers.py b/models/turbine_models/custom_models/sd_inference/schedulers.py new file mode 100644 index 000000000..97bd2418f --- /dev/null +++ b/models/turbine_models/custom_models/sd_inference/schedulers.py @@ -0,0 +1,178 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys + +import torch +from torch.fx.experimental.proxy_tensor import make_fx +from shark_turbine.aot import * +from iree import runtime as ireert +import iree.compiler as ireec +from iree.compiler.ir import Context +import numpy as np + +from turbine_models.custom_models.sd_inference import utils +from diffusers import ( + UNet2DConditionModel, +) + +import safetensors +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument( + "--hf_auth_token", type=str, help="The Hugging Face auth token, required" +) +parser.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="CompVis/stable-diffusion-v1-4", +) +parser.add_argument( + "--scheduler_id", + type=str, + help="Scheduler ID", + default="PNDM", +) +parser.add_argument( + "--num_inference_steps", type=int, default=50, help="Number of inference steps" +) +parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size for inference" +) +parser.add_argument( + "--height", type=int, default=512, help="Height of Stable Diffusion" +) +parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") +parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") +parser.add_argument("--external_weight_path", type=str, default="") +parser.add_argument( + "--external_weights", + type=str, + default=None, + help="saves ir/vmfb without global weights for size and readability, options [safetensors]", +) +parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") +# TODO: Bring in detection for target triple +parser.add_argument( + "--iree_target_triple", + type=str, + default="", + help="Specify vulkan target triple or rocm/cuda target device.", +) +parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") + + +class Scheduler(torch.nn.Module): + def __init__(self, hf_model_name, num_inference_steps, scheduler): + super().__init__() + self.scheduler = scheduler + self.scheduler.set_timesteps(num_inference_steps) + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + ) + self.guidance_scale = 7.5 + + def forward(self, latents, encoder_hidden_states) -> torch.FloatTensor: + latents = latents * self.scheduler.init_noise_sigma + for t in self.scheduler.timesteps: + latent_model_input = torch.cat([latents] * 2) + t = t.unsqueeze(0) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, timestep=t + ) + unet_out = self.unet.forward( + latent_model_input, t, encoder_hidden_states, return_dict=False + )[0] + noise_pred_uncond, noise_pred_text = unet_out.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + return latents + + +def export_scheduler( + scheduler, + hf_model_name, + batch_size, + height, + width, + hf_auth_token=None, + compile_to="torch", + external_weights=None, + external_weight_path=None, + device=None, + target_triple=None, + max_alloc=None, +): + mapper = {} + utils.save_external_weights( + mapper, scheduler, external_weights, external_weight_path + ) + + encoder_hidden_states_sizes = (2, 77, 768) + if hf_model_name == "stabilityai/stable-diffusion-2-1-base": + encoder_hidden_states_sizes = (2, 77, 1024) + + sample = (batch_size, 4, height // 8, width // 8) + + class CompiledScheduler(CompiledModule): + if external_weights: + params = export_parameters( + scheduler, external=True, external_scope="", name_mapper=mapper.get + ) + else: + params = export_parameters(scheduler) + + def main( + self, + sample=AbstractTensor(*sample, dtype=torch.float32), + encoder_hidden_states=AbstractTensor( + *encoder_hidden_states_sizes, dtype=torch.float32 + ), + ): + return jittable(scheduler.forward)(sample, encoder_hidden_states) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledScheduler(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + safe_name = utils.create_safe_name(hf_model_name, "-scheduler") + if compile_to != "vmfb": + return module_str + else: + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) + + +if __name__ == "__main__": + args = parser.parse_args() + schedulers = utils.get_schedulers(args.hf_model_name) + scheduler = schedulers[args.scheduler_id] + scheduler_module = Scheduler( + args.hf_model_name, args.num_inference_steps, scheduler + ) + mod_str = export_scheduler( + scheduler_module, + args.hf_model_name, + args.batch_size, + args.height, + args.width, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_path, + args.device, + args.iree_target_triple, + args.vulkan_max_allocation, + ) + safe_name = utils.create_safe_name(args.hf_model_name, "-scheduler") + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sd_inference/schedulers_runner.py b/models/turbine_models/custom_models/sd_inference/schedulers_runner.py new file mode 100644 index 000000000..2490f8ebf --- /dev/null +++ b/models/turbine_models/custom_models/sd_inference/schedulers_runner.py @@ -0,0 +1,172 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +from turbine_models.model_runner import vmfbRunner +from iree import runtime as ireert +import torch +from diffusers import ( + PNDMScheduler, + UNet2DConditionModel, +) + +parser = argparse.ArgumentParser() + +# TODO move common runner flags to generic flag file +parser.add_argument( + "--scheduler_id", + type=str, + help="Scheduler ID", + default="PNDM", +) +parser.add_argument( + "--num_inference_steps", type=int, default=50, help="Number of inference steps" +) +parser.add_argument( + "--vmfb_path", type=str, default="", help="path to vmfb containing compiled module" +) +parser.add_argument( + "--external_weight_path", + type=str, + default="", + help="path to external weight parameters if model compiled without them", +) +parser.add_argument( + "--compare_vs_torch", + action="store_true", + help="Runs both turbine vmfb and a torch model to compare results", +) +parser.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="CompVis/stable-diffusion-v1-4", +) +parser.add_argument( + "--hf_auth_token", + type=str, + help="The Hugging face auth token, required for some models", +) +parser.add_argument( + "--device", + type=str, + default="local-task", + help="local-sync, local-task, cuda, vulkan, rocm", +) +parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size for inference" +) +parser.add_argument( + "--height", type=int, default=512, help="Height of Stable Diffusion" +) +parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") + + +def run_scheduler( + device, + sample, + encoder_hidden_states, + vmfb_path, + hf_model_name, + hf_auth_token, + external_weight_path, +): + runner = vmfbRunner(device, vmfb_path, external_weight_path) + + inputs = [ + ireert.asdevicearray(runner.config.device, sample), + ireert.asdevicearray(runner.config.device, encoder_hidden_states), + ] + results = runner.ctx.modules.compiled_scheduler["main"](*inputs) + return results + + +def run_torch_scheduler( + hf_model_name, scheduler, num_inference_steps, sample, encoder_hidden_states +): + class Scheduler(torch.nn.Module): + def __init__(self, hf_model_name, num_inference_steps, scheduler): + super().__init__() + self.scheduler = scheduler + self.scheduler.set_timesteps(num_inference_steps) + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + ) + self.guidance_scale = 7.5 + + def forward(self, latents, encoder_hidden_states) -> torch.FloatTensor: + latents = latents * self.scheduler.init_noise_sigma + for t in self.scheduler.timesteps: + latent_model_input = torch.cat([latents] * 2) + t = t.unsqueeze(0) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, timestep=t + ) + unet_out = self.unet.forward( + latent_model_input, t, encoder_hidden_states, return_dict=False + )[0] + noise_pred_uncond, noise_pred_text = unet_out.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False + )[0] + return latents + + scheduler_module = Scheduler(hf_model_name, num_inference_steps, scheduler) + results = scheduler_module.forward(sample, encoder_hidden_states) + np_torch_output = results.detach().cpu().numpy() + return np_torch_output + + +if __name__ == "__main__": + args = parser.parse_args() + sample = torch.rand( + args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32 + ) + if args.hf_model_name == "CompVis/stable-diffusion-v1-4": + encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) + elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base": + encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32) + + turbine_output = run_scheduler( + args.device, + sample, + encoder_hidden_states, + args.vmfb_path, + args.hf_model_name, + args.hf_auth_token, + args.external_weight_path, + ) + print( + "TURBINE OUTPUT:", + turbine_output.to_host(), + turbine_output.to_host().shape, + turbine_output.to_host().dtype, + ) + + if args.compare_vs_torch: + print("generating torch output: ") + from turbine_models.custom_models.sd_inference import utils + + schedulers = utils.get_schedulers(args.hf_model_name) + scheduler = schedulers[args.scheduler_id] + torch_output = run_torch_scheduler( + args.hf_model_name, + scheduler, + args.num_inference_steps, + sample, + encoder_hidden_states, + ) + print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + err = utils.largest_error(torch_output, turbine_output) + print("Largest Error: ", err) + assert err < 9e-3 + + # TODO: Figure out why we occasionally segfault without unlinking output variables + turbine_output = None diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 37787fd3a..8f509a476 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -2,6 +2,9 @@ import numpy as np import safetensors import re +from diffusers import ( + PNDMScheduler, +) def save_external_weights( @@ -35,6 +38,7 @@ def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name): "--iree-llvmcpu-target-triple=x86_64-linux-gnu", "--iree-stream-resource-index-bits=64", "--iree-vm-target-index-bits=64", + "--iree-flow-inline-constants-max-byte-length=1", ] if device == "cpu": flags.append("--iree-llvmcpu-enable-ukernels=all") @@ -86,3 +90,21 @@ def create_safe_name(hf_model_name, model_name_str): safe_name = hf_model_name.split("/")[-1].strip() + model_name_str safe_name = re.sub("-", "_", safe_name) return safe_name + + +def get_schedulers(model_id): + # TODO: Robust scheduler setup on pipeline creation -- if we don't + # set batch_size here, the SHARK schedulers will + # compile with batch size = 1 regardless of whether the model + # outputs latents of a larger batch size, e.g. SDXL. + # However, obviously, searching for whether the base model ID + # contains "xl" is not very robust. + + batch_size = 2 if "xl" in model_id.lower() else 1 + + schedulers = dict() + schedulers["PNDM"] = PNDMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + return schedulers diff --git a/models/turbine_models/custom_models/sd_inference/vae_runner.py b/models/turbine_models/custom_models/sd_inference/vae_runner.py index 77acaedcb..fa5e430ac 100644 --- a/models/turbine_models/custom_models/sd_inference/vae_runner.py +++ b/models/turbine_models/custom_models/sd_inference/vae_runner.py @@ -127,7 +127,7 @@ def encode_inp(self, inp): print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) err = utils.largest_error(torch_output, turbine_results) print("Largest Error: ", err) - assert err < 2e-3 + assert err < 3e-3 # TODO: Figure out why we occasionally segfault without unlinking output variables turbine_results = None diff --git a/models/turbine_models/tests/sd_test.py b/models/turbine_models/tests/sd_test.py index 125f97d82..961b920a0 100644 --- a/models/turbine_models/tests/sd_test.py +++ b/models/turbine_models/tests/sd_test.py @@ -13,6 +13,8 @@ unet_runner, vae, vae_runner, + schedulers, + schedulers_runner, ) from transformers import CLIPTextModel from turbine_models.custom_models.sd_inference import utils @@ -24,6 +26,8 @@ arguments = { "hf_auth_token": None, "hf_model_name": "CompVis/stable-diffusion-v1-4", + "scheduler_id": "PNDM", + "num_inference_steps": 5, "batch_size": 1, "height": 512, "width": 512, @@ -52,6 +56,15 @@ None, ) +schedulers_dict = utils.get_schedulers( + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", +) +scheduler = schedulers_dict[arguments["scheduler_id"]] +scheduler_module = schedulers.Scheduler( + "CompVis/stable-diffusion-v1-4", arguments["num_inference_steps"], scheduler +) + class StableDiffusionTest(unittest.TestCase): def testExportClipModel(self): @@ -220,10 +233,60 @@ def testExportVaeModelEncode(self): example_input, ) err = utils.largest_error(torch_output, turbine) - assert err < 2e-3 + assert err < 3e-3 os.remove("stable_diffusion_v1_4_vae.safetensors") os.remove("stable_diffusion_v1_4_vae.vmfb") + @unittest.expectedFailure + def testExportPNDMScheduler(self): + with self.assertRaises(SystemExit) as cm: + schedulers.export_scheduler( + scheduler_module, + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + arguments["batch_size"], + arguments["height"], + arguments["width"], + None, + "vmfb", + "safetensors", + "stable_diffusion_v1_4_scheduler.safetensors", + "cpu", + ) + self.assertEqual(cm.exception.code, None) + arguments[ + "external_weight_path" + ] = "stable_diffusion_v1_4_scheduler.safetensors" + arguments["vmfb_path"] = "stable_diffusion_v1_4_scheduler.vmfb" + sample = torch.rand( + arguments["batch_size"], + 4, + arguments["height"] // 8, + arguments["width"] // 8, + dtype=torch.float32, + ) + encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) + turbine = schedulers_runner.run_scheduler( + arguments["device"], + sample, + encoder_hidden_states, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path"], + ) + torch_output = schedulers_runner.run_torch_scheduler( + arguments["hf_model_name"], + scheduler, + arguments["num_inference_steps"], + sample, + encoder_hidden_states, + ) + err = utils.largest_error(torch_output, turbine) + assert err < 9e-3 + os.remove("stable_diffusion_v1_4_scheduler.safetensors") + os.remove("stable_diffusion_v1_4_scheduler.vmfb") + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG)