diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 26a55698c..e0239e4d5 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -261,10 +261,7 @@ def export_submodel( return clip_vmfb, clip_external_weight_path case "scheduler": if self.cpu_scheduling: - return ( - schedulers.get_scheduler(self.hf_model_name, self.scheduler_id), - None, - ) + return (None, None) scheduler = schedulers.export_scheduler( self.hf_model_name, self.scheduler_id, @@ -368,7 +365,7 @@ def load_pipeline( runners["clip"] = vmfbRunner(rt_device, vmfbs["clip"], weights["clip"]) if self.cpu_scheduling: self.scheduler = schedulers.SchedulingModel( - vmfbs["scheduler"], + schedulers.get_scheduler(self.hf_model_name, self.scheduler_id), self.height, self.width, self.num_inference_steps, @@ -449,14 +446,15 @@ def generate_images( sample, add_time_ids, timesteps = self.scheduler.initialize(samples[i]) if self.is_img2img: + strength = 0.5 # should be user-facing init_timestep = min( - int(num_inference_steps * strength), num_inference_steps + int(self.num_inference_steps * strength), self.num_inference_steps ) - t_start = max(num_inference_steps - init_timestep, 0) + t_start = max(self.num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] latents = self.encode_image(image) latents = self.scheduler.add_noise( - latents, noise, timesteps[0].repeat(1) + latents, sample, timesteps[0].repeat(1) ) return latents, [timesteps] diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 3a12c40e8..1c77541c7 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -69,7 +69,7 @@ def compile_to_vmfb( ] ) device = "llvm-cpu" - elif device == "vulkan": + elif device in ["vulkan", "vulkan-spirv"]: flags.extend( [ "--iree-hal-target-backends=vulkan-spirv", diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py index ef25ee0ec..f70b82cbc 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py @@ -81,6 +81,7 @@ def __init__( external_weights_dir: str = "./shark_weights", external_weights: str = "safetensors", vae_decomp_attn: bool = True, + custom_vae: str = "", ): self.hf_model_name = hf_model_name self.scheduler_id = scheduler_id @@ -99,6 +100,7 @@ def __init__( self.external_weights_dir = external_weights_dir self.external_weights = external_weights self.vae_decomp_attn = vae_decomp_attn + self.custom_vae = custom_vae # FILE MANAGEMENT AND PIPELINE SETUP