Skip to content

Commit

Permalink
Make pipeline mode names mutually exclusive and fixes to weights loading
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 21, 2024
1 parent 7388e14 commit e399fe7
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
vae_precision: str = "fp32",
scheduler_id: str = None, # compatibility only, always uses EulerFlowScheduler
shift: float = 1.0,
custom_vae: str = None,
):
self.hf_model_name = hf_model_name
# self.scheduler_id = scheduler_id
Expand Down Expand Up @@ -122,7 +123,7 @@ def __init__(
self.external_weights_dir = external_weights_dir
self.external_weights = external_weights
self.vae_decomp_attn = vae_decomp_attn
self.custom_vae = None
self.custom_vae = custom_vae
self.cpu_scheduling = cpu_scheduling
self.torch_dtype = torch.float32 if self.precision == "fp32" else torch.float16
self.vae_precision = vae_precision if vae_precision else self.precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
"vae_decode": None,
"prompt_encoder": None,
"scheduled_unet": None,
"pipeline": None,
"full_pipeline": None,
"unetloop": None,
"fullpipeline": None,
}

EMPTY_FLAGS = {
Expand Down Expand Up @@ -117,6 +117,12 @@ def __init__(
self.vae_precision = vae_precision
self.vae_dtype = "float32" if vae_precision == "fp32" else "float16"
self.custom_vae = custom_vae
if self.custom_vae:
self.vae_dir = os.path.join(
self.pipeline_dir, utils.create_safe_name(custom_vae, "")
)
if not os.path.exists(self.vae_dir):
os.makedirs(self.vae_dir)
self.cpu_scheduling = cpu_scheduling
self.compiled_pipeline = False
self.split_scheduler = False
Expand Down Expand Up @@ -173,6 +179,7 @@ def check_prepared(
def is_prepared(self, vmfbs, weights):
missing = []
dims = f"{str(self.width)}x{str(self.height)}"
pipeline_dir = self.pipeline_dir
for key in vmfbs:
if key == "scheduled_unet":
keywords = [
Expand All @@ -189,6 +196,8 @@ def is_prepared(self, vmfbs, weights):
elif key == "vae_decode":
keywords = ["vae", self.vae_precision, dims]
device_key = "vae"
if self.custom_vae:
pipeline_dir = self.vae_dir
elif key == "prompt_encoder":
keywords = ["prompt_encoder", self.precision, self.max_length]
device_key = "clip"
Expand All @@ -203,10 +212,10 @@ def is_prepared(self, vmfbs, weights):
self.devices[device_key]["target"],
]
)
avail_files = os.listdir(self.pipeline_dir)
avail_files = os.listdir(pipeline_dir)
for filename in avail_files:
if all(str(x) in filename for x in keywords):
vmfbs[key] = os.path.join(self.pipeline_dir, filename)
vmfbs[key] = os.path.join(pipeline_dir, filename)
if not vmfbs[key]:
missing.append(key + " vmfb")

Expand Down Expand Up @@ -432,6 +441,14 @@ def export_submodel(
vae_torch = self.get_torch_models("vae_decode")
else:
vae_torch = None
if self.custom_vae:
vae_external_weight_path = os.path.join(
self.vae_dir,
f"vae_decode_{self.vae_precision}." + self.external_weights,
)
vae_dir = self.vae_dir
else:
vae_dir = self.pipeline_dir
vae_decode_vmfb = vae.export_vae_model(
vae_torch,
self.hf_model_name,
Expand All @@ -448,7 +465,7 @@ def export_submodel(
"decode",
self.vae_decomp_attn,
exit_on_vmfb=False,
pipeline_dir=self.pipeline_dir,
pipeline_dir=vae_dir,
attn_spec=self.attn_spec,
input_mlir=input_mlir["vae_decode"],
weights_only=weights_only,
Expand All @@ -468,7 +485,9 @@ def export_submodel(
self.devices["clip"]["target"],
self.ireec_flags["clip"],
exit_on_vmfb=False,
pipeline_dir=self.pipeline_dir,
pipeline_dir=(
self.pipeline_dir if not self.custom_vae else self.vae_dir
),
input_mlir=input_mlir["prompt_encoder"],
attn_spec=self.attn_spec,
weights_only=weights_only,
Expand All @@ -490,7 +509,7 @@ def export_submodel(
f"{str(self.width)}x{str(self.height)}",
self.precision,
str(self.max_length),
"pipeline",
"unetloop",
]
pipeline_vmfb = utils.compile_to_vmfb(
pipeline_file,
Expand All @@ -517,7 +536,7 @@ def export_submodel(
f"{str(self.width)}x{str(self.height)}",
self.precision,
str(self.max_length),
"full_pipeline",
"fullpipeline",
]
pipeline_vmfb = utils.compile_to_vmfb(
pipeline_file,
Expand Down
24 changes: 13 additions & 11 deletions models/turbine_models/custom_models/sdxl_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
import torch._dynamo as dynamo
from diffusers import AutoencoderKL
import safetensors


class VaeModel(torch.nn.Module):
Expand All @@ -34,6 +35,14 @@ def __init__(
hf_model_name,
subfolder="vae",
)
elif "safetensors" in custom_vae:
custom_vae = safetensors.torch.load_file(custom_vae)
# custom vae as a HF state dict
self.vae = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
)
self.vae.load_state_dict(custom_vae)
elif not isinstance(custom_vae, dict):
try:
# custom HF repo with no vae subfolder
Expand All @@ -46,13 +55,6 @@ def __init__(
custom_vae,
subfolder="vae",
)
else:
# custom vae as a HF state dict
self.vae = AutoencoderKL.from_pretrained(
hf_model_name,
subfolder="vae",
)
self.vae.load_state_dict(custom_vae)

def decode(self, inp):
img = 1 / 0.13025 * inp
Expand Down Expand Up @@ -104,10 +106,10 @@ def export_vae_model(
attn_spec=attn_spec,
)
return vmfb_path
if precision == "fp32" and device == "rocm":
decomp_attn = True
external_weights = None
print("Decomposing attention and inlining weights for fp32 VAE on ROCm")
# if precision == "fp32" and device == "rocm":
# decomp_attn = True
# external_weights = None
# print("Decomposing attention and inlining weights for fp32 VAE on ROCm")
if device == "cpu":
decomp_attn = True

Expand Down

0 comments on commit e399fe7

Please sign in to comment.