Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds SDXL support and CI testing, benchmarks. #271

Merged
merged 182 commits into from
Apr 11, 2024
Merged

Adds SDXL support and CI testing, benchmarks. #271

merged 182 commits into from
Apr 11, 2024

Conversation

monorimet
Copy link
Contributor

@monorimet monorimet commented Dec 18, 2023

No description provided.

aviator19941
aviator19941 previously approved these changes Dec 18, 2023
Copy link
Contributor

@aviator19941 aviator19941 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dan-garvey
Copy link
Member

test failure seems real

@aviator19941
Copy link
Contributor

@monorimet I think you need to update the run_unet function with guidance_scale in the definition, add it as an ireert.asdevicearray arg, and update the run_unet test to use guidance_scale instead of the run_torch_unet function.

@monorimet
Copy link
Contributor Author

@monorimet I think you need to update the run_unet function with guidance_scale in the definition, add it as an ireert.asdevicearray arg, and update the run_unet test to use guidance_scale instead of the run_torch_unet function.

Oops. thanks for taking a look. I'll fix it shortly.

@monorimet
Copy link
Contributor Author

There's a correctness difference. I probably need to update the torch implementation. I'll come back to it tomorrow.

os.remove("stable_diffusion_v1_4_clip.safetensors")
os.remove("stable_diffusion_v1_4_clip.vmfb")
# os.remove(f"{arguments['safe_model_name']}_clip.safetensors")
# os.remove(f"{arguments['safe_model_name']}_clip.vmfb")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these commented out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some race condition between this and the actual test execution. Will leave these comments unresolved until test file mgmt is fixed.

os.remove("stable_diffusion_v1_4_unet.safetensors")
os.remove("stable_diffusion_v1_4_unet.vmfb")
# os.remove(f"{arguments['safe_model_name']}_unet.safetensors")
# os.remove(f"{arguments['safe_model_name']}_unet.vmfb")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

os.remove("stable_diffusion_v1_4_vae.safetensors")
os.remove("stable_diffusion_v1_4_vae.vmfb")
# os.remove(f"{arguments['safe_model_name']}_vae.safetensors")
# os.remove(f"{arguments['safe_model_name']}_vae.vmfb")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

os.remove("stable_diffusion_v1_4_vae.safetensors")
os.remove("stable_diffusion_v1_4_vae.vmfb")
# os.remove(f"{arguments['safe_model_name']}_vae.safetensors")
# os.remove(f"{arguments['safe_model_name']}_vae.vmfb")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same



class StableDiffusionTest(unittest.TestCase):
# def testExportClipModel(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of commented code here?

@@ -0,0 +1,190 @@
# Copyright 2023 Nod Labs, Inc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a lot of repeated code from sd_inference/unet.py. Can we combine the two?

@@ -0,0 +1,163 @@
import argparse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combine with other unet_runner?

@monorimet monorimet changed the title Add precision, max_length to unet, vae and guidance scale as input to unet. (WIP) Add some parameters to UNet export/tests + SDXL Jan 22, 2024
@monorimet monorimet force-pushed the ean-sd-fp16 branch 2 times, most recently from b1459fc to a0879c7 Compare February 9, 2024 20:40
@aviator19941 aviator19941 dismissed their stale review February 14, 2024 20:03

need to robustify without SDPA decomps and add SDXL scheduler examples

@monorimet monorimet force-pushed the ean-sd-fp16 branch 2 times, most recently from 4e2801f to 14fc107 Compare February 14, 2024 23:40
@monorimet monorimet changed the title (WIP) Add some parameters to UNet export/tests + SDXL Adds SDXL support and CI testing, benchmarks. Feb 23, 2024
Copy link
Member

@dan-garvey dan-garvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didnt they fix sdpa torch lowering?

@monorimet
Copy link
Contributor Author

monorimet commented Feb 29, 2024

didnt they fix sdpa torch lowering?

yes and no.

with lowering to iree_linalg_ext.attention we have numerics issues on CPU after tileanddecomposeattention and pending implementation for vulkan, tiling for shared memory on ROCM -- see iree-org/iree#16421

so the lowering is OK from torch->linalg but the hal backends need tile and decompose better.

@monorimet
Copy link
Contributor Author

I think we need this commit to be ported over to the upstream fx importer 30ef1fc
@aviator19941 / @PhaneeshB can you help with this

@monorimet monorimet force-pushed the ean-sd-fp16 branch 2 times, most recently from fa6ba50 to 89274eb Compare April 9, 2024 15:52
@monorimet monorimet force-pushed the ean-sd-fp16 branch 3 times, most recently from 80fe4ce to df1002e Compare April 9, 2024 18:42
Copy link
Contributor

@IanNod IanNod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be a lot of repeated code between SD and SDXL. Doesn't need to be done now but can you add todo's or create an issue to work on combining the two to get rid of that?

.github/workflows/test_models.yml Outdated Show resolved Hide resolved
@@ -214,6 +214,13 @@ def flat_wrapped_f(*args):
if "functorch_functionalize" in self._passes:
transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args)

for node in transformed_f.graph.nodes: # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this being done in fx_importer.py? if we need it here instead should we remove it there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be covered by torch-mlir's fx importer. Removing it here.

for node in transformed_f.graph.nodes: # type: ignore
if node.op == "call_function":
if node.target == torch._ops.ops.aten.lift_fresh_copy.default:
print(f"replaced lift_fresh_copy")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove printout?


decomp_list = DEFAULT_DECOMPOSITIONS

decomp_list.extend(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't SDPA decomp be flag guarded?

torch.ops.aten._scaled_dot_product_flash_attention.default,
]
)
# encoder_hidden_states_sizes = (2, 77, 768)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented code?

)

winograd_params = "keys=unet.down_blocks.2.resnets.0.conv2.weight keys=unet.down_blocks.2.resnets.1.conv1.weight keys=unet.down_blocks.2.resnets.1.conv2.weight keys=unet.mid_block.resnets.0.conv1.weight keys=unet.mid_block.resnets.0.conv2.weight keys=unet.mid_block.resnets.1.conv1.weight keys=unet.mid_block.resnets.1.conv2.weight keys=unet.up_blocks.0.resnets.0.conv2.weight keys=unet.up_blocks.0.resnets.1.conv2.weight keys=unet.up_blocks.0.resnets.2.conv2.weight keys=unet.up_blocks.0.resnets.0.conv1.weight keys=unet.up_blocks.0.resnets.1.conv1.weight keys=unet.up_blocks.0.resnets.2.conv1.weight keys=unet.up_blocks.0.upsamplers.0.conv.weight"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably shouldn't have all of these winograd_params and flags hardcoded. Probably fine for now but might want a todo to address this

@@ -93,6 +118,15 @@ def export_vae_model(
upload_ir=False,
):
mapper = {}
decomp_list = DEFAULT_DECOMPOSITIONS
decomp_list.extend(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flag guard these decomps?

)
self.do_classifier_free_guidance = do_classifier_free_guidance

# self.tokenizer_1 = CLIPTokenizer.from_pretrained(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented code

full_pipeline_file = (
pipe_prefix + "f32" if args.precision == "fp32" else pipe_prefix + "f16"
)
# pipeline_vmfb_path = utils.compile_to_vmfb(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented code

torch.ops.aten._scaled_dot_product_flash_attention.default,
]
)
# encoder_hidden_states_sizes = (2, 77, 768)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented code

Copy link
Contributor

@IanNod IanNod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@monorimet monorimet merged commit 1dea19e into main Apr 11, 2024
8 checks passed
@monorimet monorimet deleted the ean-sd-fp16 branch April 11, 2024 20:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants