Skip to content

Commit

Permalink
Merge branch 'ean-unify-sd' into ean-unify-sd-staging
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet authored Jul 9, 2024
2 parents 491d8ad + 79a094f commit 46856a2
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,13 @@ def is_valid_file(arg):
##############################################################################

p.add_argument("--batch_size", type=int, default=1, help="Batch size for inference")
p.add_argument(
"--batch_prompt_input",
type=bool,
default=False,
help="If batch size > 1 this enables batching the prompt encoder input rather than concating prompt encoders output",
)

p.add_argument(
"--height", type=int, default=1024, help="Height of Stable Diffusion output image."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
custom_vae: str = "",
cpu_scheduling: bool = False,
vae_precision: str = "fp32",
batch_prompt_input: bool = False,
):
self.hf_model_name = hf_model_name
self.scheduler_id = scheduler_id
Expand All @@ -76,6 +77,7 @@ def __init__(
self.precision = precision
self.max_length = max_length
self.batch_size = batch_size
self.batch_prompt_input = batch_prompt_input
self.num_inference_steps = num_inference_steps
self.devices = {}
if isinstance(device, dict):
Expand Down Expand Up @@ -375,7 +377,7 @@ def export_submodel(
"vmfb",
self.external_weights,
unet_external_weight_path,
self.devices["unet"]["device"],
self.devices["unet"]["driver"],
self.devices["unet"]["target"],
self.ireec_flags["unet"],
self.decomp_attn,
Expand Down Expand Up @@ -404,7 +406,7 @@ def export_submodel(
"vmfb",
self.external_weights,
unet_external_weight_path,
self.devices["unet"]["device"],
self.devices["unet"]["driver"],
self.devices["unet"]["target"],
self.ireec_flags["unet"],
self.decomp_attn,
Expand All @@ -429,7 +431,7 @@ def export_submodel(
self.num_inference_steps,
self.precision,
"vmfb",
self.devices["unet"]["device"],
self.devices["unet"]["driver"],
self.devices["unet"]["target"],
self.ireec_flags["scheduler"],
exit_on_vmfb=False,
Expand Down Expand Up @@ -460,7 +462,7 @@ def export_submodel(
"vmfb",
self.external_weights,
vae_external_weight_path,
self.devices["vae"]["device"],
self.devices["vae"]["driver"],
self.devices["vae"]["target"],
self.ireec_flags["vae"],
"decode",
Expand All @@ -482,7 +484,7 @@ def export_submodel(
"vmfb",
self.external_weights,
prompt_encoder_external_weight_path,
self.devices["clip"]["device"],
self.devices["clip"]["driver"],
self.devices["clip"]["target"],
self.ireec_flags["clip"],
exit_on_vmfb=False,
Expand All @@ -492,7 +494,8 @@ def export_submodel(
input_mlir=input_mlir["prompt_encoder"],
attn_spec=self.attn_spec,
weights_only=weights_only,
output_batchsize=self.batch_size,
batchsize=self.batch_size,
batch_input=self.batch_prompt_input,
)
return prompt_encoder_vmfb, prompt_encoder_external_weight_path
case "unetloop":
Expand All @@ -514,7 +517,7 @@ def export_submodel(
]
pipeline_vmfb = utils.compile_to_vmfb(
pipeline_file,
self.devices["unet"]["device"],
self.devices["unet"]["driver"],
self.devices["unet"]["target"],
self.ireec_flags["unetloop"],
os.path.join(self.pipeline_dir, "_".join(pipeline_keys)),
Expand All @@ -541,7 +544,7 @@ def export_submodel(
]
pipeline_vmfb = utils.compile_to_vmfb(
pipeline_file,
self.devices["unet"]["device"],
self.devices["unet"]["driver"],
self.devices["unet"]["target"],
self.ireec_flags["unetloop"],
os.path.join(self.pipeline_dir, "_".join(pipeline_keys)),
Expand Down Expand Up @@ -1082,6 +1085,7 @@ def numpy_to_pil_image(images):
args.vae_decomp_attn,
custom_vae=None,
vae_precision=args.vae_precision,
batch_prompt_input=args.batch_prompt_input,
)

vmfbs, weights = sdxl_pipe.check_prepared(mlirs, vmfbs, weights)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
hf_auth_token=None,
do_classifier_free_guidance=True,
batch_size=1,
batch_input=False,
):
super().__init__()
self.torch_dtype = torch.float16 if precision == "fp16" else torch.float32
Expand All @@ -42,6 +43,7 @@ def __init__(
)
self.do_classifier_free_guidance = True
self.batch_size = batch_size
self.batch_input = batch_input

def forward(
self, text_input_ids_1, text_input_ids_2, uncond_input_ids_1, uncond_input_ids_2
Expand Down Expand Up @@ -85,20 +87,25 @@ def forward(
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(
bs_embed * 1, -1
)
prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1)
if not self.batch_input:
prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1)
add_text_embeds = pooled_prompt_embeds
add_text_embeds = add_text_embeds.repeat(self.batch_size, 1)
if not self.batch_input:
add_text_embeds = add_text_embeds.repeat(self.batch_size, 1)
if self.do_classifier_free_guidance:
neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(1, 1).view(
1, -1
)
if not self.batch_input:
neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(
1, 1
).view(1, -1)
neg_prompt_embeds = neg_prompt_embeds.repeat(1, 1, 1)
neg_prompt_embeds = neg_prompt_embeds.view(bs_embed * 1, seq_len, -1)
neg_prompt_embeds = neg_prompt_embeds.repeat(self.batch_size, 1, 1)
if not self.batch_input:
neg_prompt_embeds = neg_prompt_embeds.repeat(self.batch_size, 1, 1)
prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0)
neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(
self.batch_size, 1
)
if not self.batch_input:
neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(
self.batch_size, 1
)
add_text_embeds = torch.cat(
[neg_pooled_prompt_embeds, add_text_embeds], dim=0
)
Expand Down Expand Up @@ -163,6 +170,7 @@ def export_prompt_encoder(
input_mlir=None,
attn_spec=None,
weights_only=False,
batch_input=False,
):
do_classifier_free_guidance = True

Expand Down Expand Up @@ -206,7 +214,13 @@ def export_prompt_encoder(
hf_auth_token,
do_classifier_free_guidance,
batch_size=batch_size,
batch_input=batch_input,
)

input_batchsize = 1
if batch_input:
input_batchsize = batch_size

if precision == "fp16":
prompt_encoder_module = prompt_encoder_module.half()
mapper = {}
Expand All @@ -231,19 +245,19 @@ class CompiledClip(CompiledModule):

def encode_prompts(
self,
t_ids_1=AbstractTensor(1, max_length, dtype=torch.int64),
t_ids_2=AbstractTensor(1, max_length, dtype=torch.int64),
uc_ids_1=AbstractTensor(1, max_length, dtype=torch.int64),
uc_ids_2=AbstractTensor(1, max_length, dtype=torch.int64),
t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
uc_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
uc_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
):
return jittable(prompt_encoder_module.forward)(
t_ids_1, t_ids_2, uc_ids_1, uc_ids_2
)

def encode_prompts_turbo(
self,
t_ids_1=AbstractTensor(1, max_length, dtype=torch.int64),
t_ids_2=AbstractTensor(1, max_length, dtype=torch.int64),
t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
):
return jittable(prompt_encoder_module.forward_turbo)(t_ids_1, t_ids_2)

Expand Down

0 comments on commit 46856a2

Please sign in to comment.