Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed May 23, 2024
1 parent 4dd179a commit 07e21e4
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
14 changes: 6 additions & 8 deletions models/turbine_models/custom_models/sd_inference/sd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,7 @@ def export_submodel(
return clip_vmfb, clip_external_weight_path
case "scheduler":
if self.cpu_scheduling:
return (
schedulers.get_scheduler(self.hf_model_name, self.scheduler_id),
None,
)
return (None, None)
scheduler = schedulers.export_scheduler(
self.hf_model_name,
self.scheduler_id,
Expand Down Expand Up @@ -368,7 +365,7 @@ def load_pipeline(
runners["clip"] = vmfbRunner(rt_device, vmfbs["clip"], weights["clip"])
if self.cpu_scheduling:
self.scheduler = schedulers.SchedulingModel(
vmfbs["scheduler"],
schedulers.get_scheduler(self.hf_model_name, self.scheduler_id),
self.height,
self.width,
self.num_inference_steps,
Expand Down Expand Up @@ -449,14 +446,15 @@ def generate_images(
sample, add_time_ids, timesteps = self.scheduler.initialize(samples[i])

if self.is_img2img:
strength = 0.5 # should be user-facing
init_timestep = min(
int(num_inference_steps * strength), num_inference_steps
int(self.num_inference_steps * strength), self.num_inference_steps
)
t_start = max(num_inference_steps - init_timestep, 0)
t_start = max(self.num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
latents = self.encode_image(image)
latents = self.scheduler.add_noise(
latents, noise, timesteps[0].repeat(1)
latents, sample, timesteps[0].repeat(1)
)
return latents, [timesteps]

Expand Down
2 changes: 1 addition & 1 deletion models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def compile_to_vmfb(
]
)
device = "llvm-cpu"
elif device == "vulkan":
elif device in ["vulkan", "vulkan-spirv"]:
flags.extend(
[
"--iree-hal-target-backends=vulkan-spirv",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
external_weights_dir: str = "./shark_weights",
external_weights: str = "safetensors",
vae_decomp_attn: bool = True,
custom_vae: str = "",
):
self.hf_model_name = hf_model_name
self.scheduler_id = scheduler_id
Expand All @@ -99,6 +100,7 @@ def __init__(
self.external_weights_dir = external_weights_dir
self.external_weights = external_weights
self.vae_decomp_attn = vae_decomp_attn
self.custom_vae = custom_vae

# FILE MANAGEMENT AND PIPELINE SETUP

Expand Down

0 comments on commit 07e21e4

Please sign in to comment.