-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds SDXL support and CI testing, benchmarks. #271
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
test failure seems real |
@monorimet I think you need to update the run_unet function with guidance_scale in the definition, add it as an ireert.asdevicearray arg, and update the run_unet test to use guidance_scale instead of the run_torch_unet function. |
Oops. thanks for taking a look. I'll fix it shortly. |
There's a correctness difference. I probably need to update the torch implementation. I'll come back to it tomorrow. |
155ae04
to
4498486
Compare
c1aebad
to
1cec8dd
Compare
os.remove("stable_diffusion_v1_4_clip.safetensors") | ||
os.remove("stable_diffusion_v1_4_clip.vmfb") | ||
# os.remove(f"{arguments['safe_model_name']}_clip.safetensors") | ||
# os.remove(f"{arguments['safe_model_name']}_clip.vmfb") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these commented out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some race condition between this and the actual test execution. Will leave these comments unresolved until test file mgmt is fixed.
os.remove("stable_diffusion_v1_4_unet.safetensors") | ||
os.remove("stable_diffusion_v1_4_unet.vmfb") | ||
# os.remove(f"{arguments['safe_model_name']}_unet.safetensors") | ||
# os.remove(f"{arguments['safe_model_name']}_unet.vmfb") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
os.remove("stable_diffusion_v1_4_vae.safetensors") | ||
os.remove("stable_diffusion_v1_4_vae.vmfb") | ||
# os.remove(f"{arguments['safe_model_name']}_vae.safetensors") | ||
# os.remove(f"{arguments['safe_model_name']}_vae.vmfb") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
os.remove("stable_diffusion_v1_4_vae.safetensors") | ||
os.remove("stable_diffusion_v1_4_vae.vmfb") | ||
# os.remove(f"{arguments['safe_model_name']}_vae.safetensors") | ||
# os.remove(f"{arguments['safe_model_name']}_vae.vmfb") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
|
||
|
||
class StableDiffusionTest(unittest.TestCase): | ||
# def testExportClipModel(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of commented code here?
@@ -0,0 +1,190 @@ | |||
# Copyright 2023 Nod Labs, Inc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a lot of repeated code from sd_inference/unet.py. Can we combine the two?
@@ -0,0 +1,163 @@ | |||
import argparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
combine with other unet_runner?
b1459fc
to
a0879c7
Compare
need to robustify without SDPA decomps and add SDXL scheduler examples
4e2801f
to
14fc107
Compare
0d5e913
to
0b66db8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didnt they fix sdpa torch lowering?
yes and no. with lowering to iree_linalg_ext.attention we have numerics issues on CPU after tileanddecomposeattention and pending implementation for vulkan, tiling for shared memory on ROCM -- see iree-org/iree#16421 so the lowering is OK from torch->linalg but the hal backends need tile and decompose better. |
56a8bfe
to
4b28a12
Compare
I think we need this commit to be ported over to the upstream fx importer 30ef1fc |
fa6ba50
to
89274eb
Compare
80fe4ce
to
df1002e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be a lot of repeated code between SD and SDXL. Doesn't need to be done now but can you add todo's or create an issue to work on combining the two to get rid of that?
@@ -214,6 +214,13 @@ def flat_wrapped_f(*args): | |||
if "functorch_functionalize" in self._passes: | |||
transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args) | |||
|
|||
for node in transformed_f.graph.nodes: # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't this being done in fx_importer.py? if we need it here instead should we remove it there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to be covered by torch-mlir's fx importer. Removing it here.
for node in transformed_f.graph.nodes: # type: ignore | ||
if node.op == "call_function": | ||
if node.target == torch._ops.ops.aten.lift_fresh_copy.default: | ||
print(f"replaced lift_fresh_copy") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove printout?
|
||
decomp_list = DEFAULT_DECOMPOSITIONS | ||
|
||
decomp_list.extend( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't SDPA decomp be flag guarded?
torch.ops.aten._scaled_dot_product_flash_attention.default, | ||
] | ||
) | ||
# encoder_hidden_states_sizes = (2, 77, 768) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented code?
) | ||
|
||
winograd_params = "keys=unet.down_blocks.2.resnets.0.conv2.weight keys=unet.down_blocks.2.resnets.1.conv1.weight keys=unet.down_blocks.2.resnets.1.conv2.weight keys=unet.mid_block.resnets.0.conv1.weight keys=unet.mid_block.resnets.0.conv2.weight keys=unet.mid_block.resnets.1.conv1.weight keys=unet.mid_block.resnets.1.conv2.weight keys=unet.up_blocks.0.resnets.0.conv2.weight keys=unet.up_blocks.0.resnets.1.conv2.weight keys=unet.up_blocks.0.resnets.2.conv2.weight keys=unet.up_blocks.0.resnets.0.conv1.weight keys=unet.up_blocks.0.resnets.1.conv1.weight keys=unet.up_blocks.0.resnets.2.conv1.weight keys=unet.up_blocks.0.upsamplers.0.conv.weight" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably shouldn't have all of these winograd_params and flags hardcoded. Probably fine for now but might want a todo to address this
@@ -93,6 +118,15 @@ def export_vae_model( | |||
upload_ir=False, | |||
): | |||
mapper = {} | |||
decomp_list = DEFAULT_DECOMPOSITIONS | |||
decomp_list.extend( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flag guard these decomps?
) | ||
self.do_classifier_free_guidance = do_classifier_free_guidance | ||
|
||
# self.tokenizer_1 = CLIPTokenizer.from_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commented code
full_pipeline_file = ( | ||
pipe_prefix + "f32" if args.precision == "fp32" else pipe_prefix + "f16" | ||
) | ||
# pipeline_vmfb_path = utils.compile_to_vmfb( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commented code
torch.ops.aten._scaled_dot_product_flash_attention.default, | ||
] | ||
) | ||
# encoder_hidden_states_sizes = (2, 77, 768) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commented code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
No description provided.