Skip to content

Commit

Permalink
Merge branch 'nod-ai:ean-unify-sd' into ean-unify-sd
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 authored Jun 19, 2024
2 parents b45a6c5 + 618d01f commit 5a9aaa0
Show file tree
Hide file tree
Showing 11 changed files with 192 additions and 68 deletions.
1 change: 0 additions & 1 deletion models/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
protobuf
sentencepiece
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
transformers==4.37.1
torchsde
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ def is_valid_file(arg):
default="fp16",
help="Precision of Stable Diffusion weights and graph.",
)
p.add_argument(
"--vae_precision",
type=str,
default=None,
help="Precision of Stable Diffusion VAE weights and graph.",
)
p.add_argument(
"--max_length", type=int, default=77, help="Sequence Length of Stable Diffusion"
)
Expand All @@ -257,7 +263,7 @@ def is_valid_file(arg):
p.add_argument(
"--vae_decomp_attn",
type=bool,
default=True,
default=False,
help="Decompose attention for VAE decode only at fx graph level",
)
p.add_argument(
Expand Down Expand Up @@ -340,6 +346,12 @@ def is_valid_file(arg):
action="store_true",
help="Just compile attention reproducer for mmdit.",
)
p.add_argument(
"--vae_input_path",
type=str,
default=None,
help="Path to input latents for VAE inference numerics validation.",
)


##############################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def export_mmdit_model(
torch.empty(hidden_states_shape, dtype=dtype),
torch.empty(encoder_hidden_states_shape, dtype=dtype),
torch.empty(pooled_projections_shape, dtype=dtype),
torch.empty(1, dtype=dtype),
torch.empty(init_batch_dim, dtype=dtype),
]

decomp_list = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]):
(batch_size, args.max_length * 2, 4096), dtype=dtype
)
pooled_projections = torch.randn((batch_size, 2048), dtype=dtype)
timestep = torch.tensor([0], dtype=dtype)
timestep = torch.tensor([0, 0], dtype=dtype)

turbine_output = run_mmdit_turbine(
hidden_states,
Expand All @@ -180,6 +180,7 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]):
timestep,
args,
)
np.save("torch_mmdit_output.npy", torch_output.astype(np.float16))
print("torch OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)

print("\n(torch (comfy) image latents to iree image latents): ")
Expand Down
141 changes: 98 additions & 43 deletions models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from turbine_models.custom_models.sd_inference import utils
from turbine_models.model_runner import vmfbRunner
from transformers import CLIPTokenizer
from diffusers import FlowMatchEulerDiscreteScheduler

from PIL import Image
import os
Expand Down Expand Up @@ -44,10 +45,8 @@ class SharkSD3Pipeline:
def __init__(
self,
hf_model_name: str,
# scheduler_id: str,
height: int,
width: int,
shift: float,
precision: str,
max_length: int,
batch_size: int,
Expand All @@ -60,9 +59,12 @@ def __init__(
pipeline_dir: str = "./shark_vmfbs",
external_weights_dir: str = "./shark_weights",
external_weights: str = "safetensors",
vae_decomp_attn: bool = True,
custom_vae: str = "",
vae_decomp_attn: bool = False,
cpu_scheduling: bool = False,
vae_precision: str = "fp32",
scheduler_id: str = None, #compatibility only, always uses EulerFlowScheduler
shift: float = 1.0,

):
self.hf_model_name = hf_model_name
# self.scheduler_id = scheduler_id
Expand Down Expand Up @@ -120,10 +122,11 @@ def __init__(
self.external_weights_dir = external_weights_dir
self.external_weights = external_weights
self.vae_decomp_attn = vae_decomp_attn
self.custom_vae = custom_vae
self.custom_vae = None
self.cpu_scheduling = cpu_scheduling
self.torch_dtype = torch.float32 if self.precision == "fp32" else torch.float16
self.vae_dtype = torch.float32
self.vae_precision = vae_precision if vae_precision else self.precision
self.vae_dtype = torch.float32 if vae_precision == "fp32" else torch.float16
# TODO: set this based on user-inputted guidance scale and negative prompt.
self.do_classifier_free_guidance = True # False if any(x in hf_model_name for x in ["turbo", "lightning"]) else True

Expand Down Expand Up @@ -206,7 +209,12 @@ def is_prepared(self, vmfbs, weights):
)
if w_key == "clip":
default_name = os.path.join(
self.external_weights_dir, f"sd3_clip_fp16.irpa"
self.external_weights_dir, f"sd3_text_encoders_{self.precision}.irpa"
)
if w_key == "mmdit":
default_name = os.path.join(
self.external_weights_dir,
f"sd3_mmdit_{self.precision}." + self.external_weights,
)
if weights[w_key] is None and os.path.exists(default_name):
weights[w_key] = os.path.join(default_name)
Expand Down Expand Up @@ -357,7 +365,7 @@ def export_submodel(
self.batch_size,
self.height,
self.width,
"fp32",
self.vae_precision,
"vmfb",
self.external_weights,
vae_external_weight_path,
Expand Down Expand Up @@ -419,10 +427,16 @@ def load_pipeline(
unet_loaded = time.time()
print("\n[LOG] MMDiT loaded in ", unet_loaded - load_start, "sec")

runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper(
self.devices["mmdit"]["driver"],
vmfbs["scheduler"],
)
if not self.cpu_scheduling:
runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper(
self.devices["mmdit"]["driver"],
vmfbs["scheduler"],
)
else:
print("Using torch CPU scheduler.")
runners["scheduler"] = FlowMatchEulerDiscreteScheduler.from_pretrained(
self.hf_model_name, subfolder="scheduler"
)

sched_loaded = time.time()
print("\n[LOG] Scheduler loaded in ", sched_loaded - unet_loaded, "sec")
Expand Down Expand Up @@ -495,11 +509,12 @@ def generate_images(
)
)

guidance_scale = ireert.asdevicearray(
self.runners["pipe"].config.device,
np.asarray([guidance_scale]),
dtype=iree_dtype,
)
if not self.cpu_scheduling:
guidance_scale = ireert.asdevicearray(
self.runners["pipe"].config.device,
np.asarray([guidance_scale]),
dtype=iree_dtype,
)

tokenize_start = time.time()
text_input_ids_dict = self.tokenizer.tokenize_with_weights(prompt)
Expand Down Expand Up @@ -533,12 +548,23 @@ def generate_images(
"clip"
].ctx.modules.compiled_text_encoder["encode_tokens"](*text_encoders_inputs)
encode_prompts_end = time.time()
if self.cpu_scheduling:
timesteps, num_inference_steps = sd3_schedulers.retrieve_timesteps(
self.runners["scheduler"],
num_inference_steps=self.num_inference_steps,
timesteps=None,
)
steps = num_inference_steps


for i in range(batch_count):
unet_start = time.time()
sample, steps, timesteps = self.runners["scheduler"].initialize(samples[i])
if not self.cpu_scheduling:
latents, steps, timesteps = self.runners["scheduler"].initialize(samples[i])
else:
latents = torch.tensor(samples[i].to_host(), dtype=self.torch_dtype)
iree_inputs = [
sample,
latents,
ireert.asdevicearray(
self.runners["pipe"].config.device, prompt_embeds, dtype=iree_dtype
),
Expand All @@ -553,40 +579,71 @@ def generate_images(
# print(f"step {s}")
if self.cpu_scheduling:
step_index = s
t = timesteps[s]
if self.do_classifier_free_guidance:
latent_model_input = torch.cat([latents] * 2)
timestep = ireert.asdevicearray(
self.runners["pipe"].config.device,
t.expand(latent_model_input.shape[0]),
dtype=iree_dtype,
)
latent_model_input = ireert.asdevicearray(
self.runners["pipe"].config.device,
latent_model_input,
dtype=iree_dtype,
)
else:
step_index = ireert.asdevicearray(
self.runners["scheduler"].runner.config.device,
torch.tensor([s]),
"int64",
)
latents, t = self.runners["scheduler"].prep(
sample,
step_index,
timesteps,
)
latent_model_input, timestep = self.runners["scheduler"].prep(
latents,
step_index,
timesteps,
)
t = ireert.asdevicearray(
self.runners["scheduler"].runner.config.device,
timestep.to_host()[0]
)
noise_pred = self.runners["pipe"].ctx.modules.compiled_mmdit[
"run_forward"
](
latents,
latent_model_input,
iree_inputs[1],
iree_inputs[2],
t,
timestep,
)
sample = self.runners["scheduler"].step(
noise_pred,
t,
sample,
guidance_scale,
step_index,
)
if isinstance(sample, torch.Tensor):
if not self.cpu_scheduling:
latents = self.runners["scheduler"].step(
noise_pred,
t,
latents,
guidance_scale,
step_index,
)
else:
noise_pred = torch.tensor(noise_pred.to_host(), dtype=self.torch_dtype)
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.runners["scheduler"].step(
noise_pred,
t,
latents,
return_dict=False,
)[0]

if isinstance(latents, torch.Tensor):
latents = latents.type(self.vae_dtype)
latents = ireert.asdevicearray(
self.runners["vae"].config.device,
sample,
dtype=self.vae_dtype,
latents,
)
else:
latents = sample.astype("float32")
vae_numpy_dtype = np.float32 if self.vae_precision == "fp32" else np.float16
latents = latents.astype(vae_numpy_dtype)

vae_start = time.time()
vae_out = self.runners["vae"].ctx.modules.compiled_vae["decode"](latents)
Expand Down Expand Up @@ -634,7 +691,7 @@ def generate_images(
out_image = Image.fromarray(image)
images.extend([[out_image]])
if return_imgs:
return images
return images[0]
for idx_batch, image_batch in enumerate(images):
for idx, image in enumerate(image_batch):
img_path = (
Expand Down Expand Up @@ -767,7 +824,6 @@ def run_diffusers_cpu(
args.hf_model_name,
args.height,
args.width,
args.shift,
args.precision,
args.max_length,
args.batch_size,
Expand All @@ -779,16 +835,15 @@ def run_diffusers_cpu(
args.decomp_attn,
args.pipeline_dir,
args.external_weights_dir,
args.external_weights,
args.vae_decomp_attn,
custom_vae=None,
external_weights=args.external_weights,
vae_decomp_attn=args.vae_decomp_attn,
cpu_scheduling=args.cpu_scheduling,
vae_precision=args.vae_precision,
)
vmfbs, weights = sd3_pipe.check_prepared(mlirs, vmfbs, weights)
if args.cpu_scheduling:
vmfbs.pop("scheduler")
weights.pop("scheduler")
vmfbs, weights = sd3_pipe.check_prepared(mlirs, vmfbs, weights)
if args.npu_delegate_path:
extra_device_args = {"npu_delegate_path": args.npu_delegate_path}
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os
import inspect
from typing import List

import torch
from typing import Any, Callable, Dict, List, Optional, Union
from shark_turbine.aot import *
import shark_turbine.ops.iree as ops
from iree.compiler.ir import Context
Expand Down Expand Up @@ -75,11 +77,12 @@ def initialize(self, sample):

def prepare_model_input(self, sample, t, timesteps):
t = timesteps[t]
t = t.expand(sample.shape[0])

if self.do_classifier_free_guidance:
latent_model_input = torch.cat([sample] * 2)
else:
latent_model_input = sample
t = t.expand(latent_model_input.shape[0])
return latent_model_input.type(self.dtype), t.type(self.dtype)

def step(self, noise_pred, t, sample, guidance_scale, i):
Expand Down Expand Up @@ -146,6 +149,42 @@ def step(self, noise_pred, t, latents, guidance_scale, i):
return_dict=False,
)[0]

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Only used for cpu scheduling.
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps

@torch.no_grad()
def export_scheduler_model(
Expand Down
Loading

0 comments on commit 5a9aaa0

Please sign in to comment.