diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py index ce57dbe15..256e4d21b 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py @@ -65,6 +65,7 @@ def __init__( vae_precision: str = "fp32", scheduler_id: str = None, # compatibility only, always uses EulerFlowScheduler shift: float = 1.0, + custom_vae: str = None, ): self.hf_model_name = hf_model_name # self.scheduler_id = scheduler_id @@ -122,7 +123,7 @@ def __init__( self.external_weights_dir = external_weights_dir self.external_weights = external_weights self.vae_decomp_attn = vae_decomp_attn - self.custom_vae = None + self.custom_vae = custom_vae self.cpu_scheduling = cpu_scheduling self.torch_dtype = torch.float32 if self.precision == "fp32" else torch.float16 self.vae_precision = vae_precision if vae_precision else self.precision diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py index 8d6ce4ed1..51759d47d 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py @@ -32,8 +32,8 @@ "vae_decode": None, "prompt_encoder": None, "scheduled_unet": None, - "pipeline": None, - "full_pipeline": None, + "unetloop": None, + "fullpipeline": None, } EMPTY_FLAGS = { @@ -117,6 +117,12 @@ def __init__( self.vae_precision = vae_precision self.vae_dtype = "float32" if vae_precision == "fp32" else "float16" self.custom_vae = custom_vae + if self.custom_vae: + self.vae_dir = os.path.join( + self.pipeline_dir, utils.create_safe_name(custom_vae, "") + ) + if not os.path.exists(self.vae_dir): + os.makedirs(self.vae_dir) self.cpu_scheduling = cpu_scheduling self.compiled_pipeline = False self.split_scheduler = False @@ -173,6 +179,7 @@ def check_prepared( def is_prepared(self, vmfbs, weights): missing = [] dims = f"{str(self.width)}x{str(self.height)}" + pipeline_dir = self.pipeline_dir for key in vmfbs: if key == "scheduled_unet": keywords = [ @@ -189,6 +196,8 @@ def is_prepared(self, vmfbs, weights): elif key == "vae_decode": keywords = ["vae", self.vae_precision, dims] device_key = "vae" + if self.custom_vae: + pipeline_dir = self.vae_dir elif key == "prompt_encoder": keywords = ["prompt_encoder", self.precision, self.max_length] device_key = "clip" @@ -203,10 +212,10 @@ def is_prepared(self, vmfbs, weights): self.devices[device_key]["target"], ] ) - avail_files = os.listdir(self.pipeline_dir) + avail_files = os.listdir(pipeline_dir) for filename in avail_files: if all(str(x) in filename for x in keywords): - vmfbs[key] = os.path.join(self.pipeline_dir, filename) + vmfbs[key] = os.path.join(pipeline_dir, filename) if not vmfbs[key]: missing.append(key + " vmfb") @@ -432,6 +441,14 @@ def export_submodel( vae_torch = self.get_torch_models("vae_decode") else: vae_torch = None + if self.custom_vae: + vae_external_weight_path = os.path.join( + self.vae_dir, + f"vae_decode_{self.vae_precision}." + self.external_weights, + ) + vae_dir = self.vae_dir + else: + vae_dir = self.pipeline_dir vae_decode_vmfb = vae.export_vae_model( vae_torch, self.hf_model_name, @@ -448,7 +465,7 @@ def export_submodel( "decode", self.vae_decomp_attn, exit_on_vmfb=False, - pipeline_dir=self.pipeline_dir, + pipeline_dir=vae_dir, attn_spec=self.attn_spec, input_mlir=input_mlir["vae_decode"], weights_only=weights_only, @@ -468,7 +485,9 @@ def export_submodel( self.devices["clip"]["target"], self.ireec_flags["clip"], exit_on_vmfb=False, - pipeline_dir=self.pipeline_dir, + pipeline_dir=( + self.pipeline_dir if not self.custom_vae else self.vae_dir + ), input_mlir=input_mlir["prompt_encoder"], attn_spec=self.attn_spec, weights_only=weights_only, @@ -490,7 +509,7 @@ def export_submodel( f"{str(self.width)}x{str(self.height)}", self.precision, str(self.max_length), - "pipeline", + "unetloop", ] pipeline_vmfb = utils.compile_to_vmfb( pipeline_file, @@ -517,7 +536,7 @@ def export_submodel( f"{str(self.width)}x{str(self.height)}", self.precision, str(self.max_length), - "full_pipeline", + "fullpipeline", ] pipeline_vmfb = utils.compile_to_vmfb( pipeline_file, diff --git a/models/turbine_models/custom_models/sdxl_inference/vae.py b/models/turbine_models/custom_models/sdxl_inference/vae.py index ed474256e..753cbb9e7 100644 --- a/models/turbine_models/custom_models/sdxl_inference/vae.py +++ b/models/turbine_models/custom_models/sdxl_inference/vae.py @@ -19,6 +19,7 @@ import torch import torch._dynamo as dynamo from diffusers import AutoencoderKL +import safetensors class VaeModel(torch.nn.Module): @@ -34,6 +35,14 @@ def __init__( hf_model_name, subfolder="vae", ) + elif "safetensors" in custom_vae: + custom_vae = safetensors.torch.load_file(custom_vae) + # custom vae as a HF state dict + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + self.vae.load_state_dict(custom_vae) elif not isinstance(custom_vae, dict): try: # custom HF repo with no vae subfolder @@ -46,13 +55,6 @@ def __init__( custom_vae, subfolder="vae", ) - else: - # custom vae as a HF state dict - self.vae = AutoencoderKL.from_pretrained( - hf_model_name, - subfolder="vae", - ) - self.vae.load_state_dict(custom_vae) def decode(self, inp): img = 1 / 0.13025 * inp @@ -104,10 +106,10 @@ def export_vae_model( attn_spec=attn_spec, ) return vmfb_path - if precision == "fp32" and device == "rocm": - decomp_attn = True - external_weights = None - print("Decomposing attention and inlining weights for fp32 VAE on ROCm") + # if precision == "fp32" and device == "rocm": + # decomp_attn = True + # external_weights = None + # print("Decomposing attention and inlining weights for fp32 VAE on ROCm") if device == "cpu": decomp_attn = True