Skip to content

Commit

Permalink
SDXL: fix scheduled unet modes
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 20, 2024
1 parent 9bbbafc commit 7388e14
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def export_scheduler_model(
f"{height}x{width}",
precision,
str(num_inference_steps),
target_triple,
]
vmfb_name = "_".join(vmfb_names)
safe_name = utils.create_safe_name(hf_model_name, "_" + vmfb_name)
Expand Down
20 changes: 15 additions & 5 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,16 @@ def compile_to_vmfb(
flags.pop(idx)
print("Compiling to", device, "with flags:", flags)

# Forces a standard for naming files:
# If safe_name has target triple in it, get rid of target triple in mlir name
#
if target_triple not in safe_name:
safe_vmfb_name = safe_name + "_" + target_triple
safe_mlir_name = safe_name
else:
safe_vmfb_name = safe_name
safe_mlir_name = "".join(safe_name.split(target_triple))

if mlir_source == "file":
flatbuffer_blob = ireec.compile_file(
module_str,
Expand All @@ -249,9 +259,9 @@ def compile_to_vmfb(
)
elif mlir_source == "str":
if save_mlir:
with open(f"{safe_name}.mlir", "w+") as f:
with open(f"{safe_mlir_name}.mlir", "w+") as f:
f.write(module_str)
print("Saved to", safe_name + ".mlir")
print("Saved to", safe_mlir_name + ".mlir")
flatbuffer_blob = ireec.compile_str(
module_str,
target_backends=[device],
Expand All @@ -260,11 +270,11 @@ def compile_to_vmfb(
)
else:
raise ValueError("mlir_source must be either 'file' or 'str'")
with open(f"{safe_name}.vmfb", "wb+") as f:
with open(f"{safe_vmfb_name}.vmfb", "wb+") as f:
f.write(flatbuffer_blob)
print("Saved to", safe_name + ".vmfb")
print(f"Saved to {safe_vmfb_name}.vmfb")
if return_path == True:
return safe_name + ".vmfb"
return safe_vmfb_name + ".vmfb"


def create_safe_name(hf_model_name, model_name_str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def is_valid_file(arg):

p.add_argument(
"--split_scheduler",
default=True,
default=False,
action="store_true",
help="Use a decoupled unet and scheduler for better QOL.",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
self.custom_vae = custom_vae
self.cpu_scheduling = cpu_scheduling
self.compiled_pipeline = False
self.split_scheduler = False
# TODO: set this based on user-inputted guidance scale and negative prompt.
self.do_classifier_free_guidance = True # False if any(x in hf_model_name for x in ["turbo", "lightning"]) else True
self._interrupt = False
Expand Down Expand Up @@ -166,7 +167,7 @@ def check_prepared(
print("There was an error generating the necessary files.")
exit()
else:
print("All necessary files found. Loading pipeline.")
print("All necessary files found.")
return vmfbs, weights

def is_prepared(self, vmfbs, weights):
Expand All @@ -175,10 +176,11 @@ def is_prepared(self, vmfbs, weights):
for key in vmfbs:
if key == "scheduled_unet":
keywords = [
"unet",
"DiffusionModule",
self.scheduler_id,
self.num_inference_steps,
str(self.num_inference_steps),
self.precision,
self.max_length,
dims,
]
device_key = "unet"
Expand All @@ -192,38 +194,44 @@ def is_prepared(self, vmfbs, weights):
device_key = "clip"
else:
keywords = [key, self.precision, self.max_length, dims]
device_key = key
avail_files = os.listdir(self.pipeline_dir)
keywords.append("vmfb")
keywords.append(
utils.create_safe_name(self.hf_model_name.split("/")[-1], "")
device_key = "unet"
keywords.extend(
[
utils.create_safe_name(self.hf_model_name.split("/")[-1], ""),
"vmfb",
"bs" + str(self.batch_size),
self.devices[device_key]["target"],
]
)
keywords.append(self.devices[device_key]["target"])
avail_files = os.listdir(self.pipeline_dir)
for filename in avail_files:
if all(str(x) in filename for x in keywords):
vmfbs[key] = os.path.join(self.pipeline_dir, filename)
if not vmfbs[key]:
missing.append(key + " vmfb")

for w_key in weights:
if any(x in w_key for x in ["pipeline", "scheduler"]):
continue
if weights[w_key] is not None:
continue
if self.external_weights is None:
continue
default_name = os.path.join(
self.external_weights_dir, w_key + "." + self.external_weights
)
if weights[w_key] is None and os.path.exists(default_name):
weights[w_key] = os.path.join(default_name)
elif w_key in ["scheduled_unet"] and os.path.exists(
os.path.join(self.external_weights_dir, "unet." + self.external_weights)
if any(x in w_key for x in ["pipeline", "scheduler"]) or (
self.external_weights is None
):
weights[w_key] = os.path.join(
self.external_weights_dir, "unet." + self.external_weights
continue
elif weights[w_key] is not None:
print("Weights already found for ", w_key, "at: ", weights[w_key])
elif w_key == "vae_decode":
keywords = ["vae", self.vae_precision]
elif w_key in ["prompt_encoder", "clip"]:
keywords = ["prompt_encoder", self.precision]
elif w_key in ["scheduled_unet", "unet"]:
keywords = ["unet", self.precision]
avail_weights = os.listdir(self.external_weights_dir)
for filename in avail_weights:
if all(str(x) in filename for x in keywords):
weights[w_key] = os.path.join(self.external_weights_dir, filename)
if not weights[w_key]:
missing.append(
" ".join([keywords[0], keywords[1], self.external_weights])
)
else:
missing.append(w_key + "." + self.external_weights)

if len(missing) > 0:
print(f"Missing files: " + ", ".join(missing))
return False, vmfbs, weights
Expand Down Expand Up @@ -476,12 +484,20 @@ def export_submodel(
self.max_length,
"unet_loop",
)
pipeline_keys = [
utils.create_safe_name(self.hf_model_name.split("/")[-1], ""),
"bs" + str(self.batch_size),
f"{str(self.width)}x{str(self.height)}",
self.precision,
str(self.max_length),
"pipeline",
]
pipeline_vmfb = utils.compile_to_vmfb(
pipeline_file,
self.devices["unet"]["device"],
self.devices["unet"]["target"],
self.ireec_flags["pipeline"],
os.path.join(self.pipeline_dir, "pipeline"),
os.path.join(self.pipeline_dir, "_".join(pipeline_keys)),
return_path=True,
mlir_source="str",
)
Expand All @@ -495,12 +511,20 @@ def export_submodel(
self.max_length,
"tokens_to_image",
)
pipeline_keys = [
utils.create_safe_name(self.hf_model_name.split("/")[-1], ""),
"bs" + str(self.batch_size),
f"{str(self.width)}x{str(self.height)}",
self.precision,
str(self.max_length),
"full_pipeline",
]
pipeline_vmfb = utils.compile_to_vmfb(
pipeline_file,
self.devices["unet"]["device"],
self.devices["unet"]["target"],
self.ireec_flags["pipeline"],
os.path.join(self.pipeline_dir, "full_pipeline"),
os.path.join(self.pipeline_dir, "_".join(pipeline_keys)),
return_path=True,
mlir_source="str",
)
Expand Down Expand Up @@ -631,9 +655,13 @@ def generate_images(
progress=None,
):
needs_new_scheduler = (
steps and steps != self.num_inference_steps
) or cpu_scheduling != self.cpu_scheduling
(steps and steps != self.num_inference_steps)
or (cpu_scheduling != self.cpu_scheduling)
and self.split_scheduler
)

self.cpu_scheduling = cpu_scheduling

if steps and not self.compiled_pipeline and needs_new_scheduler:
self.num_inference_steps = steps
if (
Expand Down Expand Up @@ -953,13 +981,10 @@ def numpy_to_pil_image(images):

map = empty_pipe_dict
if args.split_scheduler:
map["scheduler"] = None
map["unet"] = None
map.pop("scheduled_unet")
map.pop("pipeline")
map.pop("full_pipeline")
if args.cpu_scheduling:
map.pop("scheduler")
mlirs = copy.deepcopy(map)
vmfbs = copy.deepcopy(map)
weights = copy.deepcopy(map)
Expand Down Expand Up @@ -1002,20 +1027,12 @@ def numpy_to_pil_image(images):
"scheduler": args.ireec_flags,
}
if not args.pipeline_dir:
pipe_id_list = [
args.hf_model_name.split("/")[-1],
str(args.height),
str(args.width),
str(args.max_length),
args.precision,
args.device,
]
if args.decomp_attn:
pipe_id_list.append("decomp")
args.pipeline_dir = os.path.join(
".",
"_".join(pipe_id_list),
utils.create_safe_name(args.hf_model_name, ""),
)
if not os.path.exists(args.pipeline_dir):
os.makedirs(args.pipeline_dir, exist_ok=True)
if args.input_mlir:
user_mlir_list = args.input_mlir.split(",")
else:
Expand All @@ -1027,14 +1044,15 @@ def numpy_to_pil_image(images):
args.external_weights_dir = args.pipeline_dir
sdxl_pipe = SharkSDXLPipeline(
args.hf_model_name,
args.scheduler_id,
args.height,
args.width,
args.precision,
args.max_length,
args.batch_size,
args.num_inference_steps,
devices,
targets,
args.scheduler_id,
ireec_flags,
args.attn_spec,
args.decomp_attn,
Expand All @@ -1045,9 +1063,9 @@ def numpy_to_pil_image(images):
custom_vae=None,
vae_precision=args.vae_precision,
)

vmfbs, weights = sdxl_pipe.check_prepared(mlirs, vmfbs, weights)
if args.cpu_scheduling:
vmfbs["scheduler"] = None

if args.npu_delegate_path:
extra_device_args = {"npu_delegate_path": args.npu_delegate_path}
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,18 @@ def export_scheduled_unet_model(
# else:
# do_classifier_free_guidance = True
do_classifier_free_guidance = True

filename_keys = [
f"bs{batch_size}",
str(max_length),
f"{height}x{width}",
precision,
scheduler_id,
"DiffusionModule",
str(num_inference_steps),
]
safe_name = utils.create_safe_name(
hf_model_name,
f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_scheduled_unet_{str(num_inference_steps)}",
"_".join(filename_keys),
)
if pipeline_dir:
safe_name = os.path.join(pipeline_dir, safe_name)
Expand Down

0 comments on commit 7388e14

Please sign in to comment.